Skip to content

Instantly share code, notes, and snippets.

@kurtbrose
Created July 21, 2011 18:01
Show Gist options
  • Save kurtbrose/1097771 to your computer and use it in GitHub Desktop.
Save kurtbrose/1097771 to your computer and use it in GitHub Desktop.
simple parameter type decorator
import inspect
def accepts(*a, **kw):
"""
Decorates a function to add type checking to calls.
Raises TypeError if the decorated function is called with parameters that violate the type constraints.
@accepts can be called in the same way as the function it wraps, with positional and keyword arguments replaced by types.
Any arguments of the decorated function whose type are not specified will not be type checked.
Any arguments of the decorated function with default values will have those default values checked immediately.
"""
def decorator(f):
f_args, f_varargs, f_keywords, f_defaults = inspect.getargspec(f)
default_dict = dict(zip(f_args[-1*len(f_defaults):], f_defaults)) if f_defaults else {}
type_dict = dict(zip(f_args[:len(a)], a))
type_dict.update(kw)
for k in type_dict:
if not isinstance(type_dict[k], tuple):
type_dict[k] = (type_dict[k],)
#
def check(args, kwargs):
arg_dict = {}
arg_dict.update(default_dict)
arg_dict.update(dict(zip(f_args[:len(args)], args))) #second add in non-keyword arguments
arg_dict.update(**kwargs) #finally, add in keyword arguments
wrong_types = [(k, arg_dict[k], type_dict[k]) for k in type_dict if arg_dict.has_key(k) and not any(isinstance(arg_dict[k], t) for t in type_dict[k])]
if len(wrong_types) > 0:
msg = f.__name__ + "() has bad parameters:\n" + \
"\n".join([pname + ": got " + type(pval).__name__ + "; expected " + ", or ".join([t.__name__ for t in ptypes]) for pname, pval, ptypes in wrong_types])
raise TypeError, msg
#
check([], {}) #raise a TypeError at definition time if any of the default values violate the type constraints
#
def g(*args, **kwargs):
check(args, kwargs)
return f(*args, **kwargs)
#
return g
#
return decorator
def accepts27(*a, **kw):
'''
Same as above, can be done more simply with extensions to inspect module made in 2.7
'''
def decorator(f):
f_args, f_varargs, f_keywords, f_defaults = inspect.getargspec(f)
type_dict = dict(zip(f_args[:len(a)], a))
type_dict.update(kw)
#
def check(val_dict):
wrong_types = [(k, val_dict[k], type_dict[k]) for k in type_dict\
if arg_dict.has_key(k) and not isinstance(val_dict[k], type_dict[k])]
if len(wrong_types) > 0:
msg = f.__name__ + "() has bad parameters:\n" + \
"\n".join([pname + ": got " + type(pval).__name__ + "; expected " + ", or ".join([t.__name__ for t in ptypes]) for pname, pval, ptypes in wrong_types])
raise TypeError, msg
#
if f_defaults:
check(dict(zip(f_args[-1*len(f_defaults):], f_defaults)))
#
def g(*args, **kwargs):
arg_dict = inspect.getcallargs(f, *args, **kwargs)
check(arg_dict)
return f(*args, **kwargs)
return g
return decorator
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment