Skip to content

Instantly share code, notes, and snippets.

@brianz
Last active February 15, 2021 10:29
Show Gist options
  • Star 5 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save brianz/feedc052d64212b6576fa42dd6dcadab to your computer and use it in GitHub Desktop.
Save brianz/feedc052d64212b6576fa42dd6dcadab to your computer and use it in GitHub Desktop.
SQLAchemy helpers and mixins
import contextlib
from contextlib import contextmanager
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import NullPool
from sqlalchemy_utils import database_exists, create_database
from common.logging import get_logger
from ..constants import (
CLIENT_DB_HOST,
CLIENT_DB_NAME,
CLIENT_DB_PASSWORD,
CLIENT_DB_PORT,
CLIENT_DB_USERNAME,
)
__session_factory = None
__session = None
__engine = None
__is_test = False
def setup_db(*, is_test=False, **db_config):
global __engine
global __is_test
if __engine:
return
__is_test = is_test
connection_string = get_connection_string()
connection_kwargs = db_config.get('connection_kwargs', {})
# we always want to close connections
connection_kwargs['poolclass'] = NullPool
session_kwargs = db_config.get('session_kwargs', {})
__engine = create_engine(connection_string, **connection_kwargs)
if not database_exists(__engine.url): # pragma: no cover
print('Creating database: %s' % (__engine.url, ))
create_database(__engine.url)
create_tables()
get_session(**session_kwargs)
def get_connection_string(**kwargs):
"""Return a connection string for sqlalchemy::
dialect+driver://username:password@host:port/database
"""
return 'postgresql://%s:%s@%s:%s/%s' % (
CLIENT_DB_USERNAME,
CLIENT_DB_PASSWORD,
CLIENT_DB_HOST,
CLIENT_DB_PORT,
CLIENT_DB_NAME,
)
def close_db(): # pragma: no cover
if not __session:
return
try:
__session.commit()
except:
__session.rollback()
finally:
__session.close()
def commit_session(_raise=True): # pragma: no cover
if not __session:
return
try:
__session.commit()
except Exception as e:
__session.rollback()
if _raise:
raise
def _get_metadata():
from .mixins import Base
return Base.metadata
def create_tables():
assert __engine
meta = _get_metadata()
meta.create_all(__engine)
def _drop_tables(*, force=False):
if not __is_test and not force:
return
assert __engine
meta = _get_metadata()
meta.drop_all(__engine)
def _clear_tables(*, force=False):
if not __is_test and not force:
return
assert __engine
meta = _get_metadata()
with contextlib.closing(__engine.connect()) as con:
trans = con.begin()
for table in reversed(meta.sorted_tables):
try:
con.execute(table.delete())
except:
pass
trans.commit()
def get_session(**kwargs):
setup_db()
assert __engine
global __session
global __session_factory
if __session is not None:
return __session
if __session_factory is None: # pragma: no cover
__session_factory = sessionmaker(bind=__engine, **kwargs)
__session = __session_factory()
return __session
def session_getter(func):
"""Decorator to get a session and inject it as the first argument in a function"""
def wrapper(*args, **kwargs):
with dbtransaction() as session:
return func(session, *args, **kwargs)
return wrapper
def session_committer(func):
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
finally:
log = get_logger()
log.info('session_committer')
commit_session()
return wrapper
@contextmanager
def dbtransaction():
session = get_session()
try:
yield session
session.commit()
except:
session.rollback()
raise
import re
from sqlalchemy.ext.declarative import (
declarative_base,
declared_attr,
)
from sqlalchemy import (
MetaData,
Index,
)
from . import (
commit_session,
get_session,
)
INDEX_CONVENTION = {
"ix": "ix_%(table_name)s_%(column_0_label)s",
"uq": "uq_%(table_name)s_%(column_0_name)s",
"fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
"pk": "pk_%(table_name)s"
}
METADATA = MetaData(naming_convention=INDEX_CONVENTION)
Base = declarative_base(metadata=METADATA)
def class_name_to_underscores(name):
"""Helper to turn a class name in camelCase to camel_case"""
s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name)
return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower()
class ModelMixin:
"""Mixin to give models some helpful functionality"""
@declared_attr
def __tablename__(cls):
"""Automatically create table names using a convention from the name of the model class"""
name = class_name_to_underscores(cls.__name__)
if name.endswith('y'):
name = name.rstrip('y') + 'ies'
elif not name.endswith('s'):
name = name + 's'
return name
def save(self, *, commit=True):
"""Add in a standard way of saving a model instance"""
session = get_session()
session.add(self)
if commit:
commit_session()
@classmethod
def flush(self): # pragma: no cover
get_session().flush()
@classmethod
def _get_index(cls, name, *columns, **kwargs): # pragma: no cover
cols = [getattr(cls, c) for c in columns]
return Index(name, *cols, **kwargs)
@classmethod
def generate_index(cls, *columns, **kwargs): # pragma: no cover
name_mapping = {
'table_name': cls.__tablename__,
'column_0_label': '_'.join(columns),
}
name = INDEX_CONVENTION['ix'] % name_mapping
return cls._get_index(name, *columns, **kwargs)
@classmethod
def generate_unique_index(cls, *columns, **kwargs): # pragma: no cover
name_mapping = {
'table_name': cls.__tablename__,
'column_0_name': '_'.join(columns),
}
name = INDEX_CONVENTION['uq'] % name_mapping
if not kwargs.get('unique'):
kwargs['unique'] = True
return cls._get_index(name, *columns, **kwargs)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment