Skip to content

Instantly share code, notes, and snippets.

@exhuma
Last active September 26, 2020 10:16
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save exhuma/c135551cbb0ad4bb993efea69d1084c9 to your computer and use it in GitHub Desktop.
Save exhuma/c135551cbb0ad4bb993efea69d1084c9 to your computer and use it in GitHub Desktop.
Test-Harness for SQLAlchemy unit-tests
"""
Helper functions for unit-testing with SQLAlchemy
This provides a context-manager "rb_session" which
creates a new session that ignores all ".commit()"
calls. This might not work with all databases. It
has been tested with PostgreSQL. Verify that the
commits are really ignored if you use any other DB.
"""
from contextlib import contextmanager
from os.path import dirname, join, relpath
from typing import Any, Dict, Iterator, List, Optional, Tuple
from sqlalchemy import create_engine
from sqlalchemy.orm import Session
TSeedFiles = List[Tuple[str, Dict[str, Any]]]
class SeedException(Exception):
"""
An exception that is thrown when a seed-file fails to load
"""
def load_seed_files(session: Session, seed_files: TSeedFiles) -> None:
"""
Loads a list of seeds into the database.
Each seed can be supplied with a list of variables for templated seeds. The
variables are directly passed to :py:meth:`sqlalchemy.orm.Session.execute`
as ``params`` argument.
So, assuming the SQL file contains the following:
.. code-block:: sql
INSERT INTY mytable (foo, bar) VALUES ('hello', :inserted);
The following snippet can be used to load it, including any variables.
>>> session = get_session()
>>> load_seed_files(session, [
... ('myseed.sql', {'inserted': datetime(2019, 1, 1)}),
... ])
"""
for seed_file, variables in seed_files:
seed_path = join(dirname(__file__), 'data', 'seeds', seed_file)
with open(relpath(seed_path), encoding='utf8') as fptr:
data = fptr.read()
try:
session.execute(data, params=variables)
except Exception as exc:
# Prevent the whole seed to be printed as error (first two
# lines are sufficient)
lines = str(exc).splitlines()
simplified_error = ' '.join(lines[:2])
conn = session.bind.engine
raise SeedException(
'Unable to import seed file %r into %r: %s'
% (seed_path, conn, simplified_error)
) from None
session.commit() # type: ignore
@contextmanager
def rb_session(
dsn: str, seed_files: Optional[TSeedFiles] = None
) -> Iterator[Session]:
"""
A simple context-manager that wraps a database session that will never
commit.
Any "commit" calls on the session returned by this context manager will be
ignored.
"""
seed_files = seed_files or []
engine = create_engine(dsn)
connection = engine.connect()
trans = connection.begin()
session = Session(bind=connection)
load_seed_files(session, seed_files)
try:
yield session
finally:
trans.rollback() # type: ignore
session.close() # type: ignore
connection.close()
'''
This module contains pytest-definitions (fixtures & co) which are shared across
all tests in the project.
'''
# pylint: disable=redefined-outer-name
#
# Needed for pylint fixtures
import logging
from pytest import fixture
from sqlalchemy import create_engine
from sqlalchemy.orm import Session
def nuke_all_tables(session: Session) -> None:
"""
Truncates all tables i.e nuking all records
"""
# Truncate all tables
result = session.execute(
"""\
SELECT schemaname, tablename FROM pg_tables
WHERE schemaname in ('public', 'history') AND
tablename != 'alembic_version';
""")
table_names = ["%s.%s" % (_[0], _[1]) for _ in result.fetchall()]
session.execute("TRUNCATE %s CASCADE;" % (", ".join(table_names)))
session.commit()
@fixture
def rb_session():
"""
Returns a session which will always be rolled back and deletes
all data from accidental commits
"""
configs = get_configs()
engine = create_engine(DSN)
session = Session()
session.bind = engine
try:
yield session
finally:
session.rollback() # type: ignore
nuke_all_tables(session)
session.close() # type: ignore
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment