Skip to content

Instantly share code, notes, and snippets.

@djrobstep
Last active January 28, 2022 17:32
Show Gist options
  • Star 6 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save djrobstep/4f4c04bf602c2da3549450f4914c7fc8 to your computer and use it in GitHub Desktop.
Save djrobstep/4f4c04bf602c2da3549450f4914c7fc8 to your computer and use it in GitHub Desktop.
orm.py: An ORM for the ORM haters
import inspect
import pprint
import textwrap
from collections.abc import Iterable
from dataclasses import asdict
from dataclasses import field as dataclass_field
from dataclasses import fields as dataclass_fields
from dataclasses import make_dataclass
import results
db = results.db("postgresql:///library")
i = None
SCHEMA_SQL = """\
create extension pgcrypto;
create table book(
book_id uuid default gen_random_uuid() primary key,
title varchar,
author varchar
);
create table borrower(
borrower_id uuid default gen_random_uuid() primary key,
name varchar
);
create table issue(
issue_id uuid default gen_random_uuid() primary key,
book_id uuid,
borrower_id uuid,
issue_date timestamptz default now()
);
"""
class Entity:
def save(self, db):
table_name = type(self).__name__.lower()
id_field = f"{table_name}_id"
fields = {k: v for k, v in asdict(self).items() if v is not None}
try:
entity_id = fields[id_field]
upsert_on = [id_field]
except KeyError:
entity_id = None
upsert_on = None
inserted = db.insert(table_name, fields, upsert_on=upsert_on).one()
if entity_id is None:
inserted_id = getattr(inserted, id_field)
setattr(self, id_field, inserted_id)
def __str__(self):
linked = {k: getattr(self, k) for k in dir(self) if not k.startswith("_")}
for k in list(linked):
v = linked[k]
if inspect.ismethod(v):
linked.pop(k)
continue
attrs = asdict(self)
attrs.update(linked)
s = pprint.pformat(attrs, sort_dicts=False)
classname = self.__class__.__name__
def indented(s):
return textwrap.indent(s, " ")
return f"{classname}:\n{indented(s)}"
def make_class(inspected_table):
it = inspected_table
classname = it.name.title()
columns = [
(c.name, c.pytype, dataclass_field(default=None))
for cname, c in it.columns.items()
]
return make_dataclass(classname, columns, bases=(Entity,))
def class_for_table(table_name):
inspected_table = i.tables[f'"public"."{table_name}"']
return make_class(inspected_table)
def query_to_objects(query, _class, linked_classes=None):
linked_classes = linked_classes or []
rows = db.ss(query)
class_name = _class.__name__
fields = [_.name for _ in dataclass_fields(_class)]
grouped = rows.grouped_by(columns=fields).values()
objects_with_rows = [(_class(**{f: g[0][f] for f in fields}), g) for g in grouped]
for obj, rows in objects_with_rows:
for linked_class in linked_classes:
linked_fields = [_.name for _ in dataclass_fields(linked_class)]
table_name = linked_class.__name__.lower()
id_field = f"{table_name}_id"
linked_objects = [
linked_class(**{k: r[k] for k in linked_fields}) for r in rows
]
for o in linked_objects:
setattr(o, class_name.lower(), obj)
if hasattr(obj, id_field): # one-to-one
setattr(obj, table_name, linked_objects[0])
elif id_field in linked_fields: # one-to-many
setattr(obj, f"{table_name}s", linked_objects)
return [_[0] for _ in objects_with_rows]
def init_db():
global i
try:
db.ss("select * from book")
except Exception:
db.raw(SCHEMA_SQL)
for t in "book issue borrower".split():
db.ss(f"delete from {t};")
i = db.inspect()
def main():
Book = class_for_table("book")
Issue = class_for_table("issue")
Borrower = class_for_table("borrower")
book = Book(title="A Tale of Two Cities", author="Charles Dickens")
book.save(db)
borrower = Borrower(name="Alice Ellison")
borrower.save(db)
borrower2 = Borrower(name="Sven Svensson")
borrower2.save(db)
issue = Issue(book_id=book.book_id, borrower_id=borrower.borrower_id)
issue.save(db)
issue = Issue(book_id=book.book_id, borrower_id=borrower2.borrower_id)
issue.save(db)
books = query_to_objects("select * from book", Book)
Q = """\
select
*
from
issue
join book using (book_id)
join borrower using (borrower_id)
"""
issues = query_to_objects(Q, Issue, [Book, Borrower])
for i in issues:
print(i)
Q = """\
select
*
from
issue
join book using (book_id)
join borrower using (borrower_id)
"""
books = query_to_objects(Q, Book, [Issue, Borrower])
for b in books:
print(b)
if __name__ == "__main__":
init_db()
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment