A set of helpers to work with SQLAlchemy setup, connections, transactions, etc.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import contextlib | |
from contextlib import contextmanager | |
from sqlalchemy import create_engine | |
from sqlalchemy.orm import sessionmaker | |
from sqlalchemy.pool import ( | |
NullPool, | |
QueuePool, | |
StaticPool, | |
) | |
from sqlalchemy_utils import database_exists, create_database | |
from ..constants import ( | |
DB_ENGINE, | |
DB_HOST, | |
DB_NAME, | |
DB_PASSWORD, | |
DB_PORT, | |
DB_USERNAME, | |
SQLITE, | |
POSTGRESQL, | |
) | |
__session_factory = None | |
__session = None | |
__engine = None | |
__is_test = False | |
def _get_metadata(): | |
from .mixins import Base | |
return Base.metadata | |
def setup_db(*, is_test=False, **db_config): | |
"""Main setup function for our database. | |
This will perform the initial DB connection and also create tables as needed. | |
""" | |
global __engine | |
global __is_test | |
if __engine: | |
return | |
__is_test = is_test | |
connection_string = get_connection_string() | |
connection_kwargs = db_config.get('connection_kwargs', {}) | |
# In a serverless environment use a staic pool which is a single-connection pool per Lambda | |
connection_kwargs.update({ | |
'poolclass': StaticPool, | |
}) | |
session_kwargs = db_config.get('session_kwargs', {}) | |
__engine = create_engine(connection_string, **connection_kwargs) | |
print('Connected to: %s' % (__engine.url, )) | |
if not database_exists(__engine.url): # pragma: no cover | |
print('Creating database: %s' % (__engine.url, )) | |
create_database(__engine.url) | |
create_tables() | |
get_session(**session_kwargs) | |
def get_connection_string(**kwargs): | |
"""Return a connection string for sqlalchemy:: | |
dialect+driver://username:password@host:port/database | |
""" | |
global DB_NAME | |
if DB_ENGINE not in (SQLITE, POSTGRESQL): | |
raise ValueError( | |
'Invalid database engine specified: %s. Only sqlite' \ | |
' and postgresql are supported' % (DB_ENGINE, )) | |
if DB_ENGINE == SQLITE: | |
# missing filename creates an in-memory db | |
return 'sqlite://%s' % kwargs.get('filename', '') | |
if __is_test and not DB_NAME.startswith('test_'): | |
DB_NAME = 'test_%s' % (DB_NAME, ) | |
return 'postgresql://%s:%s@%s:%s/%s' % ( | |
DB_USERNAME, | |
DB_PASSWORD, | |
DB_HOST, | |
DB_PORT, | |
DB_NAME, | |
) | |
def close_db(): # pragma: no cover | |
if not __session: | |
return | |
try: | |
__session.commit() | |
except: | |
__session.rollback() | |
finally: | |
__session.close() | |
def commit_session(_raise=False): # pragma: no cover | |
if not __session: | |
return | |
try: | |
__session.commit() | |
except Exception as e: | |
__session.rollback() | |
if _raise: | |
raise | |
def create_tables(): | |
assert __engine | |
meta = _get_metadata() | |
meta.create_all(__engine) | |
def get_session(**kwargs): | |
"""Main API for connection to the DB via the SQLAlchemy session. | |
Clients should use this for any DB interactions as it will connect and setup the database as | |
needed. After initialization the global session will be returned so this is safe to call | |
multiple times in a single thread. | |
""" | |
setup_db() | |
assert __engine | |
global __session | |
global __session_factory | |
if __session is not None: | |
return __session | |
if __session_factory is None: # pragma: no cover | |
__session_factory = sessionmaker(bind=__engine, **kwargs) | |
__session = __session_factory() | |
return __session | |
def session_committer(func): | |
"""Decorator to comming the DB session. | |
Use this from high-level functions such as handler so that the session is always committed or | |
closed. | |
""" | |
def wrapper(*args, **kwargs): | |
try: | |
return func(*args, **kwargs) | |
finally: | |
commit_session() | |
return wrapper | |
def session_getter(func): | |
"""Decorator to get a session and inject it as the first argument in a function""" | |
def wrapper(*args, **kwargs): | |
session = get_session() | |
return func(session, *args, **kwargs) | |
return wrapper | |
@contextmanager | |
def dbtransaction(): # pragma: no cover | |
"""Use as a context manager to commit a transaction""" | |
session = get_session() | |
try: | |
yield session | |
session.commit() | |
except: | |
session.rollback() | |
raise | |
def _drop_tables(*, force=False): | |
if not __is_test and not force: | |
return | |
assert __engine | |
meta = _get_metadata() | |
meta.drop_all(__engine) | |
def _clear_tables(*, force=False): | |
if not __is_test and not force: | |
return | |
assert __engine | |
meta = _get_metadata() | |
with contextlib.closing(__engine.connect()) as con: | |
trans = con.begin() | |
for table in reversed(meta.sorted_tables): | |
try: | |
con.execute(table.delete()) | |
except: | |
pass | |
trans.commit() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment