Skip to content

Instantly share code, notes, and snippets.

@Evgenus
Last active February 27, 2024 12:45
Show Gist options
  • Save Evgenus/5d9b279c396d414cfcd61814c8417058 to your computer and use it in GitHub Desktop.
Save Evgenus/5d9b279c396d414cfcd61814c8417058 to your computer and use it in GitHub Desktop.
Proper SQLAlchemy transactions example
from contextlib import contextmanager
import threading
from thread import get_ident
from sqlalchemy import *
from sqlalchemy.orm import *
from sqlalchemy.ext.declarative import declarative_base
from flask import Flask
from flask_sqlalchemy import SQLAlchemy
################################################################################
def scopefunc():
return "{}.{}".format(is_inside_txn(), get_ident())
db = SQLAlchemy(
session_options=dict(
scopefunc=scopefunc
)
)
txn_context = threading.local()
context_name = "db_or_conection_instance_id"
def push_txn_context():
if not hasattr(txn_context, context_name):
setattr(txn_context, context_name, [])
getattr(txn_context, context_name).append(True)
def pop_txn_context():
getattr(txn_context, context_name).pop(-1)
def is_inside_txn():
return len(getattr(txn_context, context_name, [])) > 0
@contextmanager
def transaction():
outer_transaction = not is_inside_txn()
push_txn_context()
if outer_transaction:
db.session.close()
else:
db.session.begin(nested=True)
try:
yield
db.session.commit()
except Exception:
db.session.rollback()
raise
finally:
if outer_transaction:
db.session.close()
pop_txn_context()
def transactional(func):
def wrapper(*args, **kwargs):
with transaction():
return func(*args, **kwargs)
return wrapper
################################################################################
Base = declarative_base()
def print_what_is_in_session(message):
values = [obj.value for obj in db.session.query(SomeObject).all()]
print message, values
class SomeObject(db.Model):
__tablename__ = 'a'
id = Column(Integer, primary_key=True)
value = Column(String(128))
@transactional
def A():
db.session.add(SomeObject(value="a1"))
print "calling function B from the scope of A"
try:
B()
except Exception:
pass
# here everythin created inside B (but not in A) should be removed from DB
print_what_is_in_session("after B rolled-back: ")
db.session.add(SomeObject(value="a2"))
@transactional
def B():
db.session.add(SomeObject(value="b1"))
C()
db.session.add(SomeObject(value="b2"))
@transactional
def C():
db.session.add(SomeObject(value="c"))
# here we can see that all object created before that moment are in session
print_what_is_in_session("before exception: ")
raise ValueError(0)
################################################################################
def scenario():
# calling transactional function with maybe nested transactions
print "calling function A from outer scope"
db.session.add(SomeObject(value="i"))
db.session.flush()
A()
print_what_is_in_session("after calling A: ")
print "calling function B from outer scope"
try:
B()
except Exception:
pass
# As long as B was rolled back and catched in A only things created
# inside A was stored into DB
print_what_is_in_session("at the end: ")
if __name__ == "__main__":
app = Flask("test_app")
app.config['SQLALCHEMY_DATABASE_URI'] = 'mysql://root:root@localhost/test'
app.config['SQLALCHEMY_ECHO'] = True
db.init_app(app)
with app.app_context():
db.create_all()
scenario()
db.session.close()
db.drop_all()
from contextlib import contextmanager
import threading
from sqlalchemy import *
from sqlalchemy.orm import *
from sqlalchemy.ext.declarative import declarative_base
################################################################################
txn_context = threading.local()
context_name = "db_or_conection_instance_id"
def push_txn_context(session):
if not hasattr(txn_context, context_name):
setattr(txn_context, context_name, [])
getattr(txn_context, context_name).append(session)
def pop_txn_context():
getattr(txn_context, context_name).pop(-1)
def is_inside_txn():
return len(getattr(txn_context, context_name, [])) > 0
# this could be reduced to some property of thread-local object (like db.session)
def current_session():
session_stack = getattr(txn_context, context_name, [])
if len(session_stack) > 0:
return session_stack[-1]
@contextmanager
def transaction():
if is_inside_txn():
session = current_session()
session.begin_nested()
else:
session = Session(engine)
push_txn_context(session)
try:
yield
session.commit()
except Exception:
session.rollback()
raise
finally:
pop_txn_context()
if not is_inside_txn():
session.close()
def transactional(func):
def wrapper(*args, **kwargs):
with transaction():
return func(*args, **kwargs)
return wrapper
################################################################################
Base = declarative_base()
def print_what_is_in_session(message, session):
values = [obj.value for obj in session.query(SomeObject).all()]
print message, values
class SomeObject(Base):
__tablename__ = 'a'
id = Column(Integer, primary_key=True)
value = Column(String(128))
@transactional
def A():
session = current_session()
session.add(SomeObject(value="a1"))
try:
B()
except Exception:
pass
# here everythin created inside B (but not in A) should be removed from DB
print_what_is_in_session("after B rolled-back: ", session)
session.add(SomeObject(value="a2"))
@transactional
def B():
session = current_session()
session.add(SomeObject(value="b1"))
C()
session.add(SomeObject(value="b2"))
@transactional
def C():
session = current_session()
session.add(SomeObject(value="c"))
# here we can see that all object created before that moment are in session
print_what_is_in_session("before exception: ", session)
raise ValueError(0)
################################################################################
if __name__ == "__main__":
engine = create_engine('mysql://root:root@localhost/test', echo=True)
Base.metadata.create_all(engine)
connection = engine.connect()
# Existing external session created by something (like Flask-SQLAlchemy)
session = Session(engine)
# calling transactional function with maybe nested transactions
print "calling function A from outer scope"
A()
print_what_is_in_session("after calling A: ", session)
print "calling function B from outer scope"
try:
B()
except Exception:
pass
# As long as B was rolled back and catched in A only things created
# inside A was stored into DB
print_what_is_in_session("at the end: ", session)
session.close()
Base.metadata.drop_all(engine)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment