Last active
December 29, 2022 18:06
-
-
Save seandstewart/477998f8675f81e33d122a922bc678d1 to your computer and use it in GitHub Desktop.
A base repository in Python using sqilte3 and pypika. Bugs and optimizations are left as an exercise to the user.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import annotations | |
import contextlib | |
import sqlite3 | |
from typing import Iterator, TypeVar, Generic | |
import pypika | |
_T = TypeVar("_T") | |
class BaseSQLiteRepository(Generic[_T]): | |
__slots__ = ("url", "table") | |
def __init__( | |
self, *, tablename: str, schema: str = None, url: str = "sqlite://:memory:" | |
): | |
self.url = url | |
self.table = pypika.Table(name=tablename, schema=schema) | |
@contextlib.contextmanager | |
def connection( | |
self, *, c: sqlite3.Connection = None | |
) -> Iterator[sqlite3.Connection]: | |
if c is not None: | |
yield c | |
else: | |
conn: sqlite3.Connection | |
with sqlite3.connect(self.url) as conn: | |
yield conn | |
def table_columns(self) -> tuple[pypika.terms.Field, ...]: | |
return (self.table.star,) | |
def deserialize(self, *, row: sqlite3.Row | None) -> _T | None: | |
raise NotImplementedError() | |
def serialize(self, *, model: _T) -> dict: | |
raise NotImplementedError() | |
def fetchone(self, *, query: pypika.queries.QueryBuilder, connection: sqlite3.Connection = None) -> sqlite3.Row: | |
conn: sqlite3.Connection | |
with self.connection(c=connection) as conn: | |
cursor: sqlite3.Cursor = conn.execute(str(query)) | |
row = cursor.fetchone() | |
return row | |
def get(self, *, id: int, connection: sqlite3.Connection = None) -> _T | None: | |
query = ( | |
pypika.Query.from_(self.table) | |
.select(*self.table_columns()) | |
.where(self.table.id == id) | |
) | |
row = self.fetchone(query=query, connection=connection) | |
return self.deserialize(row=row) | |
def create(self, *, instance: _T, connection: sqlite3.Connection = None) -> _T | None: | |
data = self.serialize(model=instance) | |
query = ( | |
pypika.Query.into(self.table) | |
.insert((*data.values(),)) | |
.returning(*self.table_columns()) | |
) | |
row = self.fetchone(query=query, connection=connection) | |
return self.deserialize(row=row) | |
def update(self, *, instance: _T, connection: sqlite3.Connection = None) -> _T | None: | |
data = self.serialize(model=instance) | |
query = ( | |
pypika.Query.into(self.table) | |
.update((*data.values(),)) | |
.where(self.table.id == instance.id) | |
.returning(*self.table_columns()) | |
) | |
row = self.fetchone(query=query, connection=connection) | |
return self.deserialize(row=row) | |
def delete(self, *, id: int, connection: sqlite3.Connection = None) -> int: | |
query = ( | |
pypika.Query.from_(self.table) | |
.delete() | |
.where(self.table.id == id) | |
) | |
row = self.fetchone(query=query, connection=connection) | |
return row[0] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment