Last active
September 22, 2022 18:36
-
-
Save pingiun/56b96ecec1b2c445f3c3d71dc9efd517 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sqlalchemy import Column, PickleType | |
def upsert(obj): | |
"""Inserts the `obj` or updates if it already exists | |
Postgresql is the only database that supports upserts like this, but sqlite is used locally so a "hack" is used to support | |
an alternative check and update/insert method. This cannot be used when multithreading.""" | |
# SQLalchemy model objects have the table object in __table__ | |
table = obj.__table__ | |
# Filter out private values from the object | |
values = {k: v for k, v in obj.__dict__.items() if not k.startswith('_')} | |
if db_session.bind.dialect.name == 'postgresql': | |
# https://gist.github.com/bhtucker/c40578a2fb3ca50b324e42ef9dce58e1 | |
update_cols = [c.name for c in table.columns | |
if c not in list(table.primary_key.columns)] | |
# http://docs.sqlalchemy.org/en/latest/dialects/postgresql.html#insert-on-conflict-upsert | |
stmt = insert(table).values(values) | |
stmt_on_conflict = stmt.on_conflict_do_update(index_elements=table.primary_key.columns, | |
set_={k: getattr(stmt.excluded, k) for k in update_cols}) | |
db_session.execute(stmt_on_conflict) | |
elif db_session.bind.dialect.name == 'sqlite': | |
# For local testing we use the standard test then insert/update method | |
# This means that you should use a single thread when testing with sqlite | |
# or just run postgres locally | |
keys = {k: v for k, v in obj.__dict__.items() if not k.startswith('_') and k in list(table.primary_key.columns)} | |
if db_session.query(obj.__class__).filter_by(**keys).first() is not None: | |
db_session.query(obj.__class__).filter_by(**keys).update(values) | |
else: | |
db_session.add(obj) | |
db_session.commit() | |
else: | |
raise RuntimeError("Only postgres and sqlite (for local testing) are supported") | |
class KeyValueBinary: | |
"""Can be extended to create tables that store a python dict (in the form of conversationhandler.DatabaseDict) | |
Example: | |
``` | |
Base = declarative_base() | |
Base.query = db_session.query_property() | |
class EmailConversation(Base, KeyValueBinary): | |
__tablename__ = 'emailconversation' | |
pass | |
``` | |
""" | |
key = Column(PickleType, primary_key=True) | |
value = Column(PickleType) | |
class DatabaseDict(MutableMapping): | |
"""A dictionary that is stored in a database | |
Example: | |
``` | |
# db_session from the session maker | |
db_dict = DatabaseDict(db_session, EmailConversation) | |
db_dict['test'] = 3 | |
""" | |
def __init__(self, db_session, table): | |
super().__init__() | |
if not issubclass(table, KeyValueBinary): | |
raise RuntimeError("Table must be a KeyValueBinary") | |
self.db_session = db_session | |
self.table = table | |
def __iter__(self): | |
pass # TODO: implement | |
def __len__(self): | |
return self.table.query.count() | |
def __getitem__(self, key): | |
item = self.table.query.filter(self.table.key == key).first() | |
if item is None: | |
raise KeyError("Key not found in the database: {}".format(key)) | |
else: | |
return item.value | |
def __delitem__(self, key): | |
self.table.query.filter(self.table.key == key).delete() | |
self.db_session.commit() | |
def __setitem__(self, key, value): | |
upsert(self.table(key=key, value=value)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment