Skip to content

Instantly share code, notes, and snippets.

@supposedly
Last active December 19, 2018 22:42
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save supposedly/467e737a16bb96a09940d084a8ac2102 to your computer and use it in GitHub Desktop.
Save supposedly/467e737a16bb96a09940d084a8ac2102 to your computer and use it in GitHub Desktop.
import inspect
from functools import wraps
from itertools import islice, starmap
def _callable(obj):
return callable(obj) and obj is not inspect._empty
def convert(hint, val):
return hint(val) if _callable(hint) else val
def typecast(func):
"""
Wraps func such that arguments passed to it will be converted
according to its typehints.
More specifically, calls func's annotations on arguments
passed to it; non-callable annotations are not touched.
If a callable annotates a variadic argument (*, **),
the annotation will be called on each value therein.
"""
def _hint_for(param):
return func.__annotations__.get(param.name)
params = inspect.signature(func).parameters.values()
# Gather annotations
# ...of positional parameters
pos = [_hint_for(p) for p in params if p.kind < VAR_POSITIONAL]
var_pos = next((_hint_for(p) for p in params if p.kind == VAR_POSITIONAL), None)
pos_defaults = [p.default for p in params if p.kind < VAR_POSITIONAL]
# ...of keyword parameters
kw = {p.name: _hint_for(p) for p in params if p.kind == KEYWORD_ONLY}
var_kw = next((_hint_for(p) for p in params if p.kind > KEYWORD_ONLY), None)
kw_defaults = {p.name: p.default for p in params if p.kind == KEYWORD_ONLY}
@wraps(func)
def wrapper(*args, **kwargs):
args_, kwargs_ = [], {}
# Can use a consumable generator to keep track of what
# positionals are left to convert
arg_iter = iter(args)
if len(args) > len(pos) and not var_pos:
# More positional arguments were passed than func accepts
func(*args, **kwargs) # raise TypeError
# Type-convert the positional arguments that were passed as such
args_.extend(starmap(convert, zip(pos, arg_iter)))
# Fill in the rest with either positional parameters passed as kwargs
# or, failing that, each parameter's default value
for param, hint, default in islice(zip(params, pos, pos_defaults), len(args_), None):
if param.name in kwargs:
args_.append(convert(hint, kwargs.pop(param.name)))
else:
args_.append(default)
# If some positionals aren't present and also don't have defaults,
if inspect._empty in args_:
# Then they were simply not passed as positionals,
# but they may have been passed via keyword:
for idx, (param, hint, passed) in enumerate(zip(params, pos, args_)):
if passed is not inspect._empty:
# Only look at those for which nothing was passed
continue
try:
args_[idx] = convert(hint, kwargs.pop(param.name))
except KeyError:
# Then this parameter wasn't given, period
func(*args, **kwargs) # raise TypeError
# If func accepts *args and arg_iter has any values left in it, they
# should be passed to *args
if var_pos is not None:
args_.extend(map(var_pos, arg_iter) if _callable(var_pos) else arg_iter)
# Keyword-parameter typehints:
for name, hint in kw.items():
try:
kwargs_[name] = convert(hint, kwargs[name])
except KeyError:
default = kw_defaults[name]
if default is inspect._empty:
# Keyword argument was not passed and has no default
func(*args, **kwargs) # raise TypeError
kwargs_[name] = default
# **kwargs: just convert every value while keeping the dict otherwise intact
if var_kw is not None:
kwargs_.update({name: convert(var_kw, val) for name, val in kwargs.items() if name not in kwargs_})
return func(*args_, **kwargs_)
return wrapper
@typecast
def test(thing: str, other: 'not callable', *, one, two: int):
"""
>>> test(1, [], one='1', two='2')
('1', <class 'str'>, [], <class 'list'>, '1', <class 'str'>, 2, <class 'int'>)
"""
return thing, type(thing), other, type(other), one, type(one), two, type(two)
#------------------------------------------------------------#
@typecast
def test(one: float, *args: str, two: str = None, **kwargs: 'not callable') -> list:
"""
>>> test(1, 2, 3, three='5')
[1.0, <class 'float'>, ('2', '3'), [<class 'str'>, <class 'str'>], None, <class 'NoneType'>, {'three': ('5', <class 'str'>)}]
"""
return one, type(one), args, [type(i) for i in args], two, type(two), {i: (kwargs[i], type(kwargs[i])) for i in kwargs}
#------------------------------------------------------------#
@typecast
def test(one, *args, two, **kwargs) -> str:
"""
>>> test(1, 2, two=3, three='5')
"(1, <class 'int'>, (2,), [<class 'int'>], 3, <class 'int'>, {'three': ('5', <class 'str'>)})"
"""
return one, type(one), args, [type(i) for i in args], two, type(two), {i: (kwargs[i], type(kwargs[i])) for i in kwargs}
#------------------------------------------------------------#
@typecast
def test(one: float, *, two: float, **kwargs: float):
"""
>>> test(1, two=3, three='5')
(1.0, <class 'float'>, 3.0, <class 'float'>, {'three': (5.0, <class 'float'>)})
"""
return one, type(one), two, type(two), {i: (kwargs[i], type(kwargs[i])) for i in kwargs}
#------------------------------------------------------------#
### DON'T USE THESE ONES IN ANY SERIOUS MANNER EVER ###
def TRUNC(num): return lambda seq: seq[:num]
@typecast
def truncate_args(one: TRUNC(2), two: TRUNC(3)):
"""
>>> truncate_args('abcd', 'efgh')
('ab', 'efg')
"""
return one, two
#------------------------------------------------------------#
def EQ_LEN(seq, *, l=[]):
l.append(len(seq))
if len(l) < 2: # 2 == no. args
return seq
try:
assert all(i == l[0] for i in l)
except AssertionError:
raise ValueError('args must be of equal length')
else:
return seq
finally:
l.clear()
@typecast
def what(arg1: EQ_LEN, arg2: EQ_LEN):
"""
>>> what('abcd', 'abcd')
(4, 4)
>>> what('abcd', 'abc')
Traceback (most recent call last):
File "<stdin>", line 6, in EQ_LEN
ValueError: args must be of equal length
>>>
"""
return len(arg1), len(arg2)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment