Skip to content

Instantly share code, notes, and snippets.

@Johann150
Created March 31, 2024 12:03
Show Gist options
  • Save Johann150/30366a696532998d9feb8313ea2966ac to your computer and use it in GitHub Desktop.
Save Johann150/30366a696532998d9feb8313ea2966ac to your computer and use it in GitHub Desktop.
python dataclasses ORM
import sqlite3
from typing import Optional
from pprint import pprint
import orm
orm.db = sqlite3.connect(":memory:")
orm.db.row_factory=sqlite3.Row
orm.db.executescript("""
CREATE TABLE foo (id INTEGER PRIMARY KEY, name TEXT NOT NULL) STRICT;
INSERT INTO foo (id, name) VALUES (0, 'bar'), (1, 'baz');
CREATE TABLE thing (id INTEGER PRIMARY KEY, name TEXT NOT NULL, description TEXT) STRICT;
INSERT INTO thing (id, name, description) VALUES (0, 'asdf', NULL), (1, 'qwert', 'hjkl');
""")
orm.db.commit()
################
@orm.DbObject('foo', 'id')
class Foo:
id: Optional[int]
name: str
@orm.DbObject('thing', 'id')
class Thing:
"""This is a Thing."""
id: Optional[int]
name: str
description: Optional[str]
################
f = Foo.load(name='baz')
assert f.id == 1
pprint(f)
f.name='booz'
f.save()
assert f.id == 1
assert f.name == 'booz'
assert Foo.load(id=1).name == 'booz'
pprint(f)
g = Foo(id=None, name='booz')
assert g.id is None
g.save()
assert g.id == 2
assert g.name == 'booz'
pprint(g)
data = Foo.load_all(name='booz')
assert len(data) == 2
pprint(data)
g = Foo.load(id=2)
assert g is not None
g.delete()
data = Foo.load_all()
assert len(data) == 2
pprint(data)
thing = Thing.load(id=0)
assert thing is not None
assert thing.id == 0
pprint(thing)
try:
things = Thing.load_all()
pprint(things)
things[0].id = 1 # oops, not allowed. raises an AttributeError
things[0].save()
assert False
except AttributeError:
pass
from typing import Optional, Type
import functools
import dataclasses
db = None
def DbObject(tablename, primary_key_keys):
"""
Decorates a class to be able to `load` from, `load_all` and `save` to a database.
"""
global db
if isinstance(primary_key_keys, list):
primary_keys = primary_key_keys
else:
primary_keys = [primary_key_keys]
if len(primary_keys) == 0:
raise TypeError('no primary_keys defined')
def decorator(cls: Type) -> Type:
cls = dataclasses.dataclass(cls)
field_names = [f.name for f in dataclasses.fields(cls)]
for key in primary_keys:
if key not in field_names:
raise TypeError(f'primary key {repr(key)} is not present as a field of {cls.__name__}')
# prevent assigning to any attributes which are defined as primary keys
old_setattr = cls.__setattr__
@functools.wraps(old_setattr)
def setattr_wrapper(self, name, value):
if name in primary_keys:
raise AttributeError(f'cannot assign to primary key field {repr(name)}')
else:
return old_setattr(self, name, value) # type: ignore
cls.__setattr__ = setattr_wrapper
# prevent deleting any primary key fields
old_delattr = cls.__delattr__
@functools.wraps(old_delattr)
def delattr_wrapper(self, name):
if name in primary_keys:
raise AttributeError(f'cannot delete primary key field {repr(name)}')
else:
return old_delattr(self, name) # type: ignore
cls.__delattr__ = delattr_wrapper
# now wrap the constructor, setting primary keys needs to be allowed here
old_init = cls.__init__
@functools.wraps(old_init)
def init_wrapper(self, *args, **kwargs):
cls.__setattr__ = old_setattr
value = old_init(self, *args, **kwargs)
cls.__setattr__ = setattr_wrapper
return value
cls.__init__ = init_wrapper
@classmethod
def load(cls, **kwargs):
"""
Load an object from the database using the given kwargs.
If no object matches, `None` is returned.
If an object matches, it will be constructed and then returned.
If more than one one object matches, an arbitrary one will be returned.
"""
global db
for name in kwargs:
if name not in field_names:
raise TypeError(f'{cls.__name__}.load() got an unexpected keyword argument {repr(name)}')
sql_fields = ' AND '.join([f'"{name}" = :{name}' for name in kwargs])
data = db.execute(f'SELECT * FROM "{tablename}" WHERE {sql_fields}', kwargs).fetchone()
if data is None:
return None
else:
return cls(**data)
@classmethod
def load_all(cls, _where: str = '1', **kwargs):
"""
Load all objects from the database which mach the kwargs.
Using a kwarg which is not defined as a parameter will raise a `TypeError`.
If necessary, a custom WHERE clause can be submitted with the `_where` parameter.
When using `_where`, the callee is responsible for not introducing SQL injection.
If both `_where` and other kwargs are specified, all will be applied with an `AND` conjunction.
"""
global db
where = '(' + _where + ')'
for name in kwargs:
if name not in field_names:
raise TypeError(f'{cls.__name__}.load_all() got an unexpected keyword argument {repr(name)}')
where += f' AND ({name} = :{name})'
sql = f'SELECT * FROM "{tablename}" WHERE ({where})'
return [
cls(**row)
for row in db.execute(sql, kwargs)
]
def _update(self):
global db
sql_fields = ', '.join([f'"{name}" = :{name}' for name in field_names])
sql_where = ' AND '.join([f'"{key}" = :{key}' for key in primary_keys])
sql = f'UPDATE "{tablename}" SET {sql_fields} WHERE {sql_where}'
db.execute(sql, dataclasses.asdict(self))
db.commit()
def _insert(self):
global db
sql_fields = [
name
for name in field_names
if not (name in primary_keys and getattr(self, name) is None)
]
sql_field_names = ', '.join(sql_fields)
sql_placeholders = ', '.join([f':{name}' for name in sql_fields])
sql = f'INSERT INTO "{tablename}" ({sql_field_names}) VALUES ({sql_placeholders})'
rowid = db.execute(sql, dataclasses.asdict(self)).lastrowid
db.commit()
return rowid
def save(self):
is_new = any(map(
lambda key: getattr(self, key) is None,
primary_keys
))
if is_new:
rowid = _insert(self)
if len(primary_keys) == 1:
old_setattr(self, primary_keys[0], rowid) # type: ignore
# FIXME: what to do if there is no rowid
else:
_update(self)
def delete(self):
global db
is_new = any(map(
lambda key: getattr(self, key) is None,
primary_keys
))
if is_new:
raise ValueError('tried to delete a new object')
else:
sql_where = ' AND '.join([f'"{key}" = :{key}' for key in primary_keys])
sql = f'DELETE FROM {tablename} WHERE {sql_where}'
db.execute(sql, dataclasses.asdict(self))
db.commit()
cls.load = load
cls.load_all = load_all
cls.save = save
cls.delete = delete
return cls
return decorator
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment