from typing import ( | |
Any, | |
Dict, | |
NamedTuple, | |
) | |
import pytest | |
class Case(NamedTuple): | |
name: str | |
kwargs: Dict[str, Any] | |
def case(test_case_name=None, **kwargs): | |
if test_case_name is None: | |
test_case_name = ' '.join(f'{k}={v}' for k,v in sorted(kwargs.items())) | |
return Case(test_case_name, kwargs) | |
def parametrize(*cases: Case): | |
arg_names = sorted({k for c in cases for k in c.kwargs}) | |
return pytest.mark.parametrize( | |
','.join(arg_names), | |
[tuple(c.kwargs[a] for a in arg_names) for c in cases], | |
ids=[c.name for c in cases], | |
) | |
# Usage: | |
@parametrize( | |
case("some test case", a=1, b=2, result=3), | |
case(a=10, b=11, result=21), | |
case(a=88, b=99, result=88 + 99), | |
case(a=5, b=7, result=5 + 7), | |
) | |
def test_sum(a, b, result): | |
assert a + b == result |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment