Skip to content

Instantly share code, notes, and snippets.

@louisswarren
Last active December 28, 2022 05:40
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save louisswarren/b10073be844e15e2aa0b480bfb6a20cf to your computer and use it in GitHub Desktop.
Save louisswarren/b10073be844e15e2aa0b480bfb6a20cf to your computer and use it in GitHub Desktop.
Testing using decorators
import io
import sys
class wrap_stdio:
def __init__(self, io_in = ''):
self.io_in = io.StringIO(io_in)
self.io_out = io.StringIO()
self.stdin = sys.stdin
self.stdout = sys.stdout
def __enter__(self):
sys.stdin = self.io_in
sys.stdout = self.io_out
return self.io_out
def __exit__(self, exc_type, exc_value, traceback):
sys.stdin = self.stdin
sys.stdout = self.stdout
class Example:
attribs = 'returns', 'sending', 'yields', 'inputting', 'prints'
def __init__(self, func, **kwargs):
self.func = func
self.inputting = ''
for attrib in self.attribs:
self.__dict__[attrib] = None
self.__dict__['check_' + attrib] = False
self.update(**kwargs)
def update(self, **kwargs):
for arg, value in kwargs.items():
if arg not in self.attribs:
raise TypeError(f"Attribute not one of {repr(self.attribs)}")
self.__dict__[arg] = value
self.__dict__['check_' + arg] = True
def __call__(self, *args, **kwargs):
raise TypeError(f"Example is not callable (did you forget @example?)")
def _check_returns(self, val):
if val != self.returns:
err = f"Returned {repr(val)}, expecting {repr(self.returns)}"
raise AssertionError(err)
def _check_prints(self, msg):
if msg != self.prints:
err = f"Printed {repr(msg)}, expecting {repr(self.prints)}"
raise AssertionError(err)
def _check_yield(self, val, pos):
if pos >= len(self.yields):
err = f"Yielded {repr(val)}, expecting to return"
raise AssertionError(err)
if val != self.yields[pos]:
err = f"Yielded {repr(val)}, expecting {repr(self.yields[pos])}"
raise AssertionError(err)
def _check_yields(self, r):
try:
self._check_yield(next(r), 0)
for i, x in enumerate(self.yields[1:]):
y = self.sending[i] if self.check_sending else None
self._check_yield(r.send(y), i + 1)
except StopIteration as e:
def test(self, *args, **kwargs):
with wrap_stdio(self.inputting) as out:
r = self.func(*args, **kwargs)
if self.check_yields:
r = self._check_yields(r)
if self.check_returns:
self._check_returns(r)
if self.check_prints:
self._check_prints(out.getvalue())
print("Success!")
# if self.check_yields:
# seq = []
# try:
# seq.append(next(r))
# if self.check_sending:
# for x in self.sending:
# seq.append(r.send(x)
#
def example_decorator(**kwargs):
def decorator(x):
if isinstance(x, Example):
x.update(**kwargs)
return x
else:
return Example(x, **kwargs)
return decorator
def example(*args, **kwargs):
def decorator(x):
eg = example_decorator()(x)
eg.test(*args, **kwargs)
return eg.func
return decorator
def returns(val): return example_decorator(returns = val)
def sending(*seq): return example_decorator(sending = seq)
def yields(*seq): return example_decorator(yields = seq)
def inputting(msg): return example_decorator(inputting = msg)
def prints(msg): return example_decorator(prints = msg)
@example(1, 2)
@returns(3)
@example(3, 4)
@prints("Got 3 4\n")
def foo(x, y):
print("Got", x, y)
return x + y
foo(10, 20)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment