-
-
Save datavudeja/9aa13f22561edb86034caac362fd073b to your computer and use it in GitHub Desktop.
conftest.py with every trick i know for pytest
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
''' | |
[tool.poe.tasks.test] | |
cmd = "pytest" | |
env = { OBJC_DISABLE_INITIALIZE_FORK_SAFETY="YES", DISABLE_SPRING="true" } # pytest-parallel | |
[tool.pytest.ini_options] | |
log_format = "[%(levelname)-5s %(asctime)s %(name)s@%(filename)s:%(lineno)d %(funcName)s] %(message)s" | |
log_auto_indent = 2 | |
addopts = """ | |
--color=yes | |
--capture=no | |
--code-highlight=yes | |
--tb=short | |
-vv | |
--doctest-modules | |
--ignore=alembic --ignore=app/common/extra_logger.py | |
-x | |
--disable-warnings -p no:warnings | |
""" | |
# --workers=8 --tests-per-worker=1 # pytest-parallel | |
filterwarnings = [ | |
"ignore::DeprecationWarning" | |
] | |
python_classes = "Test* test_*" | |
python_functions = "test_* benchmark_*" | |
markers = [ | |
"llm_eval: An LLM evaluation. Slow, expensive, probably runs many times for statistical significance.", | |
] | |
''' | |
import os | |
import subprocess | |
from typing import Literal | |
import pytest | |
import rich | |
from rich.traceback import Traceback | |
from _pytest.nodes import Item | |
from _pytest.reports import TestReport | |
from _pytest.runner import CallInfo | |
from rich.align import Align | |
from rich.box import HORIZONTALS as HORIZONTAL_LINES | |
from rich.panel import Panel | |
pytest_plugins = "conftest_internals" # A module with an __init__.py importing other files. | |
"""The "pytest_plugins" variable is how you split up a big confest.py file into multiple files.""" | |
console = rich.get_console() | |
def panel( | |
text, | |
box=HORIZONTAL_LINES, | |
subtitle_align: Literal["left", "center", "right"] = "right", | |
style="bright_blue", | |
**kwargs, | |
): | |
return Panel( | |
Align(text, "center"), | |
box=box, | |
subtitle_align=subtitle_align, | |
style=style, | |
**kwargs, | |
) | |
def pytest_runtestloop(session): | |
import dotenv | |
dotenv_file = "app/.tests.env" | |
dotenv.load_dotenv(dotenv_file, verbose=True, override=True) | |
general_settings.reload_in_place(dotenv_file) # Bad practice. | |
container = docker.get_test_container() # See github.com/giladbarnea/pytest-docker. | |
container_was_already_running: bool = container and docker.is_container_running(container) | |
if not container_was_already_running: | |
ok, exitcode = docker.docker_up("app/docker-compose-tests.yaml", env_file=dotenv_file, verbosity=2) | |
if not ok: | |
raise RuntimeError(f"Failed docker-compose up with exitcode: {exitcode}") | |
db.wait_for_test_db() | |
if not container_was_already_running: | |
# It takes a little longer if we just started the container | |
time.sleep(2) | |
ok, exitcode = db.run_test_db_migrations() | |
if not ok: | |
raise RuntimeError("Failed to run migrations") | |
empty_db = session.config.getoption('--empty-db') | |
should_empty_db = empty_db == "on-start" or empty_db == "on-start-and-finish" | |
if should_empty_db: | |
db.empty_test_db() | |
def pytest_sessionstart(session: pytest.Session): | |
# I don't think this works. Check out runtestloop. | |
os.environ.update( | |
IN_PYTEST="true", DEBUG="true", DISABLE_SHORT_LOGS="true", FORCE_COLOR="true" | |
) | |
def pytest_runtest_logstart(nodeid: str, location: tuple[str, int | None, str]): | |
file, line_numer, test_name = location | |
test_name = test_name.replace('[', r'\[') | |
console.print( | |
"\n", | |
panel(f"[bright_blue]Running: [bright_white]{test_name}[/]", subtitle=f"[dim white]{file}:{line_numer}[/]"), | |
) | |
def pytest_collection_modifyitems(session: pytest.Session, config: pytest.Config, items: list[pytest.Item]): | |
# Note that `-k` didn't happen yet, so `items` contains all tests. TL;DR: Breaks if you use `-k`. | |
krange = config.getoption("krange") | |
if krange: | |
min, _, max = config.option.krange.partition('..') | |
min = int(min) if min else None | |
max = int(max) if max else None | |
for item in items: | |
if item.name[-1] != ']': | |
continue | |
item_name = item.name.rstrip(']') | |
i = len(item_name) - 1 | |
while item_name[i].isdigit(): | |
i -= 1 | |
parametrized_key = int(item_name[i + 1 :]) | |
below_minimum: bool = min is not None and parametrized_key < min | |
above_maximum: bool = max is not None and parametrized_key > max | |
if below_minimum or above_maximum: | |
item.add_marker(pytest.mark.skip(reason=f"Not in krange: {min}..{max}")) | |
def pytest_runtest_makereport(item: Item, call: CallInfo) -> TestReport | None: | |
test_passed = call.when == "call" and call.excinfo is None | |
if test_passed: | |
everpassed: list = item.config.cache.get("everpassed", []) | |
everpassed.append(item.nodeid) if item.nodeid not in everpassed else None | |
item.config.cache.set("everpassed", everpassed) | |
return | |
test_failed: bool = call.when == "call" and call.excinfo is not None | |
if not test_failed: | |
return | |
test_title = f"{item.fspath.basename} {item.name}" | |
line_numer = item.location[1] | |
console.print( | |
panel( | |
f"Failed: [bright_white]{test_title}[/]\n{call.excinfo.typename}: {call.excinfo.value}", | |
subtitle=f"[dim white]{item.fspath.basename}:{line_numer}[/]", | |
style="bright_red", | |
) | |
) | |
print_traceback(call) | |
console.print(panel("Resuming tests")) | |
def pytest_report_teststatus(report: TestReport, config): # pylint: disable=unused-argument | |
test_title = f"{report.location[0]} {report.head_line}" | |
if report.when == "call": | |
if report.passed: | |
return "passed", f"\n✔ {test_title}", "✔ PASSED" | |
if report.failed: | |
return "failed", f"\n✘ {test_title}", "✘ FAILED" | |
if report.skipped: | |
return "skipped", f"\n⚠ {test_title}", "⚠ SKIPPED" | |
return None | |
def print_traceback( | |
exception: BaseException, | |
*args, | |
extra_lines=5, | |
show_locals=True, | |
locals_max_string=1000, | |
locals_max_length=1000, | |
width=console.width, | |
suppress=(".venv",), | |
console=console, | |
): | |
traceback = Traceback.from_exception( | |
type(exception), | |
exception, | |
traceback=exception.__traceback__, | |
extra_lines=extra_lines, | |
show_locals=show_locals, | |
locals_max_string=locals_max_string, | |
locals_max_length=locals_max_length, | |
suppress=suppress, | |
width=width, | |
) | |
console.print(traceback, *args) | |
def llm_eval(function): | |
""" | |
Decorator to mark a test as an LLM evaluation. | |
Will be skipped unless `--evals` or `--only-evals` is passed to pytest. | |
""" | |
# [tool.pytest.ini_options] markers = [ "llm_eval: An LLM evaluation." ] | |
return pytest.mark.llm_eval(function) | |
llm_eval_decorator_name = llm_eval.__name__ | |
def pytest_configure(config: pytest.Config): | |
if not config.option.allow_evals and not config.option.only_evals: | |
setattr(config.option, "markexpr", f"not {llm_eval_decorator_name}") | |
elif config.option.only_evals: | |
setattr(config.option, "markexpr", llm_eval_decorator_name) | |
def pytest_runtest_setup(item: Item): | |
if item.get_closest_marker(llm_eval_decorator_name) and not item.config.option.allow_evals: | |
pytest.skip("LLM evaluation tests are disabled. Use --allow-evals to run them.") | |
def pytest_addoption(parser): | |
parser.addoption( | |
"--allow-evals", | |
action="store_true", | |
dest="allow_evals", | |
default=False, | |
help=( | |
"Also run LLM evaluation tests (off by default).\n" | |
f"Evals are tests decorated with @{llm_eval_decorator_name}." | |
), | |
) | |
parser.addoption( | |
"--only-evals", | |
action="store_true", | |
dest="only_evals", | |
default=False, | |
help=( | |
"Only run LLM evaluation tests (off by default).\n" | |
f"Evals are tests decorated with @{llm_eval_decorator_name}." | |
), | |
) | |
parser.addoption('--krange', action='store', dest='krange', default=None, help='e.g. "--krange=2..5", "--krange=2..", "--krange=..5"') | |
def pytest_cmdline_main(config): | |
config.option.keyword += 'and not bad_test' | |
@pytest.fixture(scope="function") | |
def current_test_name(request: pytest.FixtureRequest) -> str: | |
return request.node.name.replace("[", "_").replace("]", "_").lower() | |
def os_system(cmd: str, **extra_env_vars) -> int: | |
return subprocess.call( | |
cmd, | |
shell=True, | |
env={**os.environ, **{k: str(v) for k, v in extra_env_vars.items()}}, | |
) | |
def pytest_exception_interact(node, call, report) -> None: | |
for command in ( | |
"docker ps -a", | |
"echo ${CI_PROJECT_DIR}/public/${CI_PROJECT_NAME}", | |
"mkdir ${CI_PROJECT_DIR}/public/${CI_PROJECT_NAME}/docker_logs", | |
"docker ps -a | awk '{print $1, $NF}' | sed 1,1d", | |
"docker ps -a | awk '{print $1, $NF}' | sed 1,1d | " | |
"while read -r containerid containername; do " | |
"docker logs ${containerid} > " | |
"${CI_PROJECT_DIR}/public/${CI_PROJECT_NAME}/docker_logs/${containername}_logs.log; " | |
"done", | |
): | |
escaped = command.replace("'", "”") | |
os.system(f"echo =============== '{escaped}' =============") | |
os.system(command) | |
def assert_sub_dict(superdict, subdict, *, superdict_name: str = None, subdict_name: str = None, breadcrumbs: str = ''): | |
""" | |
Asserts that every key-value pair in `subdict` is present and equal in `superdict`. | |
Accumulates all differences and raises an exception with all differences at the end. | |
Also accepts pydantic models as input. | |
>>> assert_sub_dict({"a": 1, "b": 2}, {"a": 1}) | |
>>> class MyModel(BaseModel): | |
... a: int | |
>>> assert_sub_dict({"a": 1, "b": 2}, MyModel(a=1)) | |
""" | |
superdict = superdict.dict() if isinstance(superdict, BaseModel) else superdict | |
subdict = subdict.dict() if isinstance(subdict, BaseModel) else subdict | |
superdict_name = superdict_name if superdict_name else 'Superdict' | |
subdict_name = subdict_name if subdict_name else 'Subdict' | |
differences = [] | |
def _compare_dicts(super_dict, sub_dict, current_breadcrumbs=''): | |
for sub_key, sub_value in sub_dict.items(): | |
if sub_key not in super_dict: | |
differences.append( | |
f"Key '{current_breadcrumbs}.{sub_key}' not found in {superdict_name}. " | |
f'{subdict_name}{current_breadcrumbs}: {sub_value!r}' | |
) | |
continue | |
super_value = super_dict[sub_key] | |
if isinstance(sub_value, (dict, BaseModel)): | |
_compare_dicts( | |
super_value, | |
sub_value.dict() if isinstance(sub_value, BaseModel) else sub_value, | |
current_breadcrumbs=f'{current_breadcrumbs}.{sub_key}', | |
) | |
elif super_value != sub_value: | |
differences.append( | |
f"Values of key '{current_breadcrumbs}.{sub_key}' do not match. " | |
f'{superdict_name}: {super_value!r} != {subdict_name}: {sub_value!r}' | |
) | |
_compare_dicts(superdict, subdict, breadcrumbs) | |
if differences: | |
raise AssertionError('\n· ' + '\n· '.join(differences)) | |
def assert_dicts_equal(dict1: dict, dict2: dict, *, dict1_name: str = 'Dict1', dict2_name: str = 'Dict2'): | |
""" | |
Asserts that every key-value pair in `subdict` is present and equal in `superdict`, and vice versa. | |
""" | |
assert_sub_dict(dict1, dict2, superdict_name=dict1_name, subdict_name=dict2_name) | |
assert_sub_dict(dict2, dict1, superdict_name=dict2_name, subdict_name=dict1_name) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment