Last active
March 10, 2021 08:51
-
-
Save mahmoud/10f6b6b0a9c5860030693357124131df to your computer and use it in GitHub Desktop.
Excerpted tidbits from our Django app's conftest during the Django 1 to 2 and Python 2 to 3 migration ("TR" stands for "Tech Refresh", circa mid 2020), with parts related to wrapping the test session in a DB transaction, and other parts related to bulk-skipping functionality. Tested on py2.7/3.6, Django 1.11/2.0.x, pytest 4.6.11.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
""" | |
Common test fixtures for pytest | |
""" | |
from __future__ import print_function, unicode_literals | |
import os | |
from itertools import groupby | |
import pytest | |
from django.test import TransactionTestCase, TestCase | |
from django.test.testcases import connections_support_transactions | |
from boltons.dictutils import OMD | |
from boltons.fileutils import atomic_save | |
from six import PY3 | |
from pytest_django.plugin import _blocking_manager, validate_django_db | |
from django.db.backends.base.base import BaseDatabaseWrapper | |
# Side effects on import. Can't be helped, we need to unblock all DB | |
# accesses | |
_blocking_manager.unblock() | |
_blocking_manager._blocking_wrapper = BaseDatabaseWrapper.ensure_connection | |
IS_TR_ENV = bool(PY3) | |
TEST_PARTITION_OPTION = "--test_partition" | |
TOTAL_TESTS_OPTION = "--total_test_partitions" | |
def _patchedTearDownClass(cls): | |
""" | |
Overrides tearDownClass from TestCase | |
""" | |
if connections_support_transactions(): | |
cls._rollback_atomics(cls.cls_atomics) | |
super(TransactionTestCase, cls).tearDownClass() | |
TestCase.tearDownClass = classmethod(_patchedTearDownClass) | |
@pytest.fixture(autouse=True) | |
def enable_db_access_for_all_tests(db): | |
""" | |
Force the test to provide DB access to all tests. | |
We probably shouldn't do this, but we have no idea what tests use the DB | |
and what doesn't at this time | |
This also forces the default TestCase environment to be djangos TestCase | |
with atomic transaction support: | |
https://github.com/pytest-dev/pytest-django/blob/ce5d5bc0b29748ed411b9d683a33e1b13d98e17f/pytest_django/fixtures.py#L149 | |
""" | |
pass | |
@pytest.fixture(scope='session') | |
def _transaction_wrap(django_db_setup): | |
# wraps the following client data setup in an exterior transaction separate | |
# from the transaction around the individual tests themselves | |
TestCase._enter_atomics() | |
def pytest_addoption(parser): | |
parser.addoption(TEST_PARTITION_OPTION, type=int, default=os.getenv('TEST_PARTITION') or 0) | |
parser.addoption(TOTAL_TESTS_OPTION, type=int, default=os.getenv('TOTAL_TEST_PARTITIONS', 1)) | |
parser.addoption('--tr-regen-state', action="store_true", default=False, | |
help="regenerate list of skipped/failing tests into a new skipfile, for tech refresh") | |
parser.addoption('--tr-no-autoskip', action='store_true', default=False, | |
help='do not autoskip known failing tests based on tr_test_state.txt') | |
parser.addoption('--tr-recheck', action='store_true', default=False, | |
help='update TR state file for any passing tests. meant to be used with' | |
' test pattern filtering and implies --tr-no-autoskip') | |
return | |
def pytest_collection_modifyitems(config, items): | |
def get_marker_transaction(test): | |
marker = test.get_closest_marker('django_db') | |
if marker: | |
transaction, _ = validate_django_db(marker) | |
return transaction | |
return None | |
def has_fixture(test, fixture): | |
funcargnames = getattr(test, 'funcargnames', None) | |
return funcargnames and fixture in funcargnames | |
def run_transaction_test_cases_after_all_other_tests(test): | |
""" | |
Detect if a test case is marked as a transaction test case, and | |
if so, make sure to run it last since transaction test cases | |
truncate the database (and thus leave no data for "non-transaction" | |
test cases to act on.) | |
Part of the teardown for djangos TransactionTestCase does this: | |
https://github.com/django/django/blob/b61ea56789a5825bd2961a335cb82f65e09f1614/django/test/testcases.py#L1000 | |
""" | |
is_test_case_subclass = getattr( | |
test, 'cls', None) and issubclass(test.cls, TestCase) | |
is_transaction_test_case_subclass = getattr( | |
test, 'cls', None) and issubclass(test.cls, TransactionTestCase) | |
if is_test_case_subclass or get_marker_transaction(test) is False: | |
return 0 | |
elif is_transaction_test_case_subclass or get_marker_transaction(test) is True: | |
return 1 | |
elif has_fixture(test, 'transactional_db') or has_fixture(test, 'live_server'): | |
# live_server uses transactional_db. So same truncation. | |
return 1 | |
elif has_fixture(test, 'db'): | |
return 0 | |
return 0 | |
def sort_by_app_and_test_folder_name(test): | |
if test.cls: | |
return test.cls.__module__ | |
elif test.function: | |
return test.function.__module__ | |
def group_by_test_case(test): | |
if test.cls: | |
return test.cls.__name__ | |
else: | |
return "" | |
key_funcs = (run_transaction_test_cases_after_all_other_tests, | |
sort_by_app_and_test_folder_name, | |
group_by_test_case,) | |
for key_func in key_funcs: | |
items.sort(key=key_func) | |
TEST_NUMBER = config.getoption(TEST_PARTITION_OPTION) | |
TOTAL_TESTS = config.getoption(TOTAL_TESTS_OPTION) | |
# assign each group of test cases a number | |
temp_new_list = [] | |
for index, (group_by_name, tests) in enumerate(groupby(items, group_by_test_case)): | |
if group_by_name == "": | |
for inner_index, test in enumerate(tests): | |
if inner_index % TOTAL_TESTS == TEST_NUMBER: | |
temp_new_list.append(test) | |
else: | |
if index % TOTAL_TESTS == TEST_NUMBER: | |
for test in tests: | |
temp_new_list.append(test) | |
items[:] = temp_new_list | |
if not config.getoption("--needs-isolation"): | |
skip_needs_isolation = pytest.mark.skip( | |
reason="need --needs-isolation option to run") | |
for item in items: | |
if "needs_isolation" in item.keywords: | |
item.add_marker(skip_needs_isolation) | |
skip_tr_tests = not (config.option.tr_recheck or config.option.tr_no_autoskip or config.option.tr_regen_state) | |
if IS_TR_ENV and skip_tr_tests: | |
trf_path = config.rootdir + '/tr_test_state.txt' | |
try: | |
trf = TestResultFile.from_path(trf_path) | |
except OSError: | |
pass | |
else: | |
skip_tr_failure = pytest.mark.xfail( | |
reason='known failure related to tech refresh') | |
fns = set(trf.get_failing_nodeids()) | |
for item in items: | |
if item.nodeid in fns: | |
item.add_marker(skip_tr_failure) | |
return | |
_all_reports = OMD() | |
@pytest.hookimpl(tryfirst=True) | |
def pytest_runtest_logreport(report): | |
# this function runs 3x for each test: setup, call, teardown. | |
# this approach ensures that if any phase fails, this test stays marked as failed | |
if _all_reports.get(report.nodeid) != 'failed': | |
_all_reports.add(report.nodeid, report.outcome) | |
return | |
@pytest.fixture(scope="session", autouse=True) | |
def tr_state_save(request): | |
session = request.node | |
config = session.config | |
if config.option.tr_regen_state: | |
if config.option.tr_recheck: | |
raise SystemExit('--tr-regen-state is mutually exclusive with --tr-recheck') | |
args = [arg for arg in config.args if arg] | |
if config.option.keyword: | |
raise SystemExit('refusing to regenerate TR state while running a subset of tests') | |
yield | |
if not IS_TR_ENV or not _all_reports: | |
return # not tech refreshing / process aborted | |
results = _all_reports.items() | |
path = config.rootdir + '/tr_test_state.txt' | |
if config.option.tr_regen_state: | |
if len(_all_reports) != session.testscollected: | |
raise SystemExit('refusing to regenerate TR state with incomplete test run') | |
new_trf = TestResultFile(path, results) | |
new_trf.save() | |
if config.option.tr_recheck: | |
try: | |
trf = TestResultFile.from_path(path) | |
except OSError: | |
print('no existing test result file at %r, nothing to recheck against' % path) | |
trf.update(results) | |
trf.save() | |
return | |
class TestResultFile(object): | |
def __init__(self, path, results, intro_lines=()): | |
self.results = OMD(sorted(results)) | |
self.path = path | |
self.intro_lines = intro_lines | |
def get_failing_nodeids(self): | |
return [nodeid for nodeid, res in self.results.items() if res == 'failed'] | |
def update(self, new_results): | |
for nodeid, outcome in new_results: | |
self.results.add(nodeid, outcome) | |
self.results = self.results.sorted() | |
@classmethod | |
def from_path(cls, path): | |
with open(path) as f: | |
contents = f.read() | |
contents_lines = contents.splitlines() | |
intro_lines = [] | |
results = [] | |
intro_done = False | |
for line in contents_lines: | |
line = line.strip() | |
if not line: | |
continue | |
if not intro_done and line.startswith('#'): | |
intro_lines.append(line[2:] if line.startswith('# ') else line[1:]) | |
else: | |
intro_done = True | |
result, _, nodeid = line.partition(' - ') | |
results.append((nodeid, result)) | |
return cls(path, results, intro_lines=intro_lines) | |
def save(self): | |
lines = [] | |
for line in self.intro_lines: | |
lines.append('# %s\n' % line) | |
for nodeid, result in self.results.items(): | |
lines.append('%s - %s\n' % (result, nodeid)) | |
with atomic_save(self.path) as f: | |
f.writelines([line.encode('utf8') for line in lines]) | |
return |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment