Created
August 10, 2022 16:23
-
-
Save evjeny/6110a97289b2ae90004a6a95b8b294ce to your computer and use it in GitHub Desktop.
params parsers
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import inspect | |
import typing | |
from updaters import UpdaterA, UpdaterB, UpdaterC, UpdaterD | |
params_type = dict[str, typing.Any] | |
def add_default_params(parser: argparse.ArgumentParser, default_params: params_type): | |
for param_name, param_value in default_params.items(): | |
parser.add_argument( | |
f"--{param_name}", | |
default=None, | |
type=type(param_value), | |
help=f"default value = {param_value}, type = {type(param_value)}" | |
) | |
def add_instance_params(parser: argparse.ArgumentParser, instance: type, default_params: params_type): | |
name = instance.__name__ | |
for param_name, param_desc in inspect.signature(instance.__init__).parameters.items(): | |
if param_name == "self": | |
continue | |
param_value = None | |
param_type = None | |
if param_desc.default != inspect.Signature.empty: | |
param_value = param_desc.default | |
if param_desc.annotation != inspect.Signature.empty: | |
param_type = param_desc.annotation | |
if param_name in default_params: | |
param_value = default_params[param_name] | |
if param_type is None: | |
param_type = type(param_value) | |
parser.add_argument( | |
f"--{name}__{param_name}", | |
type=param_type, | |
default=param_value, | |
help=f"default value = {param_value}, type = {param_type}" | |
) | |
def check_all_parameters(parsed_params: params_type, instances: list[type]): | |
# params for specific instances | |
instance_param_names: list[str] = [ | |
param_name for param_name in parsed_params.keys() | |
if any( | |
param_name.startswith(f"{instance.__name__}__") | |
for instance in instances | |
) | |
] | |
for param_name in instance_param_names: | |
if parsed_params[param_name] is None: | |
raise Exception(f"{param_name} should been provided!") | |
def set_default_params_for_instances( | |
parsed_params: params_type, | |
instances: list[type], | |
default_params: params_type | |
) -> params_type: | |
result: params_type = dict(parsed_params) | |
provided_default_params: params_type = dict() | |
for default_param_name in default_params.keys(): | |
if parsed_params.get(default_param_name, None) is not None: | |
provided_default_params[default_param_name] = parsed_params[default_param_name] | |
for instance in instances: | |
name = instance.__name__ | |
for param_name, param_value in provided_default_params.items(): | |
if f"{name}__{param_name}" in parsed_params: | |
result[f"{name}__{param_name}"] = param_value | |
return result | |
def get_params(instances: list[type], default_params: params_type) -> dict[type, dict[str, typing.Any]]: | |
parser = argparse.ArgumentParser("Dummy parser") | |
add_default_params(parser, default_params) | |
for instance in instances: | |
add_instance_params(parser, instance, default_params) | |
dict_args = vars(parser.parse_args()) | |
check_all_parameters(dict_args, instances) | |
dict_args = set_default_params_for_instances(dict_args, instances, default_params) | |
instance_kwargs: dict[type, dict[str, typing.Any]] = dict() | |
for instance in instances: | |
name = instance.__name__ | |
instance_kwargs[instance] = { | |
param_name.removeprefix(f"{name}__"): param_value | |
for param_name, param_value in dict_args.items() | |
if param_name.startswith(f"{name}__") | |
} | |
return instance_kwargs | |
def main(): | |
instances = [UpdaterA, UpdaterB, UpdaterC, UpdaterD] | |
instance_params = get_params( | |
instances, | |
default_params={ | |
"max_tries": 3, | |
"interval": 6.28 | |
} | |
) | |
for instance in instances: | |
print(f"Construct `{instance.__name__}`") | |
instance_object = instance(**instance_params[instance]) | |
print(instance_object) | |
print() | |
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class BaseUpdater: | |
def __init__(self, **kwargs): | |
for arg_name, arg_value in kwargs.items(): | |
print(f"{arg_name} = {arg_value}") | |
class UpdaterA(BaseUpdater): | |
def __init__( | |
self, | |
interval: float, | |
table_name: str = "hello", | |
dummy_size: int = 12 | |
): | |
super().__init__(interval=interval, table_name=table_name, dummy_size=dummy_size) | |
class UpdaterB(BaseUpdater): | |
def __init__( | |
self, | |
max_tries: int, | |
query: str = "qwek", | |
asd: float = 123.12 | |
): | |
super().__init__(max_tries=max_tries, query=query, asd=asd) | |
class UpdaterC(BaseUpdater): | |
def __init__( | |
self, | |
interval: float, | |
max_tries: int, | |
table_name: str = "yolo", | |
another: int = 15 | |
): | |
super().__init__(interval=interval, max_tries=max_tries, table_name=table_name, another=another) | |
class UpdaterD(BaseUpdater): | |
def __init__( | |
self, | |
not_annotated | |
): | |
super().__init__(not_annotated=not_annotated) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment