Last active
February 15, 2021 10:29
-
-
Save brianz/feedc052d64212b6576fa42dd6dcadab to your computer and use it in GitHub Desktop.
SQLAchemy helpers and mixins
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import contextlib | |
from contextlib import contextmanager | |
from sqlalchemy import create_engine | |
from sqlalchemy.orm import sessionmaker | |
from sqlalchemy.pool import NullPool | |
from sqlalchemy_utils import database_exists, create_database | |
from common.logging import get_logger | |
from ..constants import ( | |
CLIENT_DB_HOST, | |
CLIENT_DB_NAME, | |
CLIENT_DB_PASSWORD, | |
CLIENT_DB_PORT, | |
CLIENT_DB_USERNAME, | |
) | |
__session_factory = None | |
__session = None | |
__engine = None | |
__is_test = False | |
def setup_db(*, is_test=False, **db_config): | |
global __engine | |
global __is_test | |
if __engine: | |
return | |
__is_test = is_test | |
connection_string = get_connection_string() | |
connection_kwargs = db_config.get('connection_kwargs', {}) | |
# we always want to close connections | |
connection_kwargs['poolclass'] = NullPool | |
session_kwargs = db_config.get('session_kwargs', {}) | |
__engine = create_engine(connection_string, **connection_kwargs) | |
if not database_exists(__engine.url): # pragma: no cover | |
print('Creating database: %s' % (__engine.url, )) | |
create_database(__engine.url) | |
create_tables() | |
get_session(**session_kwargs) | |
def get_connection_string(**kwargs): | |
"""Return a connection string for sqlalchemy:: | |
dialect+driver://username:password@host:port/database | |
""" | |
return 'postgresql://%s:%s@%s:%s/%s' % ( | |
CLIENT_DB_USERNAME, | |
CLIENT_DB_PASSWORD, | |
CLIENT_DB_HOST, | |
CLIENT_DB_PORT, | |
CLIENT_DB_NAME, | |
) | |
def close_db(): # pragma: no cover | |
if not __session: | |
return | |
try: | |
__session.commit() | |
except: | |
__session.rollback() | |
finally: | |
__session.close() | |
def commit_session(_raise=True): # pragma: no cover | |
if not __session: | |
return | |
try: | |
__session.commit() | |
except Exception as e: | |
__session.rollback() | |
if _raise: | |
raise | |
def _get_metadata(): | |
from .mixins import Base | |
return Base.metadata | |
def create_tables(): | |
assert __engine | |
meta = _get_metadata() | |
meta.create_all(__engine) | |
def _drop_tables(*, force=False): | |
if not __is_test and not force: | |
return | |
assert __engine | |
meta = _get_metadata() | |
meta.drop_all(__engine) | |
def _clear_tables(*, force=False): | |
if not __is_test and not force: | |
return | |
assert __engine | |
meta = _get_metadata() | |
with contextlib.closing(__engine.connect()) as con: | |
trans = con.begin() | |
for table in reversed(meta.sorted_tables): | |
try: | |
con.execute(table.delete()) | |
except: | |
pass | |
trans.commit() | |
def get_session(**kwargs): | |
setup_db() | |
assert __engine | |
global __session | |
global __session_factory | |
if __session is not None: | |
return __session | |
if __session_factory is None: # pragma: no cover | |
__session_factory = sessionmaker(bind=__engine, **kwargs) | |
__session = __session_factory() | |
return __session | |
def session_getter(func): | |
"""Decorator to get a session and inject it as the first argument in a function""" | |
def wrapper(*args, **kwargs): | |
with dbtransaction() as session: | |
return func(session, *args, **kwargs) | |
return wrapper | |
def session_committer(func): | |
def wrapper(*args, **kwargs): | |
try: | |
return func(*args, **kwargs) | |
finally: | |
log = get_logger() | |
log.info('session_committer') | |
commit_session() | |
return wrapper | |
@contextmanager | |
def dbtransaction(): | |
session = get_session() | |
try: | |
yield session | |
session.commit() | |
except: | |
session.rollback() | |
raise |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import re | |
from sqlalchemy.ext.declarative import ( | |
declarative_base, | |
declared_attr, | |
) | |
from sqlalchemy import ( | |
MetaData, | |
Index, | |
) | |
from . import ( | |
commit_session, | |
get_session, | |
) | |
INDEX_CONVENTION = { | |
"ix": "ix_%(table_name)s_%(column_0_label)s", | |
"uq": "uq_%(table_name)s_%(column_0_name)s", | |
"fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s", | |
"pk": "pk_%(table_name)s" | |
} | |
METADATA = MetaData(naming_convention=INDEX_CONVENTION) | |
Base = declarative_base(metadata=METADATA) | |
def class_name_to_underscores(name): | |
"""Helper to turn a class name in camelCase to camel_case""" | |
s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name) | |
return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower() | |
class ModelMixin: | |
"""Mixin to give models some helpful functionality""" | |
@declared_attr | |
def __tablename__(cls): | |
"""Automatically create table names using a convention from the name of the model class""" | |
name = class_name_to_underscores(cls.__name__) | |
if name.endswith('y'): | |
name = name.rstrip('y') + 'ies' | |
elif not name.endswith('s'): | |
name = name + 's' | |
return name | |
def save(self, *, commit=True): | |
"""Add in a standard way of saving a model instance""" | |
session = get_session() | |
session.add(self) | |
if commit: | |
commit_session() | |
@classmethod | |
def flush(self): # pragma: no cover | |
get_session().flush() | |
@classmethod | |
def _get_index(cls, name, *columns, **kwargs): # pragma: no cover | |
cols = [getattr(cls, c) for c in columns] | |
return Index(name, *cols, **kwargs) | |
@classmethod | |
def generate_index(cls, *columns, **kwargs): # pragma: no cover | |
name_mapping = { | |
'table_name': cls.__tablename__, | |
'column_0_label': '_'.join(columns), | |
} | |
name = INDEX_CONVENTION['ix'] % name_mapping | |
return cls._get_index(name, *columns, **kwargs) | |
@classmethod | |
def generate_unique_index(cls, *columns, **kwargs): # pragma: no cover | |
name_mapping = { | |
'table_name': cls.__tablename__, | |
'column_0_name': '_'.join(columns), | |
} | |
name = INDEX_CONVENTION['uq'] % name_mapping | |
if not kwargs.get('unique'): | |
kwargs['unique'] = True | |
return cls._get_index(name, *columns, **kwargs) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment