Skip to content

Instantly share code, notes, and snippets.

@visionNoob
Created August 18, 2020 05:04
Show Gist options
  • Save visionNoob/830729f4269bd060dc6bb7901c301d1a to your computer and use it in GitHub Desktop.
Save visionNoob/830729f4269bd060dc6bb7901c301d1a to your computer and use it in GitHub Desktop.
import argparse
import collections
from functools import reduce, partial
from operator import getitem
from logger import setup_logging
from omegaconf import OmegaConf
class ConfigParser:
def __init__(self, config, modification=None):
"""
class to parse configuration json file. Handles hyperparameters for training, initializations of modules, checkpoint saving
and logging module.
:param config: Dict containing configurations, hyperparameters for training. contents of `config.json` file for example.
:param modification: Dict keychain:value, specifying position values to be replaced from config dict.
"""
# load config file and apply modification
self._config = _update_config(config, modification)
@classmethod
def from_args(cls, args, options=''):
"""
Initialize this class from some cli arguments. Used in train, test.
"""
for opt in options:
args.add_argument(*opt.flags, default=None, type=opt.type)
if not isinstance(args, tuple):
args = args.parse_args()
config = OmegaConf.load("./test_config.yaml")
modification = {opt.target: getattr(
args, _get_opt_name(opt.flags)) for opt in options}
return cls(config, modification=modification)
def init_obj(self, name, module, *args, **kwargs):
"""
Finds a function handle with the name given as 'type' in config, and returns the
instance initialized with corresponding arguments given.
`object = config.init_obj('name', module, a, b=1)`
is equivalent to
`object = module.name(a, b=1)`
"""
module_name = self[name]['type']
module_args = dict(self[name]['args'])
assert all([k not in module_args for k in kwargs]
), 'Overwriting kwargs given in config file is not allowed'
module_args.update(kwargs)
return getattr(module, module_name)(*args, **module_args)
def init_ftn(self, name, module, *args, **kwargs):
"""
Finds a function handle with the name given as 'type' in config, and returns the
function with given arguments fixed with functools.partial.
`function = config.init_ftn('name', module, a, b=1)`
is equivalent to
`function = lambda *args, **kwargs: module.name(a, *args, b=1, **kwargs)`.
"""
module_name = self[name]['type']
module_args = dict(self[name]['args'])
assert all([k not in module_args for k in kwargs]
), 'Overwriting kwargs given in config file is not allowed'
module_args.update(kwargs)
return partial(getattr(module, module_name), *args, **module_args)
def __getitem__(self, name):
"""Access items like ordinary dict."""
return self.config[name]
# setting read-only attributes
@property
def config(self):
return self._config
# helper functions to update config dict with custom cli options
def _update_config(config, modification):
if modification is None:
return config
for k, v in modification.items():
if v is not None:
_set_by_path(config, k, v)
return config
def _get_opt_name(flags):
for flg in flags:
if flg.startswith('--'):
return flg.replace('--', '')
return flags[0].replace('--', '')
def _set_by_path(tree, keys, value):
"""Set a value in a nested object in tree by sequence of keys."""
keys = keys.split(';')
_get_by_path(tree, keys[:-1])[keys[-1]] = value
def _get_by_path(tree, keys):
"""Access a nested object in tree by sequence of keys."""
return reduce(getitem, keys, tree)
def main(config):
print(config['data_loader'])
if __name__ == '__main__':
args = argparse.ArgumentParser(description='Model Train Phase')
# custom cli options to modify configuation from default values fiben in json file.
CustomArgs = collections.namedtuple('CustomArgs', 'flags type target')
options = [
CustomArgs(['--bs', '--batch_size'], type=int,
target='data_loader;args;batch_size')
]
config = ConfigParser.from_args(args, options)
main(config)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment