Skip to content

Instantly share code, notes, and snippets.

@samuelcolvin
Last active May 10, 2023 17:45
Show Gist options
  • Star 4 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save samuelcolvin/625a1655c2a73e02469fc3c27285ca42 to your computer and use it in GitHub Desktop.
Save samuelcolvin/625a1655c2a73e02469fc3c27285ca42 to your computer and use it in GitHub Desktop.
auto-generate assert statements in pytest
"""
License: MIT
Copyright (c) 2022 Samuel Colvin.
See https://twitter.com/adriangb01/status/1573708407479189505
## Usage
Once installed just add
```py
insert_assert(the_value)
# or
insert_assert(function_call())
```
to a test and run pytest this code will collect the argument string, and the value, format it with black
and substitute `assert <argument-code> == value` into code when pytest finishes.
## Installation
To use this (until it's a proper package and pytest plugin):
Add this file to `tests`, exclude it from git, add the following to your conftest.py:
```py
try:
from .insert_assert import *
except ImportError:
pass
```
## Example usage
```py
def test_string():
thing = 'foobar'
insert_assert(thing)
def test_list_callable():
def foobar():
return ['foo', 1, b'bytes']
insert_assert(foobar())
def test_comprehension():
insert_assert([f'x{i}' for i in range(10)])
```
"""
import ast
import os
import sys
import textwrap
from dataclasses import dataclass
from enum import Enum
from itertools import groupby
from pathlib import Path
from types import FrameType
from typing import Any
import pytest
from black import InvalidInput, Mode, TargetVersion, format_file_contents
# requires pip install executing black
from executing import Source
__all__ = 'add_insert_assert_to_builtins', 'pytest_terminal_summary', 'insert_assert'
@dataclass
class ToReplace:
file: Path
start_line: int
end_line: int
code: str
to_replace: list[ToReplace] = []
@pytest.fixture(scope='session', autouse=True)
def add_insert_assert_to_builtins():
__builtins__['insert_assert'] = insert_assert
def pytest_terminal_summary():
if to_replace:
# TODO replace with a pytest argument
try_run = bool(os.getenv('TRY_RUN'))
file_count = 0
for file, group in groupby(to_replace, key=lambda tr: tr.file):
# we have to substitute lines in reverse order to avoid messing up line numbers
lines = file.read_text().splitlines()
for tr in sorted(group, key=lambda x: x.start_line, reverse=True):
if try_run:
hr = '-' * 80
print(f'{file} - {tr.start_line}:{tr.end_line}:\n{hr}\n{tr.code}{hr}\n')
else:
lines[tr.start_line - 1 : tr.end_line] = tr.code.splitlines()
if not try_run:
file.write_text('\n'.join(lines))
file_count += 1
print(f'replaced {len(to_replace)} insert_assert() calls in {file_count} files')
def insert_assert(value):
call_frame: FrameType = sys._getframe(1)
source = Source.for_frame(call_frame)
ex = source.executing(call_frame)
ast_arg = ex.node.args[0]
if isinstance(ast_arg, ast.Name):
arg = ast_arg.id
else:
arg = ' '.join(map(str.strip, ex.source.asttokens().get_text(ast_arg).splitlines()))
python_code = f'# insert_assert({arg})\nassert {arg} == {custom_repr(value)}'
mode = Mode(
line_length=120,
string_normalization=False,
magic_trailing_comma=False,
target_versions={TargetVersion.PY37, TargetVersion.PY38, TargetVersion.PY39, TargetVersion.PY310},
)
try:
python_code = format_file_contents(python_code, fast=False, mode=mode)
except InvalidInput:
# we just ignore this and allow the user to fix the code and run black
print('black error')
pass
python_code = textwrap.indent(python_code, ex.node.col_offset * ' ')
to_replace.append(ToReplace(Path(call_frame.f_code.co_filename), ex.node.lineno, ex.node.end_lineno, python_code))
def custom_repr(value):
if isinstance(value, (list, tuple, set, frozenset)):
return value.__class__(map(custom_repr, value))
elif isinstance(value, dict):
return value.__class__((custom_repr(k), custom_repr(v)) for k, v in value.items())
if isinstance(value, Enum):
return PlainRepr(f'{value.__class__.__name__}.{value.name}')
else:
return PlainRepr(repr(value))
class PlainRepr:
__slots__ = ('s',)
def __init__(self, s: str):
self.s = s
def __repr__(self):
return self.s
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment