Skip to content

Instantly share code, notes, and snippets.

@mahmoud
Last active March 10, 2021 08:51
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mahmoud/10f6b6b0a9c5860030693357124131df to your computer and use it in GitHub Desktop.
Save mahmoud/10f6b6b0a9c5860030693357124131df to your computer and use it in GitHub Desktop.
Excerpted tidbits from our Django app's conftest during the Django 1 to 2 and Python 2 to 3 migration ("TR" stands for "Tech Refresh", circa mid 2020), with parts related to wrapping the test session in a DB transaction, and other parts related to bulk-skipping functionality. Tested on py2.7/3.6, Django 1.11/2.0.x, pytest 4.6.11.
# -*- coding: utf-8 -*-
"""
Common test fixtures for pytest
"""
from __future__ import print_function, unicode_literals
import os
from itertools import groupby
import pytest
from django.test import TransactionTestCase, TestCase
from django.test.testcases import connections_support_transactions
from boltons.dictutils import OMD
from boltons.fileutils import atomic_save
from six import PY3
from pytest_django.plugin import _blocking_manager, validate_django_db
from django.db.backends.base.base import BaseDatabaseWrapper
# Side effects on import. Can't be helped, we need to unblock all DB
# accesses
_blocking_manager.unblock()
_blocking_manager._blocking_wrapper = BaseDatabaseWrapper.ensure_connection
IS_TR_ENV = bool(PY3)
TEST_PARTITION_OPTION = "--test_partition"
TOTAL_TESTS_OPTION = "--total_test_partitions"
def _patchedTearDownClass(cls):
"""
Overrides tearDownClass from TestCase
"""
if connections_support_transactions():
cls._rollback_atomics(cls.cls_atomics)
super(TransactionTestCase, cls).tearDownClass()
TestCase.tearDownClass = classmethod(_patchedTearDownClass)
@pytest.fixture(autouse=True)
def enable_db_access_for_all_tests(db):
"""
Force the test to provide DB access to all tests.
We probably shouldn't do this, but we have no idea what tests use the DB
and what doesn't at this time
This also forces the default TestCase environment to be djangos TestCase
with atomic transaction support:
https://github.com/pytest-dev/pytest-django/blob/ce5d5bc0b29748ed411b9d683a33e1b13d98e17f/pytest_django/fixtures.py#L149
"""
pass
@pytest.fixture(scope='session')
def _transaction_wrap(django_db_setup):
# wraps the following client data setup in an exterior transaction separate
# from the transaction around the individual tests themselves
TestCase._enter_atomics()
def pytest_addoption(parser):
parser.addoption(TEST_PARTITION_OPTION, type=int, default=os.getenv('TEST_PARTITION') or 0)
parser.addoption(TOTAL_TESTS_OPTION, type=int, default=os.getenv('TOTAL_TEST_PARTITIONS', 1))
parser.addoption('--tr-regen-state', action="store_true", default=False,
help="regenerate list of skipped/failing tests into a new skipfile, for tech refresh")
parser.addoption('--tr-no-autoskip', action='store_true', default=False,
help='do not autoskip known failing tests based on tr_test_state.txt')
parser.addoption('--tr-recheck', action='store_true', default=False,
help='update TR state file for any passing tests. meant to be used with'
' test pattern filtering and implies --tr-no-autoskip')
return
def pytest_collection_modifyitems(config, items):
def get_marker_transaction(test):
marker = test.get_closest_marker('django_db')
if marker:
transaction, _ = validate_django_db(marker)
return transaction
return None
def has_fixture(test, fixture):
funcargnames = getattr(test, 'funcargnames', None)
return funcargnames and fixture in funcargnames
def run_transaction_test_cases_after_all_other_tests(test):
"""
Detect if a test case is marked as a transaction test case, and
if so, make sure to run it last since transaction test cases
truncate the database (and thus leave no data for "non-transaction"
test cases to act on.)
Part of the teardown for djangos TransactionTestCase does this:
https://github.com/django/django/blob/b61ea56789a5825bd2961a335cb82f65e09f1614/django/test/testcases.py#L1000
"""
is_test_case_subclass = getattr(
test, 'cls', None) and issubclass(test.cls, TestCase)
is_transaction_test_case_subclass = getattr(
test, 'cls', None) and issubclass(test.cls, TransactionTestCase)
if is_test_case_subclass or get_marker_transaction(test) is False:
return 0
elif is_transaction_test_case_subclass or get_marker_transaction(test) is True:
return 1
elif has_fixture(test, 'transactional_db') or has_fixture(test, 'live_server'):
# live_server uses transactional_db. So same truncation.
return 1
elif has_fixture(test, 'db'):
return 0
return 0
def sort_by_app_and_test_folder_name(test):
if test.cls:
return test.cls.__module__
elif test.function:
return test.function.__module__
def group_by_test_case(test):
if test.cls:
return test.cls.__name__
else:
return ""
key_funcs = (run_transaction_test_cases_after_all_other_tests,
sort_by_app_and_test_folder_name,
group_by_test_case,)
for key_func in key_funcs:
items.sort(key=key_func)
TEST_NUMBER = config.getoption(TEST_PARTITION_OPTION)
TOTAL_TESTS = config.getoption(TOTAL_TESTS_OPTION)
# assign each group of test cases a number
temp_new_list = []
for index, (group_by_name, tests) in enumerate(groupby(items, group_by_test_case)):
if group_by_name == "":
for inner_index, test in enumerate(tests):
if inner_index % TOTAL_TESTS == TEST_NUMBER:
temp_new_list.append(test)
else:
if index % TOTAL_TESTS == TEST_NUMBER:
for test in tests:
temp_new_list.append(test)
items[:] = temp_new_list
if not config.getoption("--needs-isolation"):
skip_needs_isolation = pytest.mark.skip(
reason="need --needs-isolation option to run")
for item in items:
if "needs_isolation" in item.keywords:
item.add_marker(skip_needs_isolation)
skip_tr_tests = not (config.option.tr_recheck or config.option.tr_no_autoskip or config.option.tr_regen_state)
if IS_TR_ENV and skip_tr_tests:
trf_path = config.rootdir + '/tr_test_state.txt'
try:
trf = TestResultFile.from_path(trf_path)
except OSError:
pass
else:
skip_tr_failure = pytest.mark.xfail(
reason='known failure related to tech refresh')
fns = set(trf.get_failing_nodeids())
for item in items:
if item.nodeid in fns:
item.add_marker(skip_tr_failure)
return
_all_reports = OMD()
@pytest.hookimpl(tryfirst=True)
def pytest_runtest_logreport(report):
# this function runs 3x for each test: setup, call, teardown.
# this approach ensures that if any phase fails, this test stays marked as failed
if _all_reports.get(report.nodeid) != 'failed':
_all_reports.add(report.nodeid, report.outcome)
return
@pytest.fixture(scope="session", autouse=True)
def tr_state_save(request):
session = request.node
config = session.config
if config.option.tr_regen_state:
if config.option.tr_recheck:
raise SystemExit('--tr-regen-state is mutually exclusive with --tr-recheck')
args = [arg for arg in config.args if arg]
if config.option.keyword:
raise SystemExit('refusing to regenerate TR state while running a subset of tests')
yield
if not IS_TR_ENV or not _all_reports:
return # not tech refreshing / process aborted
results = _all_reports.items()
path = config.rootdir + '/tr_test_state.txt'
if config.option.tr_regen_state:
if len(_all_reports) != session.testscollected:
raise SystemExit('refusing to regenerate TR state with incomplete test run')
new_trf = TestResultFile(path, results)
new_trf.save()
if config.option.tr_recheck:
try:
trf = TestResultFile.from_path(path)
except OSError:
print('no existing test result file at %r, nothing to recheck against' % path)
trf.update(results)
trf.save()
return
class TestResultFile(object):
def __init__(self, path, results, intro_lines=()):
self.results = OMD(sorted(results))
self.path = path
self.intro_lines = intro_lines
def get_failing_nodeids(self):
return [nodeid for nodeid, res in self.results.items() if res == 'failed']
def update(self, new_results):
for nodeid, outcome in new_results:
self.results.add(nodeid, outcome)
self.results = self.results.sorted()
@classmethod
def from_path(cls, path):
with open(path) as f:
contents = f.read()
contents_lines = contents.splitlines()
intro_lines = []
results = []
intro_done = False
for line in contents_lines:
line = line.strip()
if not line:
continue
if not intro_done and line.startswith('#'):
intro_lines.append(line[2:] if line.startswith('# ') else line[1:])
else:
intro_done = True
result, _, nodeid = line.partition(' - ')
results.append((nodeid, result))
return cls(path, results, intro_lines=intro_lines)
def save(self):
lines = []
for line in self.intro_lines:
lines.append('# %s\n' % line)
for nodeid, result in self.results.items():
lines.append('%s - %s\n' % (result, nodeid))
with atomic_save(self.path) as f:
f.writelines([line.encode('utf8') for line in lines])
return
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment