Skip to content

Instantly share code, notes, and snippets.

@JACKHAHA363
Last active February 22, 2019 06:23
Show Gist options
  • Save JACKHAHA363/46831312f26b8d1df87c6a4baaeeb3c8 to your computer and use it in GitHub Desktop.
Save JACKHAHA363/46831312f26b8d1df87c6a4baaeeb3c8 to your computer and use it in GitHub Desktop.
A principled way to have config dictionary that can be saved/restored, support pre-defined type and default values, and support cmd parsing
import yaml
import argparse
class _ConfigDict(dict):
""" A subclass of dict that supports required args
"""
# Change this to customize it
fields = []
@classmethod
def get_cmd_parser(cls, parser=None):
""" Return a argparse.ArgumentParser or add argument to existing one """
if parser is None:
parser = argparse.ArgumentParser()
for field in cls.fields:
parser.add_argument('-' + field[0], type=field[1], default=field[2],
help=field[3])
return parser
def __init__(self, **kwargs):
""" set up fields """
self.validate_config_dict(kwargs)
super(_ConfigDict, self).__init__(**kwargs)
@classmethod
def from_yaml(cls, yaml_file):
""" Load from yaml file """
with open(yaml_file, 'r') as stream:
try:
return cls(**yaml.load(stream))
except yaml.YAMLError as exc:
print(exc)
def to_yaml(self, yaml_file):
""" Save to yaml file """
data = {k: v for k,v in self.items()}
with open(yaml_file, 'w') as out:
yaml.dump(data, out, default_flow_style=False)
@classmethod
def validate_config_dict(cls, config_dict):
""" Given a dict, verify whether it has all the fields """
required_field = [field[0] for field in cls.fields]
kwargs_field = list(config_dict.keys())
missing_fields = set(required_field) - set(kwargs_field)
assert len(missing_fields) == 0, 'Missing fields for game config: ' + str(missing_fields)
def display(self):
""" Display the config """
for k, v in self.items():
print('{}:{}'.format(k, v))
def config_dict(name, fields):
""" Return a class that has spefied fields """
return type(name, (_ConfigDict,), {'fields': fields})
if __name__ == '__main__':
# Create customizable config dict class
my_fields = [('lr', float, 0.01, 'learning rate'),
('mom', float, 0.9, 'sgd momentum')]
MyConfigDict = config_dict('MyConfig', my_fields)
parser = MyConfigDict.get_cmd_parser()
args = parser.parse_args()
# Create config dict
config = MyConfigDict(**args.__dict__)
print('old config')
config.display()
# save and load
config.to_yaml('config.yaml')
# Load use class
config_new = MyConfigDict.from_yaml('config.yaml')
print('new config')
config_new.display()
# Load use object
config_new_2 = config.from_yaml('config.yaml')
print('new config 2')
config_new_2.display()
# Create by hand
config_by_args = MyConfigDict(lr=5, mom=0.9)
print('create config by directly input args')
config_by_args.display()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment