Last active
November 30, 2020 18:20
-
-
Save jaekookang/a52c6c69f384969575dab918e01f1b05 to your computer and use it in GitHub Desktop.
Python Import Hyper-parameters as a Dotted Dictionary
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
''' | |
Make dotted dictionary from hyper-parameter file (.yaml) | |
This `HParams` class lets you import your hyper-parameter file in .yaml | |
into Python as a dotted dicionary. There are different approaches to implement | |
dotted dictionary, but there are not many methods that I was not satisfied with. | |
For example, `DotMap` module doesn't seem to be stable and becomes errorful silently. | |
This was the main motivation for developing my own version of dotted dictionary. | |
Purpose: | |
- Import parameters as a dotted dictionary | |
- Edit and save as .yaml format | |
How to use: | |
Basic | |
``` | |
hp = HParams(PARAM_FILE, CONFIG) | |
# PARAM_FILE: hyper parameter file in .yaml format | |
# PARAM_FILE can be given as a python dictionary too | |
# CONFIG: if you have multiple sets of parameters, you can specify them by their name | |
# eg, 'default', 'deep_layers', 'shallow_layers' etc. | |
``` | |
Steps | |
- 1. Install ruamel.yaml (eg. `pip install ruamel.yaml`). | |
- 2. Create yaml file and provide hyper-parameters. For example: | |
``` | |
# hparams.yaml | |
default: | |
batch_size: 100 | |
epoch: 1000 | |
optimizer: Adam | |
``` | |
- 3. Load `hparams.yaml` using `HParams`. | |
``` | |
# Load from yaml file | |
hp = HParams('hparams.yaml', 'default') | |
# Load from python dictionary | |
# hp = HParams(param_dict, 'default') | |
print(hp.batch_size) | |
# out: 100 | |
``` | |
References: | |
- https://hanxiao.io/2017/12/21/Use-HParams-and-YAML-to-Better-Manage-Hyperparameters-in-Tensorflow/ | |
- https://stackoverflow.com/questions/2352181/how-to-use-a-dot-to-access-members-of-dictionary | |
- https://pypi.org/project/dotmap/ | |
2020-04-30 jkang, created | |
''' | |
import ruamel.yaml | |
YAML = ruamel.yaml.YAML | |
class Map(dict): | |
'''Dot dictionary | |
Example: | |
d = Map({'layer': 3, 'batch': 100, 'pooling': 'average'}) | |
Retrieved and edited from: https://stackoverflow.com/a/32107024 | |
''' | |
def __init__(self, *args, **kwargs): | |
super(Map, self).__init__(*args, **kwargs) | |
for arg in args: | |
if isinstance(arg, dict): | |
for k, v in arg.items(): | |
self[k] = v | |
if kwargs: | |
for k, v in kwargs.items(): | |
self[k] = v | |
def __getattr__(self, attr): | |
return self.get(attr) | |
def __setattr__(self, key, value): | |
self.__setitem__(key, value) | |
def __setitem__(self, key, value): | |
super(Map, self).__setitem__(key, value) | |
self.__dict__.update({key: value}) | |
def __delattr__(self, item): | |
self.__delitem__(item) | |
def __delitem__(self, key): | |
super(Map, self).__delitem__(key) | |
del self.__dict__[key] | |
class HParams(Map): | |
'''Hyperparameter initializer | |
## Load | |
>> hp = HParams('hparams.yaml') # or | |
>> hp = HParams('hparams.yaml', DEFAULT_CONFIG) | |
## Save | |
>> with open('hparams_new.yaml', 'w') as f: | |
YAML().dump(dict(hp), f) | |
''' | |
def __init__(self, yaml_file, config_name=None): | |
super().__init__() | |
if isinstance(yaml_file, dict): | |
# Read dictionary directly | |
dictionary = yaml_file | |
else: | |
# Read yaml file | |
with open(yaml_file) as f: | |
if config_name is not None: | |
dictionary = YAML().load(f)[config_name] | |
else: | |
dictionary = YAML().load(f) | |
# Make dotted dictionary | |
stack = [(k, v) for k, v in dictionary.items()] | |
while stack: # recursion | |
key, val = stack.pop() | |
if isinstance(val, dict): | |
self.__setattr__(key, Map(dict(val))) | |
stack.extend([(k, v) for k, v in val.items()]) | |
else: | |
self.__setattr__(key, val) | |
if __name__ == '__main__': | |
# Load parameters | |
hp = HParams('hparams.yaml', 'default') | |
print(hp.optimizer) | |
# Edit & Save parameters | |
hp.batch_size = 10000 | |
with open('hparams_edited.yaml', 'w') as f: | |
YAML().dump(dict(hp), f) | |
print('Done') | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment