Skip to content

Instantly share code, notes, and snippets.

@iscgar
Last active July 18, 2021 22:44
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save iscgar/731dfae7a6fbc26c9375624af1e3712a to your computer and use it in GitHub Desktop.
Save iscgar/731dfae7a6fbc26c9375624af1e3712a to your computer and use it in GitHub Desktop.
Helper decorator to enforce strict optional args for python fire
# Originally by Tyler Rhodes at https://gist.github.com/trhodeos/5a20b438480c880f7e15f08987bd9c0f
# Was broken in later fire versions due to unwrapping of decorated
# functions by fire, which ignores the signature of the wrapper and
# just doesn't pass the needed parameters, so I reimplemented this
# without functools.wraps() and added a bit more validation.
# Unfortunately this solution adds unneeded things to the help output
# due to the added varargs and kwargs, but there's no way around this
# that I know of.
import fire
import inspect
import operator
def only_allow_defined_args(wrapped):
"""Decorator which only allows arguments defined to be used.
Note, we need to specify this, as Fire allows method chaining. This means
that extra kwargs are kept around and passed to future methods that are
called. We don't need this, and should fail early if this happens.
Args:
wrapped: Function which to decorate.
Returns:
Wrapped function.
"""
def wrapper(*_, **kwargs):
argspec = inspect.getfullargspec(wrapped)
positional = {n: i for i, n in enumerate(argspec.args)}
possible_kwargs = set(argspec.kwonlyargs)
if positional and argspec.defaults:
possible_kwargs.update(argspec.args[-len(argspec.defaults):])
positional_left = list(_)
unknown_args = []
for name in kwargs:
idx = positional.get(name)
if idx is not None:
positional_left.pop(idx)
del positional[name]
for k, v in positional.items():
if v > idx:
positional[k] -= 1
elif name in possible_kwargs:
possible_kwargs.remove(name)
else:
unknown_args.append(name)
positional_fulfilled = list(reversed(sorted(
positional.items(), key=operator.itemgetter(1))))
for k, idx in positional_fulfilled:
try:
positional_left.pop(idx)
except IndexError:
pass
else:
del positional[k]
if unknown_args and not argspec.varkw:
possible_kwargs.update(positional.keys())
msg = 'Unknown arguments {}'.format(unknown_args)
if possible_kwargs:
msg += ', expected: {}'.format(list(possible_kwargs))
raise fire.core.FireError(msg)
if positional_left and not argspec.varargs:
raise fire.core.FireError(
'Extraneous arguments specified: {}'.format(positional_left))
return wrapped(*_, **kwargs)
for attr in ('__module__', '__name__', '__qualname__', '__doc__', '__annotations__'):
try:
setattr(wrapper, attr, getattr(wrapped, attr))
except AttributeError:
pass
getattr(wrapper, '__dict__').update(getattr(wrapped, '__dict__', {}))
wrapper_params = list(inspect.signature(wrapper).parameters.values())
POS = {
inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD}
sig = inspect.signature(wrapped)
parameters = list(sig.parameters.values())
varargs_pos = len(parameters)
for i, p in enumerate(parameters):
if p.kind in POS:
varargs_pos = i + 1
elif p.kind == inspect.Parameter.VAR_POSITIONAL:
varargs_pos = i
if (varargs_pos >= len(parameters) or
parameters[varargs_pos].kind != inspect.Parameter.VAR_POSITIONAL):
parameters.append(wrapper_params[0])
if parameters[-1].kind != inspect.Parameter.VAR_KEYWORD:
parameters.append(wrapper_params[1])
setattr(wrapper, '__signature__', inspect.Signature(
parameters, return_annotation=sig.return_annotation))
return wrapper
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment