|
from collections import namedtuple |
|
from functools import ( |
|
partial, |
|
reduce, |
|
wraps, |
|
) |
|
import inspect |
|
from inspect import _ParameterKind as PKEnum |
|
|
|
import pytest |
|
|
|
|
|
class _Missing: |
|
def __repr__(self): |
|
return '<Missing>' |
|
|
|
def __bool__(self): |
|
return False |
|
|
|
Missing = _Missing() |
|
provided = lambda **kwargs: \ |
|
{k: v for k, v in kwargs.items() if v is not Missing} |
|
|
|
ParamTuple = namedtuple( |
|
'ParamTuple', |
|
('constructor', 'converter', 'validator'), |
|
) |
|
|
|
|
|
_run_validators = True |
|
|
|
def get_run_validators(): |
|
""" |
|
Return whether or not validators are run. |
|
""" |
|
return _run_validators |
|
|
|
|
|
def set_run_validators(run): |
|
""" |
|
Set whether or not validators are run. By default, they are run. |
|
""" |
|
if not isinstance(run, bool): |
|
raise TypeError("'run' must be bool.") |
|
global _run_validators |
|
_run_validators = run |
|
|
|
|
|
class Signature: |
|
def __init__(self, returns=Missing, **params): |
|
# pylint: disable=W0212, protected-access |
|
self.returns = returns or inspect._empty |
|
params_ = {k: v.constructor(k) for k, v in params.items()} |
|
self.params = reduce( |
|
lambda i, j: tuple(i) + tuple(j), |
|
[ |
|
filter(lambda p, i=n: p.kind == PKEnum(i), params_.values()) |
|
for n in range(5) |
|
], |
|
) |
|
self.converters = { |
|
k: v.converter |
|
for k, v in params.items() if v.converter |
|
} |
|
self.validators = { |
|
k: v.validator |
|
for k, v in params.items() if v.validator |
|
} |
|
self.func = None |
|
|
|
def _make(self): |
|
return inspect.Signature(self.params, return_annotation=self.returns) |
|
|
|
def convert(self, bound): |
|
for k, v in self.converters.items(): |
|
bound.arguments[k] = v(bound.arguments[k]) |
|
|
|
def validate(self, bound): |
|
for k, validator in self.validators.items(): |
|
if isinstance(validator, (tuple, list)): |
|
for v in validator: |
|
v(k, bound.arguments[k]) |
|
continue |
|
validator(k, bound.arguments[k]) |
|
|
|
def __call__(self, func): |
|
self.func = func |
|
@wraps(self.func) |
|
def inner(*_args, **_kwargs): |
|
try: |
|
bound = inner.__signature__.bind(*_args, **_kwargs) |
|
except TypeError as exc: |
|
raise TypeError(f'{inner.__name__}() {exc.args[0]}') |
|
|
|
if self.converters: |
|
self.convert(bound) |
|
if get_run_validators() and self.validators: |
|
self.validate(bound) |
|
|
|
return self.func(*bound.args, **provided(**bound.kwargs)) |
|
inner.__signature__ = self._make() |
|
return inner |
|
|
|
|
|
def arg( |
|
default=Missing, |
|
positional_only=False, |
|
required=False, |
|
type=Missing, |
|
converter=None, |
|
validator=None, |
|
): |
|
# pylint: disable=W0212, protected-access |
|
# pylint: disable=W0622, redefined-builtin |
|
kind = PKEnum.POSITIONAL_ONLY \ |
|
if positional_only \ |
|
else PKEnum.POSITIONAL_OR_KEYWORD |
|
default = default \ |
|
if (default is not Missing or required) \ |
|
else inspect._empty |
|
|
|
return ParamTuple( |
|
constructor=partial( |
|
inspect.Parameter, |
|
kind=kind, |
|
default=default, |
|
annotation=type or inspect._empty, |
|
), |
|
converter=converter, |
|
validator=validator, |
|
) |
|
|
|
|
|
def kwarg( |
|
default=Missing, |
|
required=False, |
|
type=Missing, |
|
converter=None, |
|
validator=None, |
|
): |
|
# pylint: disable=W0212, protected-access |
|
# pylint: disable=W0622, redefined-builtin |
|
default = default \ |
|
if (default is not Missing or required) \ |
|
else inspect._empty |
|
|
|
return ParamTuple( |
|
constructor=partial( |
|
inspect.Parameter, |
|
kind=PKEnum.KEYWORD_ONLY, |
|
default=default, |
|
annotation=type or inspect._empty, |
|
), |
|
converter=converter, |
|
validator=validator, |
|
) |
|
|
|
|
|
args = ParamTuple( |
|
constructor=partial(inspect.Parameter, kind=PKEnum.VAR_POSITIONAL), |
|
converter=None, |
|
validator=None, |
|
) |
|
|
|
kwargs = ParamTuple( |
|
constructor=partial(inspect.Parameter, kind=PKEnum.VAR_KEYWORD), |
|
converter=None, |
|
validator=None, |
|
) |
|
|
|
|
|
@Signature( |
|
po1=arg(positional_only=True), |
|
po2=arg(positional_only=True, type=int), |
|
po3=arg(positional_only=True, default=7), |
|
po4=arg(positional_only=True, required=True), |
|
pok1=arg(), |
|
pok2=arg(type=int), |
|
pok3=arg(default=7), |
|
pok4=arg(required=True), |
|
vp=args, |
|
ko1=kwarg(), |
|
ko2=kwarg(required=True), |
|
ko3=kwarg(default=7), |
|
vk=kwargs, |
|
) |
|
def myfunc(*args, **kwargs): |
|
return args, kwargs |
|
|
|
|
|
@Signature( |
|
xyz=arg(converter=lambda x: x+1), |
|
) |
|
def mf2(*args, **kwargs): |
|
return args, kwargs |
|
|
|
|
|
class TestArg: |
|
# pylint: disable=R0201, no-self-use |
|
def test_converter(self): |
|
@Signature(xyz=arg(converter=lambda x: -1 * x)) |
|
def func(*args, **kwargs): |
|
return args, kwargs |
|
args, kwargs = (1,), {} |
|
assert func(*args, **kwargs) == ((-1,), {}) |
|
|
|
@pytest.mark.parametrize(('args', 'raises'), [ |
|
pytest.param((1,), False), |
|
pytest.param((0,), True), |
|
]) |
|
def test_validator(self, args, raises): |
|
err = ValueError('xyz must be greater than 0') |
|
def validate_xyz(attr, val): |
|
if val <= 0: |
|
raise ValueError(err) |
|
|
|
@Signature(xyz=arg(validator=validate_xyz)) |
|
def func(*args, **kwargs): |
|
return args, kwargs |
|
if not raises: |
|
assert func(*args) == (args, {}) |
|
return |
|
|
|
with pytest.raises(type(err)) as excinfo: |
|
func(*args) |
|
assert excinfo.value.args[0] == err |