SQLAchemy helpers and mixins
import contextlib | |
from contextlib import contextmanager | |
from sqlalchemy import create_engine | |
from sqlalchemy.orm import sessionmaker | |
from sqlalchemy.pool import NullPool | |
from sqlalchemy_utils import database_exists, create_database | |
from common.logging import get_logger | |
from ..constants import ( | |
CLIENT_DB_HOST, | |
CLIENT_DB_NAME, | |
CLIENT_DB_PASSWORD, | |
CLIENT_DB_PORT, | |
CLIENT_DB_USERNAME, | |
) | |
__session_factory = None | |
__session = None | |
__engine = None | |
__is_test = False | |
def setup_db(*, is_test=False, **db_config): | |
global __engine | |
global __is_test | |
if __engine: | |
return | |
__is_test = is_test | |
connection_string = get_connection_string() | |
connection_kwargs = db_config.get('connection_kwargs', {}) | |
# we always want to close connections | |
connection_kwargs['poolclass'] = NullPool | |
session_kwargs = db_config.get('session_kwargs', {}) | |
__engine = create_engine(connection_string, **connection_kwargs) | |
if not database_exists(__engine.url): # pragma: no cover | |
print('Creating database: %s' % (__engine.url, )) | |
create_database(__engine.url) | |
create_tables() | |
get_session(**session_kwargs) | |
def get_connection_string(**kwargs): | |
"""Return a connection string for sqlalchemy:: | |
dialect+driver://username:password@host:port/database | |
""" | |
return 'postgresql://%s:%s@%s:%s/%s' % ( | |
CLIENT_DB_USERNAME, | |
CLIENT_DB_PASSWORD, | |
CLIENT_DB_HOST, | |
CLIENT_DB_PORT, | |
CLIENT_DB_NAME, | |
) | |
def close_db(): # pragma: no cover | |
if not __session: | |
return | |
try: | |
__session.commit() | |
except: | |
__session.rollback() | |
finally: | |
__session.close() | |
def commit_session(_raise=True): # pragma: no cover | |
if not __session: | |
return | |
try: | |
__session.commit() | |
except Exception as e: | |
__session.rollback() | |
if _raise: | |
raise | |
def _get_metadata(): | |
from .mixins import Base | |
return Base.metadata | |
def create_tables(): | |
assert __engine | |
meta = _get_metadata() | |
meta.create_all(__engine) | |
def _drop_tables(*, force=False): | |
if not __is_test and not force: | |
return | |
assert __engine | |
meta = _get_metadata() | |
meta.drop_all(__engine) | |
def _clear_tables(*, force=False): | |
if not __is_test and not force: | |
return | |
assert __engine | |
meta = _get_metadata() | |
with contextlib.closing(__engine.connect()) as con: | |
trans = con.begin() | |
for table in reversed(meta.sorted_tables): | |
try: | |
con.execute(table.delete()) | |
except: | |
pass | |
trans.commit() | |
def get_session(**kwargs): | |
setup_db() | |
assert __engine | |
global __session | |
global __session_factory | |
if __session is not None: | |
return __session | |
if __session_factory is None: # pragma: no cover | |
__session_factory = sessionmaker(bind=__engine, **kwargs) | |
__session = __session_factory() | |
return __session | |
def session_getter(func): | |
"""Decorator to get a session and inject it as the first argument in a function""" | |
def wrapper(*args, **kwargs): | |
with dbtransaction() as session: | |
return func(session, *args, **kwargs) | |
return wrapper | |
def session_committer(func): | |
def wrapper(*args, **kwargs): | |
try: | |
return func(*args, **kwargs) | |
finally: | |
log = get_logger() | |
log.info('session_committer') | |
commit_session() | |
return wrapper | |
@contextmanager | |
def dbtransaction(): | |
session = get_session() | |
try: | |
yield session | |
session.commit() | |
except: | |
session.rollback() | |
raise |
import re | |
from sqlalchemy.ext.declarative import ( | |
declarative_base, | |
declared_attr, | |
) | |
from sqlalchemy import ( | |
MetaData, | |
Index, | |
) | |
from . import ( | |
commit_session, | |
get_session, | |
) | |
INDEX_CONVENTION = { | |
"ix": "ix_%(table_name)s_%(column_0_label)s", | |
"uq": "uq_%(table_name)s_%(column_0_name)s", | |
"fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s", | |
"pk": "pk_%(table_name)s" | |
} | |
METADATA = MetaData(naming_convention=INDEX_CONVENTION) | |
Base = declarative_base(metadata=METADATA) | |
def class_name_to_underscores(name): | |
"""Helper to turn a class name in camelCase to camel_case""" | |
s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name) | |
return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower() | |
class ModelMixin: | |
"""Mixin to give models some helpful functionality""" | |
@declared_attr | |
def __tablename__(cls): | |
"""Automatically create table names using a convention from the name of the model class""" | |
name = class_name_to_underscores(cls.__name__) | |
if name.endswith('y'): | |
name = name.rstrip('y') + 'ies' | |
elif not name.endswith('s'): | |
name = name + 's' | |
return name | |
def save(self, *, commit=True): | |
"""Add in a standard way of saving a model instance""" | |
session = get_session() | |
session.add(self) | |
if commit: | |
commit_session() | |
@classmethod | |
def flush(self): # pragma: no cover | |
get_session().flush() | |
@classmethod | |
def _get_index(cls, name, *columns, **kwargs): # pragma: no cover | |
cols = [getattr(cls, c) for c in columns] | |
return Index(name, *cols, **kwargs) | |
@classmethod | |
def generate_index(cls, *columns, **kwargs): # pragma: no cover | |
name_mapping = { | |
'table_name': cls.__tablename__, | |
'column_0_label': '_'.join(columns), | |
} | |
name = INDEX_CONVENTION['ix'] % name_mapping | |
return cls._get_index(name, *columns, **kwargs) | |
@classmethod | |
def generate_unique_index(cls, *columns, **kwargs): # pragma: no cover | |
name_mapping = { | |
'table_name': cls.__tablename__, | |
'column_0_name': '_'.join(columns), | |
} | |
name = INDEX_CONVENTION['uq'] % name_mapping | |
if not kwargs.get('unique'): | |
kwargs['unique'] = True | |
return cls._get_index(name, *columns, **kwargs) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment