Skip to content

Instantly share code, notes, and snippets.

@pingiun
Last active September 22, 2022 18:36
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save pingiun/56b96ecec1b2c445f3c3d71dc9efd517 to your computer and use it in GitHub Desktop.
Save pingiun/56b96ecec1b2c445f3c3d71dc9efd517 to your computer and use it in GitHub Desktop.
from sqlalchemy import Column, PickleType
def upsert(obj):
"""Inserts the `obj` or updates if it already exists
Postgresql is the only database that supports upserts like this, but sqlite is used locally so a "hack" is used to support
an alternative check and update/insert method. This cannot be used when multithreading."""
# SQLalchemy model objects have the table object in __table__
table = obj.__table__
# Filter out private values from the object
values = {k: v for k, v in obj.__dict__.items() if not k.startswith('_')}
if db_session.bind.dialect.name == 'postgresql':
# https://gist.github.com/bhtucker/c40578a2fb3ca50b324e42ef9dce58e1
update_cols = [c.name for c in table.columns
if c not in list(table.primary_key.columns)]
# http://docs.sqlalchemy.org/en/latest/dialects/postgresql.html#insert-on-conflict-upsert
stmt = insert(table).values(values)
stmt_on_conflict = stmt.on_conflict_do_update(index_elements=table.primary_key.columns,
set_={k: getattr(stmt.excluded, k) for k in update_cols})
db_session.execute(stmt_on_conflict)
elif db_session.bind.dialect.name == 'sqlite':
# For local testing we use the standard test then insert/update method
# This means that you should use a single thread when testing with sqlite
# or just run postgres locally
keys = {k: v for k, v in obj.__dict__.items() if not k.startswith('_') and k in list(table.primary_key.columns)}
if db_session.query(obj.__class__).filter_by(**keys).first() is not None:
db_session.query(obj.__class__).filter_by(**keys).update(values)
else:
db_session.add(obj)
db_session.commit()
else:
raise RuntimeError("Only postgres and sqlite (for local testing) are supported")
class KeyValueBinary:
"""Can be extended to create tables that store a python dict (in the form of conversationhandler.DatabaseDict)
Example:
```
Base = declarative_base()
Base.query = db_session.query_property()
class EmailConversation(Base, KeyValueBinary):
__tablename__ = 'emailconversation'
pass
```
"""
key = Column(PickleType, primary_key=True)
value = Column(PickleType)
class DatabaseDict(MutableMapping):
"""A dictionary that is stored in a database
Example:
```
# db_session from the session maker
db_dict = DatabaseDict(db_session, EmailConversation)
db_dict['test'] = 3
"""
def __init__(self, db_session, table):
super().__init__()
if not issubclass(table, KeyValueBinary):
raise RuntimeError("Table must be a KeyValueBinary")
self.db_session = db_session
self.table = table
def __iter__(self):
pass # TODO: implement
def __len__(self):
return self.table.query.count()
def __getitem__(self, key):
item = self.table.query.filter(self.table.key == key).first()
if item is None:
raise KeyError("Key not found in the database: {}".format(key))
else:
return item.value
def __delitem__(self, key):
self.table.query.filter(self.table.key == key).delete()
self.db_session.commit()
def __setitem__(self, key, value):
upsert(self.table(key=key, value=value))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment