Last active
January 27, 2022 03:18
-
-
Save zzzeek/8443477 to your computer and use it in GitHub Desktop.
expands upon the SQLAlchemy "test rollback fixure" at http://docs.sqlalchemy.org/en/rel_0_9/orm/session.html#joining-a-session-into-an-external-transaction to also support tests that have any combination of ROLLBACK/COMMIT within them, by ensuring that the Session always runs transactions inside of a savepoint.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sqlalchemy import Column, Integer, create_engine | |
from sqlalchemy.ext.declarative import declarative_base | |
Base = declarative_base() | |
# a model | |
class Thing(Base): | |
__tablename__ = 'thing' | |
id = Column(Integer, primary_key=True) | |
# a database w a schema | |
engine = create_engine("postgresql://scott:tiger@localhost/test", echo=True) | |
Base.metadata.drop_all(engine) | |
Base.metadata.create_all(engine) | |
from unittest import TestCase | |
import unittest | |
from sqlalchemy.orm import Session | |
from sqlalchemy import event | |
class MyTests(TestCase): | |
def setUp(self): | |
# same setup from the docs | |
self.conn = engine.connect() | |
self.trans = self.conn.begin() | |
self.session = Session(bind=self.conn) | |
# load fixture data within the scope of the transaction | |
self._fixture() | |
# start the session in a SAVEPOINT... | |
self.session.begin_nested() | |
# then each time that SAVEPOINT ends, reopen it | |
@event.listens_for(self.session, "after_transaction_end") | |
def restart_savepoint(session, transaction): | |
if transaction.nested and not transaction._parent.nested: | |
session.begin_nested() | |
def tearDown(self): | |
# same teardown from the docs | |
self.session.close() | |
self.trans.rollback() | |
self.conn.close() | |
def _fixture(self): | |
self.session.add_all([ | |
Thing(), Thing(), Thing() | |
]) | |
self.session.commit() | |
def test_thing_one(self): | |
# run zero rollbacks | |
self._test_thing(0) | |
def test_thing_two(self): | |
# run two extra rollbacks | |
self._test_thing(2) | |
def test_thing_five(self): | |
# run five extra rollbacks | |
self._test_thing(5) | |
def _test_thing(self, extra_rollback=0): | |
session = self.session | |
rows = session.query(Thing).all() | |
self.assertEquals(len(rows), 3) | |
for elem in range(extra_rollback): | |
# run N number of rollbacks | |
session.add_all([Thing(), Thing(), Thing()]) | |
rows = session.query(Thing).all() | |
self.assertEquals(len(rows), 6) | |
session.rollback() | |
# after rollbacks, still @ 3 rows | |
rows = session.query(Thing).all() | |
self.assertEquals(len(rows), 3) | |
session.add_all([Thing(), Thing()]) | |
session.commit() | |
rows = session.query(Thing).all() | |
self.assertEquals(len(rows), 5) | |
session.add(Thing()) | |
rows = session.query(Thing).all() | |
self.assertEquals(len(rows), 6) | |
for elem in range(extra_rollback): | |
# run N number of rollbacks | |
session.add_all([Thing(), Thing(), Thing()]) | |
rows = session.query(Thing).all() | |
if elem > 0: | |
# b.c. we rolled back that other "thing" too | |
self.assertEquals(len(rows), 8) | |
else: | |
self.assertEquals(len(rows), 9) | |
session.rollback() | |
rows = session.query(Thing).all() | |
if extra_rollback: | |
self.assertEquals(len(rows), 5) | |
else: | |
self.assertEquals(len(rows), 6) | |
if __name__ == '__main__': | |
unittest.main() | |
aryaniyaps
commented
Jan 27, 2022
via email
thanks for the reply! This helped me a lot, I appreciate it!
…On Tue, Jan 25, 2022 at 8:09 PM mike bayer ***@***.***> wrote:
***@***.**** commented on this gist.
------------------------------
looks like you are missing a db_session.remove(), otherwise works fine,
heres a demo, try running this:
from typing import Iteratorfrom sqlalchemy import create_engine, textfrom sqlalchemy.engine import Connection, Enginefrom sqlalchemy.orm import Session, scoped_session, sessionmaker
from pytest import fixture
engine = create_engine("sqlite://", echo=True, future=True)
db_session = scoped_session(sessionmaker(bind=engine))
@fixture(scope="session")def db_engine() -> Iterator[Engine]:
yield engine
@fixture(scope="session")def db_connection(db_engine: Engine) -> Iterator[Connection]:
""" Initializes the connection to the test database. :return: The database connection. """
connection = db_engine.connect()
yield connection
connection.close()
@fixture(autouse=True)def db_transaction(db_connection: Connection) -> Iterator[Session]:
""" Sets up a database transaction for each test case. :return: The database transaction. """
transaction = db_connection.begin()
assert transaction is not None
session = db_session(bind=db_connection)
yield session
session.close()
db_session.remove()
transaction.rollback()
def test_in_a_transaction(db_transaction):
result = db_transaction.execute(text("select 1"))
result.close()
def test_also_in_a_transaction(db_transaction):
result = db_transaction.execute(text("select 1"))
result.close()
—
Reply to this email directly, view it on GitHub
<https://gist.github.com/8443477#gistcomment-4041340>, or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AQP2YPO4IA27DENE27UES6DUX2YYNANCNFSM5MVEFKWQ>
.
You are receiving this because you commented.Message ID: <zzzeek/gist:
***@***.***>
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment