Skip to content

Instantly share code, notes, and snippets.

@evjeny
Created August 10, 2022 16:23
Show Gist options
  • Save evjeny/6110a97289b2ae90004a6a95b8b294ce to your computer and use it in GitHub Desktop.
Save evjeny/6110a97289b2ae90004a6a95b8b294ce to your computer and use it in GitHub Desktop.
params parsers
import argparse
import inspect
import typing
from updaters import UpdaterA, UpdaterB, UpdaterC, UpdaterD
params_type = dict[str, typing.Any]
def add_default_params(parser: argparse.ArgumentParser, default_params: params_type):
for param_name, param_value in default_params.items():
parser.add_argument(
f"--{param_name}",
default=None,
type=type(param_value),
help=f"default value = {param_value}, type = {type(param_value)}"
)
def add_instance_params(parser: argparse.ArgumentParser, instance: type, default_params: params_type):
name = instance.__name__
for param_name, param_desc in inspect.signature(instance.__init__).parameters.items():
if param_name == "self":
continue
param_value = None
param_type = None
if param_desc.default != inspect.Signature.empty:
param_value = param_desc.default
if param_desc.annotation != inspect.Signature.empty:
param_type = param_desc.annotation
if param_name in default_params:
param_value = default_params[param_name]
if param_type is None:
param_type = type(param_value)
parser.add_argument(
f"--{name}__{param_name}",
type=param_type,
default=param_value,
help=f"default value = {param_value}, type = {param_type}"
)
def check_all_parameters(parsed_params: params_type, instances: list[type]):
# params for specific instances
instance_param_names: list[str] = [
param_name for param_name in parsed_params.keys()
if any(
param_name.startswith(f"{instance.__name__}__")
for instance in instances
)
]
for param_name in instance_param_names:
if parsed_params[param_name] is None:
raise Exception(f"{param_name} should been provided!")
def set_default_params_for_instances(
parsed_params: params_type,
instances: list[type],
default_params: params_type
) -> params_type:
result: params_type = dict(parsed_params)
provided_default_params: params_type = dict()
for default_param_name in default_params.keys():
if parsed_params.get(default_param_name, None) is not None:
provided_default_params[default_param_name] = parsed_params[default_param_name]
for instance in instances:
name = instance.__name__
for param_name, param_value in provided_default_params.items():
if f"{name}__{param_name}" in parsed_params:
result[f"{name}__{param_name}"] = param_value
return result
def get_params(instances: list[type], default_params: params_type) -> dict[type, dict[str, typing.Any]]:
parser = argparse.ArgumentParser("Dummy parser")
add_default_params(parser, default_params)
for instance in instances:
add_instance_params(parser, instance, default_params)
dict_args = vars(parser.parse_args())
check_all_parameters(dict_args, instances)
dict_args = set_default_params_for_instances(dict_args, instances, default_params)
instance_kwargs: dict[type, dict[str, typing.Any]] = dict()
for instance in instances:
name = instance.__name__
instance_kwargs[instance] = {
param_name.removeprefix(f"{name}__"): param_value
for param_name, param_value in dict_args.items()
if param_name.startswith(f"{name}__")
}
return instance_kwargs
def main():
instances = [UpdaterA, UpdaterB, UpdaterC, UpdaterD]
instance_params = get_params(
instances,
default_params={
"max_tries": 3,
"interval": 6.28
}
)
for instance in instances:
print(f"Construct `{instance.__name__}`")
instance_object = instance(**instance_params[instance])
print(instance_object)
print()
main()
class BaseUpdater:
def __init__(self, **kwargs):
for arg_name, arg_value in kwargs.items():
print(f"{arg_name} = {arg_value}")
class UpdaterA(BaseUpdater):
def __init__(
self,
interval: float,
table_name: str = "hello",
dummy_size: int = 12
):
super().__init__(interval=interval, table_name=table_name, dummy_size=dummy_size)
class UpdaterB(BaseUpdater):
def __init__(
self,
max_tries: int,
query: str = "qwek",
asd: float = 123.12
):
super().__init__(max_tries=max_tries, query=query, asd=asd)
class UpdaterC(BaseUpdater):
def __init__(
self,
interval: float,
max_tries: int,
table_name: str = "yolo",
another: int = 15
):
super().__init__(interval=interval, max_tries=max_tries, table_name=table_name, another=another)
class UpdaterD(BaseUpdater):
def __init__(
self,
not_annotated
):
super().__init__(not_annotated=not_annotated)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment