Skip to content

Instantly share code, notes, and snippets.

@alkemann
Created February 12, 2023 21:53
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save alkemann/61b9bcfab5a98504dc678d1c04b71c51 to your computer and use it in GitHub Desktop.
Save alkemann/61b9bcfab5a98504dc678d1c04b71c51 to your computer and use it in GitHub Desktop.
Dataclass with db in basemodel
import sqlite3
from contextlib import contextmanager
from util import BaseModel
from dataclasses import dataclass
from util import get_db
import dataclasses
from typing import Iterable
import logging
log = logging.getLogger(__file__)
@dataclass
class Blog(BaseModel):
id: int
name: str
body: str
__table__ = 'blogs'
@contextmanager
def get_db(file_name: str = "database.db"):
connection = sqlite3.connect(file_name)
try:
cursor = connection.cursor()
yield cursor
finally:
connection.commit()
connection.close()
class BaseModel:
__pk__ = 'id'
@classmethod
def fields(cls) -> str:
return ", ".join([f.name for f in dataclasses.fields(cls)])
@classmethod
def fields_to_update(cls) -> str:
return ", ".join([f"{f.name} = ?" for f in dataclasses.fields(cls)])
@classmethod
def getById(cls, id: int):
""" Returns one instance, matching id"""
fields = cls.fields()
query = f"SELECT {fields} FROM {cls.__table__} WHERE {cls.__pk__} = ?"
params = [id]
log.debug(query)
log.debug(params)
with get_db() as db:
result = db.execute(query, params)
row = result.fetchone()
if row is not None:
return cls(*row)
else:
raise Exception(f"Cound not find {id} of type {cls.__name__}")
@classmethod
def list(cls) -> Iterable:
""" Returns list with instances of all records """
fields = cls.fields()
query = f"SELECT {fields} FROM {cls.__table__}"
log.debug(query)
with get_db() as db:
result = db.execute(query)
for row in result:
yield cls(*row)
def update(self) -> bool:
""" UPDATE table """
params = list(self.__dict__.values()) # keep as dict for PostgreSQL
params.append(self.id)
query = f"UPDATE {self.__table__} SET {self.fields_to_update()} WHERE {self.__pk__} = ?"
log.debug(query)
log.debug(params)
with get_db() as db:
result = db.execute(query, params)
return result.rowcount == 1
def insert(self) -> bool:
""" Insert into table """
fields = self.fields()
params = list(self.__dict__.values()) # keep as dict for PostgreSQL
query = f"INSERT INTO {self.__table__} ({fields}) VALUES (?, ?, ?)"
log.debug(query)
log.debug(params)
with get_db() as db:
result = db.execute(query, params)
return result.rowcount == 1
def getHasMany(self, model, join_table: str, other_foreign_key: str, foreign_key: str) -> Iterable:
""" Method for grabbing many related records that matches a has many relationship """
query = f"SELECT {other_foreign_key} FROM {join_table} WHERE {foreign_key} = ?"
params = [getattr(self, self.__pk__)]
log.debug(query)
log.debug(params)
with get_db() as db:
results = db.execute(query, params)
for row in results:
yield model.getById(row[0])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment