Skip to content

Instantly share code, notes, and snippets.

@brianz
Created August 31, 2018 22:33
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save brianz/2d7001bf1b0bafa9379303aa2da4cdeb to your computer and use it in GitHub Desktop.
Save brianz/2d7001bf1b0bafa9379303aa2da4cdeb to your computer and use it in GitHub Desktop.
A set of helpers to work with SQLAlchemy setup, connections, transactions, etc.
import contextlib
from contextlib import contextmanager
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import (
NullPool,
QueuePool,
StaticPool,
)
from sqlalchemy_utils import database_exists, create_database
from ..constants import (
DB_ENGINE,
DB_HOST,
DB_NAME,
DB_PASSWORD,
DB_PORT,
DB_USERNAME,
SQLITE,
POSTGRESQL,
)
__session_factory = None
__session = None
__engine = None
__is_test = False
def _get_metadata():
from .mixins import Base
return Base.metadata
def setup_db(*, is_test=False, **db_config):
"""Main setup function for our database.
This will perform the initial DB connection and also create tables as needed.
"""
global __engine
global __is_test
if __engine:
return
__is_test = is_test
connection_string = get_connection_string()
connection_kwargs = db_config.get('connection_kwargs', {})
# In a serverless environment use a staic pool which is a single-connection pool per Lambda
connection_kwargs.update({
'poolclass': StaticPool,
})
session_kwargs = db_config.get('session_kwargs', {})
__engine = create_engine(connection_string, **connection_kwargs)
print('Connected to: %s' % (__engine.url, ))
if not database_exists(__engine.url): # pragma: no cover
print('Creating database: %s' % (__engine.url, ))
create_database(__engine.url)
create_tables()
get_session(**session_kwargs)
def get_connection_string(**kwargs):
"""Return a connection string for sqlalchemy::
dialect+driver://username:password@host:port/database
"""
global DB_NAME
if DB_ENGINE not in (SQLITE, POSTGRESQL):
raise ValueError(
'Invalid database engine specified: %s. Only sqlite' \
' and postgresql are supported' % (DB_ENGINE, ))
if DB_ENGINE == SQLITE:
# missing filename creates an in-memory db
return 'sqlite://%s' % kwargs.get('filename', '')
if __is_test and not DB_NAME.startswith('test_'):
DB_NAME = 'test_%s' % (DB_NAME, )
return 'postgresql://%s:%s@%s:%s/%s' % (
DB_USERNAME,
DB_PASSWORD,
DB_HOST,
DB_PORT,
DB_NAME,
)
def close_db(): # pragma: no cover
if not __session:
return
try:
__session.commit()
except:
__session.rollback()
finally:
__session.close()
def commit_session(_raise=False): # pragma: no cover
if not __session:
return
try:
__session.commit()
except Exception as e:
__session.rollback()
if _raise:
raise
def create_tables():
assert __engine
meta = _get_metadata()
meta.create_all(__engine)
def get_session(**kwargs):
"""Main API for connection to the DB via the SQLAlchemy session.
Clients should use this for any DB interactions as it will connect and setup the database as
needed. After initialization the global session will be returned so this is safe to call
multiple times in a single thread.
"""
setup_db()
assert __engine
global __session
global __session_factory
if __session is not None:
return __session
if __session_factory is None: # pragma: no cover
__session_factory = sessionmaker(bind=__engine, **kwargs)
__session = __session_factory()
return __session
def session_committer(func):
"""Decorator to comming the DB session.
Use this from high-level functions such as handler so that the session is always committed or
closed.
"""
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
finally:
commit_session()
return wrapper
def session_getter(func):
"""Decorator to get a session and inject it as the first argument in a function"""
def wrapper(*args, **kwargs):
session = get_session()
return func(session, *args, **kwargs)
return wrapper
@contextmanager
def dbtransaction(): # pragma: no cover
"""Use as a context manager to commit a transaction"""
session = get_session()
try:
yield session
session.commit()
except:
session.rollback()
raise
def _drop_tables(*, force=False):
if not __is_test and not force:
return
assert __engine
meta = _get_metadata()
meta.drop_all(__engine)
def _clear_tables(*, force=False):
if not __is_test and not force:
return
assert __engine
meta = _get_metadata()
with contextlib.closing(__engine.connect()) as con:
trans = con.begin()
for table in reversed(meta.sorted_tables):
try:
con.execute(table.delete())
except:
pass
trans.commit()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment