Created
March 31, 2024 12:03
-
-
Save Johann150/30366a696532998d9feb8313ea2966ac to your computer and use it in GitHub Desktop.
python dataclasses ORM
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sqlite3 | |
from typing import Optional | |
from pprint import pprint | |
import orm | |
orm.db = sqlite3.connect(":memory:") | |
orm.db.row_factory=sqlite3.Row | |
orm.db.executescript(""" | |
CREATE TABLE foo (id INTEGER PRIMARY KEY, name TEXT NOT NULL) STRICT; | |
INSERT INTO foo (id, name) VALUES (0, 'bar'), (1, 'baz'); | |
CREATE TABLE thing (id INTEGER PRIMARY KEY, name TEXT NOT NULL, description TEXT) STRICT; | |
INSERT INTO thing (id, name, description) VALUES (0, 'asdf', NULL), (1, 'qwert', 'hjkl'); | |
""") | |
orm.db.commit() | |
################ | |
@orm.DbObject('foo', 'id') | |
class Foo: | |
id: Optional[int] | |
name: str | |
@orm.DbObject('thing', 'id') | |
class Thing: | |
"""This is a Thing.""" | |
id: Optional[int] | |
name: str | |
description: Optional[str] | |
################ | |
f = Foo.load(name='baz') | |
assert f.id == 1 | |
pprint(f) | |
f.name='booz' | |
f.save() | |
assert f.id == 1 | |
assert f.name == 'booz' | |
assert Foo.load(id=1).name == 'booz' | |
pprint(f) | |
g = Foo(id=None, name='booz') | |
assert g.id is None | |
g.save() | |
assert g.id == 2 | |
assert g.name == 'booz' | |
pprint(g) | |
data = Foo.load_all(name='booz') | |
assert len(data) == 2 | |
pprint(data) | |
g = Foo.load(id=2) | |
assert g is not None | |
g.delete() | |
data = Foo.load_all() | |
assert len(data) == 2 | |
pprint(data) | |
thing = Thing.load(id=0) | |
assert thing is not None | |
assert thing.id == 0 | |
pprint(thing) | |
try: | |
things = Thing.load_all() | |
pprint(things) | |
things[0].id = 1 # oops, not allowed. raises an AttributeError | |
things[0].save() | |
assert False | |
except AttributeError: | |
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Optional, Type | |
import functools | |
import dataclasses | |
db = None | |
def DbObject(tablename, primary_key_keys): | |
""" | |
Decorates a class to be able to `load` from, `load_all` and `save` to a database. | |
""" | |
global db | |
if isinstance(primary_key_keys, list): | |
primary_keys = primary_key_keys | |
else: | |
primary_keys = [primary_key_keys] | |
if len(primary_keys) == 0: | |
raise TypeError('no primary_keys defined') | |
def decorator(cls: Type) -> Type: | |
cls = dataclasses.dataclass(cls) | |
field_names = [f.name for f in dataclasses.fields(cls)] | |
for key in primary_keys: | |
if key not in field_names: | |
raise TypeError(f'primary key {repr(key)} is not present as a field of {cls.__name__}') | |
# prevent assigning to any attributes which are defined as primary keys | |
old_setattr = cls.__setattr__ | |
@functools.wraps(old_setattr) | |
def setattr_wrapper(self, name, value): | |
if name in primary_keys: | |
raise AttributeError(f'cannot assign to primary key field {repr(name)}') | |
else: | |
return old_setattr(self, name, value) # type: ignore | |
cls.__setattr__ = setattr_wrapper | |
# prevent deleting any primary key fields | |
old_delattr = cls.__delattr__ | |
@functools.wraps(old_delattr) | |
def delattr_wrapper(self, name): | |
if name in primary_keys: | |
raise AttributeError(f'cannot delete primary key field {repr(name)}') | |
else: | |
return old_delattr(self, name) # type: ignore | |
cls.__delattr__ = delattr_wrapper | |
# now wrap the constructor, setting primary keys needs to be allowed here | |
old_init = cls.__init__ | |
@functools.wraps(old_init) | |
def init_wrapper(self, *args, **kwargs): | |
cls.__setattr__ = old_setattr | |
value = old_init(self, *args, **kwargs) | |
cls.__setattr__ = setattr_wrapper | |
return value | |
cls.__init__ = init_wrapper | |
@classmethod | |
def load(cls, **kwargs): | |
""" | |
Load an object from the database using the given kwargs. | |
If no object matches, `None` is returned. | |
If an object matches, it will be constructed and then returned. | |
If more than one one object matches, an arbitrary one will be returned. | |
""" | |
global db | |
for name in kwargs: | |
if name not in field_names: | |
raise TypeError(f'{cls.__name__}.load() got an unexpected keyword argument {repr(name)}') | |
sql_fields = ' AND '.join([f'"{name}" = :{name}' for name in kwargs]) | |
data = db.execute(f'SELECT * FROM "{tablename}" WHERE {sql_fields}', kwargs).fetchone() | |
if data is None: | |
return None | |
else: | |
return cls(**data) | |
@classmethod | |
def load_all(cls, _where: str = '1', **kwargs): | |
""" | |
Load all objects from the database which mach the kwargs. | |
Using a kwarg which is not defined as a parameter will raise a `TypeError`. | |
If necessary, a custom WHERE clause can be submitted with the `_where` parameter. | |
When using `_where`, the callee is responsible for not introducing SQL injection. | |
If both `_where` and other kwargs are specified, all will be applied with an `AND` conjunction. | |
""" | |
global db | |
where = '(' + _where + ')' | |
for name in kwargs: | |
if name not in field_names: | |
raise TypeError(f'{cls.__name__}.load_all() got an unexpected keyword argument {repr(name)}') | |
where += f' AND ({name} = :{name})' | |
sql = f'SELECT * FROM "{tablename}" WHERE ({where})' | |
return [ | |
cls(**row) | |
for row in db.execute(sql, kwargs) | |
] | |
def _update(self): | |
global db | |
sql_fields = ', '.join([f'"{name}" = :{name}' for name in field_names]) | |
sql_where = ' AND '.join([f'"{key}" = :{key}' for key in primary_keys]) | |
sql = f'UPDATE "{tablename}" SET {sql_fields} WHERE {sql_where}' | |
db.execute(sql, dataclasses.asdict(self)) | |
db.commit() | |
def _insert(self): | |
global db | |
sql_fields = [ | |
name | |
for name in field_names | |
if not (name in primary_keys and getattr(self, name) is None) | |
] | |
sql_field_names = ', '.join(sql_fields) | |
sql_placeholders = ', '.join([f':{name}' for name in sql_fields]) | |
sql = f'INSERT INTO "{tablename}" ({sql_field_names}) VALUES ({sql_placeholders})' | |
rowid = db.execute(sql, dataclasses.asdict(self)).lastrowid | |
db.commit() | |
return rowid | |
def save(self): | |
is_new = any(map( | |
lambda key: getattr(self, key) is None, | |
primary_keys | |
)) | |
if is_new: | |
rowid = _insert(self) | |
if len(primary_keys) == 1: | |
old_setattr(self, primary_keys[0], rowid) # type: ignore | |
# FIXME: what to do if there is no rowid | |
else: | |
_update(self) | |
def delete(self): | |
global db | |
is_new = any(map( | |
lambda key: getattr(self, key) is None, | |
primary_keys | |
)) | |
if is_new: | |
raise ValueError('tried to delete a new object') | |
else: | |
sql_where = ' AND '.join([f'"{key}" = :{key}' for key in primary_keys]) | |
sql = f'DELETE FROM {tablename} WHERE {sql_where}' | |
db.execute(sql, dataclasses.asdict(self)) | |
db.commit() | |
cls.load = load | |
cls.load_all = load_all | |
cls.save = save | |
cls.delete = delete | |
return cls | |
return decorator |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment