Skip to content

Instantly share code, notes, and snippets.

@smokey42
Last active December 19, 2017 11:01
Show Gist options
  • Save smokey42/6be5568d10a5075e8d6d30b01260857e to your computer and use it in GitHub Desktop.
Save smokey42/6be5568d10a5075e8d6d30b01260857e to your computer and use it in GitHub Desktop.
Decorator to give a new db session/engine to a function and handle the commit automatically.
import contextlib
import functools
import inspect
@contextlib.contextmanager
def simple_transaction(session, close=False, commit=True):
"""
Simple transaction manager, commits on success, rolls back on exception.
:param session:
An ORM session, connection or raw connection.
:type close: bool
:param close:
Also close the passed connection, when set.
:type commit: bool
:param commit:
`True` (default) to commit the changes. Intended for test purposes only.
"""
is_session = hasattr(session, 'commit')
try:
if not is_session: # connection, not a session
session.begin()
yield session
if commit:
if not is_session:
session.execute('COMMIT')
else:
session.commit()
except Exception:
if not is_session:
session.execute('ROLLBACK')
else:
session.rollback()
raise
finally:
if close:
session.close()
def with_database(session_factory, param='db', init=None):
"""Decorate a function to provide it it's own database session.
If the function has a parameter named `db`, this will be used to
pass the newly constructed session to the function.
If the wrapper is passed a parameter which is named `db`, the value will be
passed on, and no new session will be created.
If the session has been created by the wrapper, it will be committed in
case there are no errors and is rolled back if the wrapped function raises
an exception.
If the optional key-value argument `commit` is present and is `False`, the
session will not be committed, even if it has been created by the wrapper.
:param session_factory:
The ORM session factory to use.
:param param:
Which parameter of the target function to pass the db reference to.
Defaults to `db`.
:param init:
Function which will be called initially, but only if the connection
has been created by this wrapper. The function has to take one
parameter, the created session.
:returns:
A wrapped function which handles the session transaction logic.
"""
def wrapper(func):
@functools.wraps(func)
def db_wrapper(*args, **kw):
"""The actual wrapper."""
if hasattr(func, '__orig_func__'):
arg_spec = inspect.getargspec(func.__orig_func__)
else:
arg_spec = inspect.getargspec(func)
if arg_spec.defaults is None:
args_names = arg_spec.args
defaults = {}
else:
args_names = arg_spec.args[:-len(arg_spec.defaults)]
defaults = dict(zip(arg_spec.args[-len(arg_spec.defaults):],
arg_spec.defaults))
arg_dict = dict(zip(args_names, args))
created_by_wrapper = False
commit = True
if param in arg_dict:
db = arg_dict[param]
elif param in kw:
db = kw.pop(param)
else:
created_by_wrapper = True
commit = kw.get('commit', True)
db = session_factory()
if param in args_names:
arg_pos = arg_spec.args.index(param)
args = args[:arg_pos] + (db,) + args[arg_pos + 1:]
elif param in defaults:
kw[param] = db
else:
raise TypeError(
"Function needs a parameter which is named '%s'." % param)
# Commit/rollback and close the session only if it was created by
# the wrapper. If a session was provided explicitly, the caller is
# responsible for managing it.
if created_by_wrapper:
if init is not None:
init(db)
with simple_transaction(db, close=True, commit=commit):
return func(*args, **kw)
else:
return func(*args, **kw)
db_wrapper.__orig_func__ = func
return db_wrapper
return wrapper
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment