Skip to content

Instantly share code, notes, and snippets.

@simonrw
Created September 20, 2015 20:24
Show Gist options
  • Save simonrw/01e4e068c5f94b6b9aa3 to your computer and use it in GitHub Desktop.
Save simonrw/01e4e068c5f94b6b9aa3 to your computer and use it in GitHub Desktop.
Database testing with transactions
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import pytest
import sqlalchemy as sa
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
# In `db.py` or something
Base = declarative_base()
Session = sessionmaker()
DBSession = Session()
class Test(Base):
__tablename__ = 'test'
id = sa.Column(sa.Integer, primary_key=True)
value = sa.Column(sa.Float)
# Possibly in `conftest.py`
@pytest.fixture(scope='session')
def testengine(request):
engine = sa.create_engine('sqlite:///:memory:')
Session.configure(bind=engine)
Base.metadata.create_all(engine)
def teardown():
Base.metadata.drop_all(engine)
request.addfinalizer(teardown)
return engine
@pytest.fixture
def db_transaction(request, testengine):
connection = testengine.connect()
transaction = connection.begin()
Session.configure(bind=connection)
def teardown():
transaction.rollback()
connection.close()
request.addfinalizer(teardown)
return connection
@pytest.fixture
def session(db_transaction):
return Session()
# Tests
def test_stuff(session, db_transaction):
t = Test(value=15)
session.add(t)
session.commit()
assert session.query(Test).count() == 1
def test_more_stuff(session, db_transaction):
c = session.query(Test).count()
assert c == 0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment