Skip to content

Instantly share code, notes, and snippets.

@Zomatree
Created May 2, 2023 02:43
Show Gist options
  • Save Zomatree/f22af7af93caab9176b673eae8acc189 to your computer and use it in GitHub Desktop.
Save Zomatree/f22af7af93caab9176b673eae8acc189 to your computer and use it in GitHub Desktop.
from __future__ import annotations
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal, Self, TypeVar, Annotated, TypeVarTuple, get_args, overload, get_origin, reveal_type, cast
import asyncpg
T = TypeVar("T")
T_T = TypeVar("T_T", bound="Table", covariant=True)
T_OT = TypeVar("T_OT", bound="Table", covariant=True)
T_Ts = TypeVarTuple("T_Ts")
if TYPE_CHECKING:
Connection = asyncpg.Connection[asyncpg.Record]
def eval_annotation(annot: Any, locals: dict[str, Any] | None = None, globals: dict[str, Any] | None = None) -> Any:
if not isinstance(annot, str):
return annot
return eval(annot, locals, globals)
class _Missing:
def __eq__(self, _: Any) -> Literal[False]:
return False
def __repr__(self) -> str:
return "<Missing>"
Missing = _Missing()
class TableMetadata:
def __init__(self, name: str, columns: list[Column[Any]]) -> None:
self.name = name
self.columns = columns
self.values: dict[str, Any] = {}
class Column(Generic[T]):
def __init__(self, table: type[Table], name: str, datatype: type, default: Any):
self.table = table
self.name = name
self.datatype = datatype
self.default = default
@overload
def __get__(self, instance: None, _: type[Table]) -> Self:
...
@overload
def __get__(self, instance: Table, _: type[Table]) -> T:
...
def __get__(self, instance: Table | None, _: type[Table]) -> T | Self:
if instance is None:
return self
return instance._metadata.values[self.name]
def __eq__(self, value: T | Self) -> WhereQuery: # type: ignore
return WhereQuery(self, value, "=")
def __lt__(self, value: T | Self) -> WhereQuery:
return WhereQuery(self, value, "<")
def __le__(self, value: T | Self) -> WhereQuery:
return WhereQuery(self, value, "<=")
def __ne__(self, value: T | Self) -> WhereQuery: # type: ignore
return WhereQuery(self, value, "!=")
class ColumnBuilder:
def __init__(self) -> None:
self._name: str | None = None
self._type: type | None = None
self._default: Any = Missing
self._primary: bool = False
self._foreign: Column[Any] | None = None
self._table: type[Table] | None = None
def name(self, name: str) -> Self:
self._name = name
return self
def type(self, type: type) -> Self:
self._type = type
return self
def default(self, default: Any) -> Self:
self._default = default
return self
def primary(self) -> Self:
self._primary = True
return self
def foreign(self, column: Column[Any]) -> Self:
self._foreign = column
return self
def table(self, table: type[Table]) -> Self:
self._table = table
return self
def build(self) -> Column[Any]:
if not self._name:
raise Exception("No name")
if not self._type:
raise Exception("No type")
if not self._table:
raise Exception("No table")
return Column(self._table, self._name, self._type, self._default)
class QueryBuilder(Generic[T_T]):
def __init__(self, table: type[T_T]) -> None:
self.table = table
def build(self) -> tuple[str, list[Any]]:
raise NotImplementedError
async def execute(self, conn: Connection) -> int:
query, parameters = self.build()
res = await conn.execute(query, *parameters)
return int(res.split(" ")[1])
async def fetch(self, conn: Connection) -> list[T_T]:
query, parameters = self.build()
records = await conn.fetch(query, *parameters)
return [self.table(**record) for record in records]
async def fetchone(self, conn: Connection) -> T_T | None:
query, parameters = self.build()
record = await conn.fetchrow(query, *parameters)
if record:
return self.table(**record)
class SelectQueryBuilder(QueryBuilder[T_T]):
def __init__(self, table: type[T_T]) -> None:
super().__init__(table)
self._wheres: list[WhereQuery] = []
def where(self, query: WhereQuery):
self._wheres.append(query)
return self
def build(self) -> tuple[str, list[Any]]:
columns = ", ".join([column.name for column in self.table._metadata.columns])
query_parts = [f"select {columns} from {self.table._metadata.name}"]
if self._wheres:
where_clause = ' and '.join([f"{where.column.name} {where.op} ${i}" for i, where in enumerate(self._wheres)])
query_parts.append(f"where {where_clause}")
return " ".join(query_parts), [where.value for where in self._wheres]
def join(self, query: SelectQueryBuilder[T_OT]) -> JoinSelectQueryBuilder[T_T, T_OT]:
return JoinSelectQueryBuilder(self, query)
class JoinSelectQueryBuilder(SelectQueryBuilder[T_T], Generic[T_T, *T_Ts]):
def __init__(self, select_query: SelectQueryBuilder[T_T], join: SelectQueryBuilder[Any]):
self._wheres = select_query._wheres
self.table = select_query.table
self.joins: list[SelectQueryBuilder[Table]] = [join]
def join(self, query: SelectQueryBuilder[T_OT]) -> JoinSelectQueryBuilder[T_T, *T_Ts, T_OT]:
self.joins.append(query)
return cast(JoinSelectQueryBuilder[T_T, *T_Ts, T_OT], self)
def build(self) -> tuple[str, list[str]]:
columns: list[str] = []
values: list[Any] = []
for table in [self.table] + [join.table for join in self.joins]:
for column in table._metadata.columns:
columns.append(f"{table._metadata.name}.{column.name} as {table._metadata.name}_{column.name}")
joins: list[str] = []
for join in self.joins:
wheres: list[str] = []
for where in join._wheres:
if isinstance(where.value, Column):
value = f"{where.value.table._metadata.name}.{where.value.name}"
else:
value = f"${len(values) + 1}"
values.append(where.value)
wheres.append(f"{where.column.table._metadata.name}.{where.column.name} {where.op} {value}")
joins.append(f"inner join {join.table._metadata.name} on {' and '.join(wheres)}")
wheres = []
for where in self._wheres:
if isinstance(where.value, Column):
value = f"{where.value.table._metadata.name}.{where.value.name}"
else:
value = f"${len(values) + 1}"
values.append(where.value)
wheres.append(f"{where.column.table._metadata.name}.{where.column.name} {where.op} {value}")
where_clause = f"where {' and '.join(wheres)}" if wheres else ""
query = f"select {','.join(columns)} from {self.table._metadata.name} {' '.join(joins)} {where_clause}"
return query, values
async def fetchone(self, conn: Connection) -> tuple[T_T, *T_Ts] | None:
query, parameters = self.build()
row = await conn.fetchrow(query, *parameters)
if row:
collections: dict[str, dict[str, Any]] = {}
for column, value in row.items():
table_name, *rest = column.split("_")
collections.setdefault(table_name, {})["_".join(rest)] = value
return cast(tuple[T_T, *T_Ts], [join.table(**collections[join.table._metadata.name]) for join in self.joins])
async def fetch(self, conn: Connection) -> list[tuple[T_T, *T_Ts]]:
query, parameters = self.build()
rows = await conn.fetch(query, *parameters)
output: list[tuple[Table, ...]] = []
for row in rows:
collections: dict[str, dict[str, Any]] = {}
for column, value in row.items():
table_name, *rest = column.split("_")
collections.setdefault(table_name, {})["_".join(rest)] = value
output.append(tuple(join.table(**collections[join.table._metadata.name]) for join in self.joins))
return cast(list[tuple[T_T, *T_Ts]], output)
class InsertQueryBuilder(QueryBuilder[T_T]):
def build(self) -> tuple[str, list[Any]]:
columns = ", ".join([column.name for column in self.table._metadata.columns])
values = ", ".join(f"${i}" for i in range(len(self.table._metadata.columns)))
return f"insert into {self.table._metadata.name} ({columns}) values ({values})", [getattr(self, column.name) for column in self.table._metadata.columns]
class WhereQuery:
def __init__(self, column: Column[Any], value: Any, op: str):
self.column = column
self.value = value
self.op = op
class Table:
_metadata: ClassVar[TableMetadata]
def __init_subclass__(cls, *, table_name: str | None = None) -> None:
columns: list[Column[Any]] = []
for key, ann in cls.__annotations__.items():
ann = eval_annotation(ann)
origin = get_origin(ann)
if origin is Annotated:
ty, column_builder = get_args(ann)
column_builder._ty = eval_annotation(ty)
else:
column_builder_ty: type[Column[Any]] = ann
ty, = get_args(column_builder_ty)
column_builder = ColumnBuilder()
column = column_builder.name(key).type(ty).table(cls).build()
columns.append(column)
setattr(cls, key, column)
cls._metadata = TableMetadata(table_name or cls.__name__, columns)
def __init__(self, **kwargs: Any):
self._metadata.values = kwargs
@classmethod
def select(cls) -> SelectQueryBuilder[Self]:
return SelectQueryBuilder(cls)
@classmethod
def where(cls, where: WhereQuery) -> SelectQueryBuilder[Self]:
return cls.select().where(where)
Text = Column[str]
Int = Column[int]
class Customer(Table, table_name="customers"):
id: Annotated[Int, ColumnBuilder().primary()]
name: Text
class Item(Table, table_name="items"):
id: Annotated[Int, ColumnBuilder().primary()]
name: Text
class Order(Table, table_name="orders"):
id: Annotated[Int, ColumnBuilder().primary()]
customer: Annotated[Int, ColumnBuilder().foreign(Customer.id)]
item: Annotated[Int, ColumnBuilder().foreign(Item.id)]
async def main(db: asyncpg.Connection[asyncpg.Record]):
query = (Order.select()
.join(Customer
.where(Order.customer == Customer.id))
.join(Item()
.where(Order.item == Item.id)
.where(Item.name == "Chair"))
)
reveal_type(await query.fetchone(db))
reveal_type(await query.fetch(db))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment