Skip to content

Instantly share code, notes, and snippets.

@trhodeos
Last active June 13, 2020 02:10
Show Gist options
  • Save trhodeos/5a20b438480c880f7e15f08987bd9c0f to your computer and use it in GitHub Desktop.
Save trhodeos/5a20b438480c880f7e15f08987bd9c0f to your computer and use it in GitHub Desktop.
Helper methods to enforce strict optional args for python fire
"""Helpers methods for interacting with python fire."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import inspect
def only_allow_defined_args(function_to_decorate):
"""Decorator which only allows arguments defined to be used.
Note, we need to specify this, as Fire allows method chaining. This means
that extra kwargs are kept around and passed to future methods that are
called. We don't need this, and should fail early if this happens.
Args:
function_to_decorate: Function which to decorate.
Returns:
Wrapped function.
"""
@functools.wraps(function_to_decorate)
def _return_wrapped(*args, **kwargs):
"""Internal wrapper function."""
valid_names, _, _, _ = inspect.getargspec(function_to_decorate)
if "self" in valid_names:
valid_names.remove("self")
for arg_name in kwargs:
if arg_name not in valid_names:
raise ValueError("Unknown argument seen '%s', expected: [%s]" %
(arg_name, ", ".join(valid_names)))
return function_to_decorate(*args, **kwargs)
return _return_wrapped
"""Tests for fire_utils."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import fire_utils
import unittest
def TestFunction(arg1=1, arg2="test"):
return (arg1, arg2)
class OnlyAllowDefinedArgsTest(unittest.TestCase):
def InstanceTestFunction(self, arg1=1):
return (self, arg1)
def testAllowsForDefinedArgs(self):
test_function = fire_utils.only_allow_defined_args(TestFunction)
self.assertEqual(test_function(arg1="hey", arg2="there"), ("hey", "there"))
def testAllowsInstanceFunctions(self):
# Note that using "self.InstanceTestFunction" binds 'self' accordingly.
test_function = fire_utils.only_allow_defined_args(
self.InstanceTestFunction)
self.assertEqual(test_function(arg1="hey"), (self, "hey"))
def testRaisesExceptionForSelfNameArg(self):
test_function = fire_utils.only_allow_defined_args(
self.InstanceTestFunction)
with self.assertRaisesRegex(ValueError, "Unknown argument seen 'self'"):
test_function(**{"arg1": 1, "self": 2})
def testRaisesExceptionForUndefinedArgs(self):
test_function = fire_utils.only_allow_defined_args(TestFunction)
with self.assertRaisesRegex(ValueError, "Unknown argument seen 'arg3'"):
test_function(arg1="hey", arg2="there", arg3="world")
def testDoesNotRunWrappedFunction(self):
def ErrorFunction():
raise Exception("oh no!")
test_function = fire_utils.only_allow_defined_args(ErrorFunction)
with self.assertRaisesRegex(Exception, "oh no!"):
test_function()
with self.assertRaisesRegex(ValueError, "Unknown argument seen 'arg1'"):
test_function(arg1="hey")
if __name__ == "__main__":
unittest.main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment