Last active
June 13, 2020 02:10
-
-
Save trhodeos/5a20b438480c880f7e15f08987bd9c0f to your computer and use it in GitHub Desktop.
Helper methods to enforce strict optional args for python fire
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Helpers methods for interacting with python fire.""" | |
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import functools | |
import inspect | |
def only_allow_defined_args(function_to_decorate): | |
"""Decorator which only allows arguments defined to be used. | |
Note, we need to specify this, as Fire allows method chaining. This means | |
that extra kwargs are kept around and passed to future methods that are | |
called. We don't need this, and should fail early if this happens. | |
Args: | |
function_to_decorate: Function which to decorate. | |
Returns: | |
Wrapped function. | |
""" | |
@functools.wraps(function_to_decorate) | |
def _return_wrapped(*args, **kwargs): | |
"""Internal wrapper function.""" | |
valid_names, _, _, _ = inspect.getargspec(function_to_decorate) | |
if "self" in valid_names: | |
valid_names.remove("self") | |
for arg_name in kwargs: | |
if arg_name not in valid_names: | |
raise ValueError("Unknown argument seen '%s', expected: [%s]" % | |
(arg_name, ", ".join(valid_names))) | |
return function_to_decorate(*args, **kwargs) | |
return _return_wrapped |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Tests for fire_utils.""" | |
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import fire_utils | |
import unittest | |
def TestFunction(arg1=1, arg2="test"): | |
return (arg1, arg2) | |
class OnlyAllowDefinedArgsTest(unittest.TestCase): | |
def InstanceTestFunction(self, arg1=1): | |
return (self, arg1) | |
def testAllowsForDefinedArgs(self): | |
test_function = fire_utils.only_allow_defined_args(TestFunction) | |
self.assertEqual(test_function(arg1="hey", arg2="there"), ("hey", "there")) | |
def testAllowsInstanceFunctions(self): | |
# Note that using "self.InstanceTestFunction" binds 'self' accordingly. | |
test_function = fire_utils.only_allow_defined_args( | |
self.InstanceTestFunction) | |
self.assertEqual(test_function(arg1="hey"), (self, "hey")) | |
def testRaisesExceptionForSelfNameArg(self): | |
test_function = fire_utils.only_allow_defined_args( | |
self.InstanceTestFunction) | |
with self.assertRaisesRegex(ValueError, "Unknown argument seen 'self'"): | |
test_function(**{"arg1": 1, "self": 2}) | |
def testRaisesExceptionForUndefinedArgs(self): | |
test_function = fire_utils.only_allow_defined_args(TestFunction) | |
with self.assertRaisesRegex(ValueError, "Unknown argument seen 'arg3'"): | |
test_function(arg1="hey", arg2="there", arg3="world") | |
def testDoesNotRunWrappedFunction(self): | |
def ErrorFunction(): | |
raise Exception("oh no!") | |
test_function = fire_utils.only_allow_defined_args(ErrorFunction) | |
with self.assertRaisesRegex(Exception, "oh no!"): | |
test_function() | |
with self.assertRaisesRegex(ValueError, "Unknown argument seen 'arg1'"): | |
test_function(arg1="hey") | |
if __name__ == "__main__": | |
unittest.main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment