Skip to content

Instantly share code, notes, and snippets.

@jaekookang
Last active November 30, 2020 18:20
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jaekookang/a52c6c69f384969575dab918e01f1b05 to your computer and use it in GitHub Desktop.
Save jaekookang/a52c6c69f384969575dab918e01f1b05 to your computer and use it in GitHub Desktop.
Python Import Hyper-parameters as a Dotted Dictionary
'''
Make dotted dictionary from hyper-parameter file (.yaml)
This `HParams` class lets you import your hyper-parameter file in .yaml
into Python as a dotted dicionary. There are different approaches to implement
dotted dictionary, but there are not many methods that I was not satisfied with.
For example, `DotMap` module doesn't seem to be stable and becomes errorful silently.
This was the main motivation for developing my own version of dotted dictionary.
Purpose:
- Import parameters as a dotted dictionary
- Edit and save as .yaml format
How to use:
Basic
```
hp = HParams(PARAM_FILE, CONFIG)
# PARAM_FILE: hyper parameter file in .yaml format
# PARAM_FILE can be given as a python dictionary too
# CONFIG: if you have multiple sets of parameters, you can specify them by their name
# eg, 'default', 'deep_layers', 'shallow_layers' etc.
```
Steps
- 1. Install ruamel.yaml (eg. `pip install ruamel.yaml`).
- 2. Create yaml file and provide hyper-parameters. For example:
```
# hparams.yaml
default:
batch_size: 100
epoch: 1000
optimizer: Adam
```
- 3. Load `hparams.yaml` using `HParams`.
```
# Load from yaml file
hp = HParams('hparams.yaml', 'default')
# Load from python dictionary
# hp = HParams(param_dict, 'default')
print(hp.batch_size)
# out: 100
```
References:
- https://hanxiao.io/2017/12/21/Use-HParams-and-YAML-to-Better-Manage-Hyperparameters-in-Tensorflow/
- https://stackoverflow.com/questions/2352181/how-to-use-a-dot-to-access-members-of-dictionary
- https://pypi.org/project/dotmap/
2020-04-30 jkang, created
'''
import ruamel.yaml
YAML = ruamel.yaml.YAML
class Map(dict):
'''Dot dictionary
Example:
d = Map({'layer': 3, 'batch': 100, 'pooling': 'average'})
Retrieved and edited from: https://stackoverflow.com/a/32107024
'''
def __init__(self, *args, **kwargs):
super(Map, self).__init__(*args, **kwargs)
for arg in args:
if isinstance(arg, dict):
for k, v in arg.items():
self[k] = v
if kwargs:
for k, v in kwargs.items():
self[k] = v
def __getattr__(self, attr):
return self.get(attr)
def __setattr__(self, key, value):
self.__setitem__(key, value)
def __setitem__(self, key, value):
super(Map, self).__setitem__(key, value)
self.__dict__.update({key: value})
def __delattr__(self, item):
self.__delitem__(item)
def __delitem__(self, key):
super(Map, self).__delitem__(key)
del self.__dict__[key]
class HParams(Map):
'''Hyperparameter initializer
## Load
>> hp = HParams('hparams.yaml') # or
>> hp = HParams('hparams.yaml', DEFAULT_CONFIG)
## Save
>> with open('hparams_new.yaml', 'w') as f:
YAML().dump(dict(hp), f)
'''
def __init__(self, yaml_file, config_name=None):
super().__init__()
if isinstance(yaml_file, dict):
# Read dictionary directly
dictionary = yaml_file
else:
# Read yaml file
with open(yaml_file) as f:
if config_name is not None:
dictionary = YAML().load(f)[config_name]
else:
dictionary = YAML().load(f)
# Make dotted dictionary
stack = [(k, v) for k, v in dictionary.items()]
while stack: # recursion
key, val = stack.pop()
if isinstance(val, dict):
self.__setattr__(key, Map(dict(val)))
stack.extend([(k, v) for k, v in val.items()])
else:
self.__setattr__(key, val)
if __name__ == '__main__':
# Load parameters
hp = HParams('hparams.yaml', 'default')
print(hp.optimizer)
# Edit & Save parameters
hp.batch_size = 10000
with open('hparams_edited.yaml', 'w') as f:
YAML().dump(dict(hp), f)
print('Done')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment