Skip to content

Instantly share code, notes, and snippets.

@maxsei
Created October 2, 2019 18:12
Show Gist options
  • Save maxsei/6e26d474421a43e79023f73ecb03915c to your computer and use it in GitHub Desktop.
Save maxsei/6e26d474421a43e79023f73ecb03915c to your computer and use it in GitHub Desktop.
import argparse
import inspect
def describe(thing):
print(list(filter(lambda x: x[0] != "_", dir(thing))))
def argify(*types):
"""
argify is a wrapper function for a regular function. It will take the names
of the arguements to a given function and parse commandline arguements with
the same names as the arguements. If the same function that is wrapped is
used as a function not as a part of main the function is called back and the
command line arguements are not set.
Parameters
----------
*types: variadic number of arguements that describe the types of command line
arguements you expect to provide in the callback function
Returns
-------
callback function
"""
def add_args(func):
# if the function not called as a part of the main function just use the
# rgular function
if __name__ != "__main__":
return func
# get the function signature
signature = inspect.signature(func)
assert len(types) == len(signature.parameters)
# get function defaults
defaults = {
arg_name: (
None
if signature.parameters[arg_name].default == inspect.Parameter.empty
else signature.parameters[arg_name].default
)
for arg_name in signature.parameters
}
# parse commandline arguements
parser = argparse.ArgumentParser()
for param, _type in zip(signature.parameters, types):
parser.add_argument("--" + param, type=_type)
parsed_arguements = parser.parse_args()
# add specified arg with correct type otherwise add default kwarg
func(
*[
type_(getattr(parsed_arguements, arg))
if getattr(parsed_arguements, arg)
else defaults[arg]
for type_, arg in zip(types, defaults)
]
)
# return a function that does nothing when done to satisfy call
def none(*args):
pass
return none
return add_args
@argify(int, str, str, str)
def data_validation(
raw_training_data_dir, raw_eval_data_dir, utils_filename, project_path="default"
):
print(raw_training_data_dir, raw_eval_data_dir, utils_filename, project_path)
print(
type(raw_training_data_dir),
type(raw_eval_data_dir),
type(utils_filename),
type(project_path),
)
return
# data_validation(1, "foo", "bar")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment