Skip to content

Instantly share code, notes, and snippets.

@zzzeek
Last active January 27, 2022 03:18
Show Gist options
  • Save zzzeek/8443477 to your computer and use it in GitHub Desktop.
Save zzzeek/8443477 to your computer and use it in GitHub Desktop.
expands upon the SQLAlchemy "test rollback fixure" at http://docs.sqlalchemy.org/en/rel_0_9/orm/session.html#joining-a-session-into-an-external-transaction to also support tests that have any combination of ROLLBACK/COMMIT within them, by ensuring that the Session always runs transactions inside of a savepoint.
from sqlalchemy import Column, Integer, create_engine
from sqlalchemy.ext.declarative import declarative_base
Base = declarative_base()
# a model
class Thing(Base):
__tablename__ = 'thing'
id = Column(Integer, primary_key=True)
# a database w a schema
engine = create_engine("postgresql://scott:tiger@localhost/test", echo=True)
Base.metadata.drop_all(engine)
Base.metadata.create_all(engine)
from unittest import TestCase
import unittest
from sqlalchemy.orm import Session
from sqlalchemy import event
class MyTests(TestCase):
def setUp(self):
# same setup from the docs
self.conn = engine.connect()
self.trans = self.conn.begin()
self.session = Session(bind=self.conn)
# load fixture data within the scope of the transaction
self._fixture()
# start the session in a SAVEPOINT...
self.session.begin_nested()
# then each time that SAVEPOINT ends, reopen it
@event.listens_for(self.session, "after_transaction_end")
def restart_savepoint(session, transaction):
if transaction.nested and not transaction._parent.nested:
session.begin_nested()
def tearDown(self):
# same teardown from the docs
self.session.close()
self.trans.rollback()
self.conn.close()
def _fixture(self):
self.session.add_all([
Thing(), Thing(), Thing()
])
self.session.commit()
def test_thing_one(self):
# run zero rollbacks
self._test_thing(0)
def test_thing_two(self):
# run two extra rollbacks
self._test_thing(2)
def test_thing_five(self):
# run five extra rollbacks
self._test_thing(5)
def _test_thing(self, extra_rollback=0):
session = self.session
rows = session.query(Thing).all()
self.assertEquals(len(rows), 3)
for elem in range(extra_rollback):
# run N number of rollbacks
session.add_all([Thing(), Thing(), Thing()])
rows = session.query(Thing).all()
self.assertEquals(len(rows), 6)
session.rollback()
# after rollbacks, still @ 3 rows
rows = session.query(Thing).all()
self.assertEquals(len(rows), 3)
session.add_all([Thing(), Thing()])
session.commit()
rows = session.query(Thing).all()
self.assertEquals(len(rows), 5)
session.add(Thing())
rows = session.query(Thing).all()
self.assertEquals(len(rows), 6)
for elem in range(extra_rollback):
# run N number of rollbacks
session.add_all([Thing(), Thing(), Thing()])
rows = session.query(Thing).all()
if elem > 0:
# b.c. we rolled back that other "thing" too
self.assertEquals(len(rows), 8)
else:
self.assertEquals(len(rows), 9)
session.rollback()
rows = session.query(Thing).all()
if extra_rollback:
self.assertEquals(len(rows), 5)
else:
self.assertEquals(len(rows), 6)
if __name__ == '__main__':
unittest.main()
@aryaniyaps
Copy link

@zzzeek thanks for clarifying!

The issue is, I have my sessionmaker in a database.py file, like this:

# database.py
from sqlalchemy import create_engine, MetaData
from sqlalchemy.orm import (
    declarative_base, 
    scoped_session, 
    sessionmaker
)

from app.config import DEBUG, DATABASE_URL

metadata = MetaData(
    naming_convention={
        "ix": "ix_%(column_0_label)s",
        "uq": "uq_%(table_name)s_%(column_0_name)s",
        "fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
        "pk": "pk_%(table_name)s",
    },
)

Base = declarative_base(metadata=metadata)

engine = create_engine(url=DATABASE_URL, future=True, echo=DEBUG)

db_session = scoped_session(sessionmaker(bind=engine))

And I have setup my pytest fixtures like this:

@fixture(scope="session")
def db_engine() -> Iterator[Engine]:
    alembic_cfg = Config("alembic.ini")
    Base.metadata.create_all(bind=engine)
    stamp(alembic_cfg, revision="head")
    yield engine
    Base.metadata.drop_all(bind=engine)
    stamp(alembic_cfg, revision=None, purge=True)


@fixture(scope="session")
def db_connection(db_engine: Engine) -> Iterator[Connection]:
    """
    Initializes the connection to 
    the test database.

    :return: The database connection.
    """
    connection = db_engine.connect()
    yield connection
    connection.close()


@fixture(autouse=True)
def db_transaction(db_connection: Connection) -> Iterator[Session]:
    """
    Sets up a database transaction for each test case.

    :return: The database transaction.
    """
    transaction = db_connection.begin()
    session = db_session(bind=db_connection)
    yield session
    session.close()
    transaction.rollback()

I have a couple of problems with this:

  1. I get an error when this line:
    transaction = db_connection.begin()

is called, probably because the connection has already begun at that point.

  1. I have my session inside a transaction, but how do I supply that to the rest of the project (views, services..)?
    I could think of patching the session present in database.py, but is there a better option, like changing the bind
    of the sessionmaker?

@zzzeek
Copy link
Author

zzzeek commented Jan 25, 2022

i see nothing wrong with the code and im not able to spot any place that the connection would be implicitly beginning, so you should be able to call begin() no problem.

you are already using the global scoped session you have in your fixture, so while that fixture is in effect, that's the session that all of your application will get when they refer to the db_session global.

@zzzeek
Copy link
Author

zzzeek commented Jan 25, 2022

looks like you are missing a db_session.remove(), otherwise works fine, heres a demo, try running this:

from typing import Iterator
from sqlalchemy import create_engine, text
from sqlalchemy.engine import Connection, Engine
from sqlalchemy.orm import Session, scoped_session, sessionmaker

from pytest import fixture

engine = create_engine("sqlite://", echo=True, future=True)

db_session = scoped_session(sessionmaker(bind=engine))


@fixture(scope="session")
def db_engine() -> Iterator[Engine]:
    yield engine


@fixture(scope="session")
def db_connection(db_engine: Engine) -> Iterator[Connection]:
    """
    Initializes the connection to
    the test database.

    :return: The database connection.
    """
    connection = db_engine.connect()
    yield connection
    connection.close()


@fixture(autouse=True)
def db_transaction(db_connection: Connection) -> Iterator[Session]:
    """
    Sets up a database transaction for each test case.

    :return: The database transaction.
    """
    transaction = db_connection.begin()
    assert transaction is not None
    session = db_session(bind=db_connection)
    yield session
    session.close()
    db_session.remove()
    transaction.rollback()


def test_in_a_transaction(db_transaction):
    result = db_transaction.execute(text("select 1"))
    result.close()

def test_also_in_a_transaction(db_transaction):
    result = db_transaction.execute(text("select 1"))
    result.close()

@aryaniyaps
Copy link

aryaniyaps commented Jan 27, 2022 via email

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment