Skip to content

Instantly share code, notes, and snippets.

@Erotemic
Last active February 19, 2023 02:09
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Erotemic/7bd7363237cb93524b87dda978d69cb4 to your computer and use it in GitHub Desktop.
Save Erotemic/7bd7363237cb93524b87dda978d69cb4 to your computer and use it in GitHub Desktop.
jsonargparse_scriptconfig_poc.py
import jsonargparse
import inspect
from jsonargparse.parameter_resolvers import ParamData
from jsonargparse.signatures import get_signature_parameters
from jsonargparse.signatures import get_doc_short_description
from typing import List, Set, Union, Optional, Tuple, Type, Any
from jsonargparse.util import get_import_path, iter_to_set_str
from argparse import SUPPRESS
from jsonargparse.typing import is_final_class
import scriptconfig as scfg
from jsonargparse.actions import _ActionConfigLoad # NOQA
from jsonargparse.optionals import get_doc_short_description # NOQA
from jsonargparse.parameter_resolvers import (ParamData, get_parameter_origins, get_signature_parameters,) # NOQA
from jsonargparse.typehints import ActionTypeHint, LazyInitBaseClass, is_optional # NOQA
from jsonargparse.typing import is_final_class # NOQA
from jsonargparse.util import LoggerProperty, get_import_path, is_subclass, iter_to_set_str # NOQA
from jsonargparse.signatures import is_factory_class, is_pure_dataclass
kinds = inspect._ParameterKind
inspect_empty = inspect._empty
class ScriptConfigArgumentParser(jsonargparse.ArgumentParser):
"""
Keep in sync with ~/code/watch/watch/utils/lightning_ext/lightning_cli_ext.py
See if we can do something to land this functionality upstream
"""
def _add_signature_arguments(
self,
function_or_class,
method_name,
nested_key,
as_group: bool = True,
as_positional: bool = False,
skip=None,
fail_untyped: bool = True,
sub_configs: bool = False,
instantiate: bool = True,
linked_targets=None
) -> list[str]:
"""Adds arguments from parameters of objects based on signatures and docstrings.
Args:
function_or_class: Object from which to add arguments.
method_name: Class method from which to add arguments.
nested_key: Key for nested namespace.
as_group: Whether arguments should be added to a new argument group.
as_positional: Whether to add required parameters as positional arguments.
skip: Names of parameters that should be skipped.
fail_untyped: Whether to raise exception if a required parameter does not have a type.
sub_configs: Whether subclass type hints should be loadable from inner config file.
instantiate: Whether the class group should be instantiated by :code:`instantiate_classes`.
Returns:
The list of arguments added.
Raises:
ValueError: When there are required parameters without at least one valid type.
"""
## Create group if requested ##
doc_group = get_doc_short_description(function_or_class, method_name, self.logger)
component = getattr(function_or_class, method_name) if method_name else function_or_class
group = self._create_group_if_requested(component, nested_key, as_group, doc_group, instantiate=instantiate)
params = get_signature_parameters(function_or_class, method_name, logger=self.logger)
if hasattr(function_or_class, '__scriptconfig__'):
# print(f'Parse scriptconfig params for: function_or_class={function_or_class}')
# Specify our own set of explicit parameters here
# pretend like things in scriptconfig are from the signature
import inspect
# Hack to insert our method for explicit parameterization
config_cls = function_or_class.__scriptconfig__
if hasattr(config_cls, '__default__'):
default = config_cls.__default__
else:
default = config_cls.default
for key, value in default.items():
# TODO can we make this compatability better?
# Can we actually use the scriptconfig argparsing action?
type = value.parsekw['type']
if type is None or not isinstance(type, type):
annotation = inspect._empty
else:
annotation = type
param = ParamData(
name=key,
annotation=annotation,
kind=inspect.Parameter.KEYWORD_ONLY,
default=value.value,
doc=value.parsekw['help'],
component=function_or_class.__init__,
parent=function_or_class,
)
param._scfg_value = value
# print(f'add scriptconfig {key=}')
params.append(param)
else:
# print(f'Parse NON-scriptconfig params for: function_or_class={function_or_class}')
...
## Add parameter arguments ##
added_args = []
for param in params:
self._add_signature_parameter(
group,
nested_key,
param,
added_args,
skip,
fail_untyped=fail_untyped,
sub_configs=sub_configs,
linked_targets=linked_targets,
as_positional=as_positional,
)
# import ubelt as ub
# print('added_args = {}'.format(ub.repr2(added_args, nl=1)))
return added_args
def _add_signature_parameter(
self,
group,
nested_key: Optional[str],
param,
added_args: List[str],
skip: Optional[Set[str]] = None,
fail_untyped: bool = True,
as_positional: bool = False,
sub_configs: bool = False,
instantiate: bool = True,
linked_targets: Optional[Set[str]] = None,
default: Any = inspect_empty,
**kwargs
):
name = param.name
kind = param.kind
annotation = param.annotation
if default == inspect_empty:
default = param.default
is_required = default == inspect_empty
src = get_parameter_origins(param.component, param.parent)
skip_message = f'Skipping parameter "{name}" from "{src}" because of: '
if not fail_untyped and annotation == inspect_empty:
annotation = Any
default = None if is_required else default
is_required = False
if is_required and linked_targets is not None and name in linked_targets:
default = None
is_required = False
if kind in {kinds.VAR_POSITIONAL, kinds.VAR_KEYWORD} or \
(not is_required and name[0] == '_') or \
(annotation == inspect_empty and not is_required and default is None):
return
elif skip and name in skip:
self.logger.debug(skip_message + 'Parameter requested to be skipped.')
return
if is_factory_class(default):
default = param.parent.__dataclass_fields__[name].default_factory()
if annotation == inspect_empty and not is_required:
annotation = type(default)
if 'help' not in kwargs:
kwargs['help'] = param.doc
if not is_required:
kwargs['default'] = default
if default is None and not is_optional(annotation, object):
annotation = Optional[annotation]
elif not as_positional:
kwargs['required'] = True
is_subclass_typehint = False
is_final_class_typehint = is_final_class(annotation)
dest = (nested_key + '.' if nested_key else '') + name
args = [dest if is_required and as_positional else '--' + dest]
if param.origin:
group_name = '; '.join(str(o) for o in param.origin)
if group_name in group.parser.groups:
group = group.parser.groups[group_name]
else:
group = group.parser.add_argument_group(
f'Conditional arguments [origins: {group_name}]',
name=group_name,
)
if annotation in {str, int, float, bool} or \
is_subclass(annotation, (str, int, float)) or \
is_final_class_typehint or \
is_pure_dataclass(annotation):
kwargs['type'] = annotation
elif annotation != inspect_empty:
try:
is_subclass_typehint = ActionTypeHint.is_subclass_typehint(annotation, all_subtypes=False)
kwargs['type'] = annotation
sub_add_kwargs: dict = {'fail_untyped': fail_untyped, 'sub_configs': sub_configs}
if is_subclass_typehint:
prefix = name + '.init_args.'
subclass_skip = {s[len(prefix):] for s in skip or [] if s.startswith(prefix)}
sub_add_kwargs['skip'] = subclass_skip
args = ActionTypeHint.prepare_add_argument(
args=args,
kwargs=kwargs,
enable_path=is_subclass_typehint and sub_configs,
container=group,
logger=self.logger,
sub_add_kwargs=sub_add_kwargs,
)
except ValueError as ex:
self.logger.debug(skip_message + str(ex))
if 'type' in kwargs or 'action' in kwargs:
sub_add_kwargs = {
'fail_untyped': fail_untyped,
'sub_configs': sub_configs,
'instantiate': instantiate,
}
if is_final_class_typehint:
kwargs.update(sub_add_kwargs)
if hasattr(param, '_scfg_value'):
value = param._scfg_value
_value = value
def _resolve_alias(name, _value, fuzzy_hyphens):
if _value is None:
aliases = None
short_aliases = None
else:
aliases = _value.alias
short_aliases = _value.short_alias
if isinstance(aliases, str):
aliases = [aliases]
if isinstance(short_aliases, str):
short_aliases = [short_aliases]
long_names = [name] + list((aliases or []))
short_names = list(short_aliases or [])
if fuzzy_hyphens:
# Do we want to allow for people to use hyphens on the CLI?
# Maybe, we can make it optional.
unique_long_names = set(long_names)
modified_long_names = {n.replace('_', '-') for n in unique_long_names}
extra_long_names = modified_long_names - unique_long_names
long_names += sorted(extra_long_names)
nest_prefix = (nested_key + '.' if nested_key else '')
short_option_strings = ['-' + nest_prefix + n for n in short_names]
long_option_strings = ['--' + nest_prefix + n for n in long_names]
option_strings = short_option_strings + long_option_strings
return option_strings
args = _resolve_alias(name, _value, fuzzy_hyphens=0)
# print(f'long_option_strings={long_option_strings}')
# print(f'short_option_strings={short_option_strings}')
action = group.add_argument(*args, **kwargs)
action.sub_add_kwargs = sub_add_kwargs
if is_subclass_typehint and len(subclass_skip) > 0:
action.sub_add_kwargs['skip'] = subclass_skip
added_args.append(dest)
elif is_required and fail_untyped:
raise ValueError(f'Required parameter without a type for "{src}" parameter "{name}".')
# Monkey patch jsonargparse so its subcommands use our extended functionality
jsonargparse.ArgumentParser = ScriptConfigArgumentParser
class MyClassConfig(scfg.DataConfig):
key1 = scfg.Value(1, alias=['key_one'], help='description1')
key2 = scfg.Value(None, help='description2')
key3 = scfg.Value(False, isflag=True, help='description3')
class MyClass:
__scriptconfig__ = MyClassConfig
def __init__(self, **kwargs):
self.config = MyClassConfig(**kwargs)
def main():
parser = ScriptConfigArgumentParser()
parser.add_class_arguments(MyClass, nested_key='my_class', fail_untyped=False, sub_configs=True)
parser.add_argument('--foo', default='bar')
parser.add_argument('-b', '--baz', '--buzz', default='bar')
config = parser.parse_args()
instances = parser.instantiate_classes(config)
print(f'{instances.my_class.__dict__=}')
if __name__ == '__main__':
"""
CommandLine:
python jsonargparse_scriptconfig_poc.py --my_class.key1=foo
python jsonargparse_scriptconfig_poc.py --my_class.key_one=foo
python jsonargparse_scriptconfig_poc.py --buzz=1
python jsonargparse_scriptconfig_poc.py --help
"""
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment