Skip to content

Instantly share code, notes, and snippets.

@dingran
Created October 28, 2019 02:00
Show Gist options
  • Save dingran/9a7a44600022d91950a0dc19617a8312 to your computer and use it in GitHub Desktop.
Save dingran/9a7a44600022d91950a0dc19617a8312 to your computer and use it in GitHub Desktop.
Function to expand master hyperparamter to create a list of hp dicts
import itertools
import copy
### Example usage ###
# In [13]: import pprint
# ...: master_hp_dict = dict(anneal_KLD=Expand(True, False), bs=1024, n_epochs=Expand(2, 10, 20), mlp=[10,30])
# ...: hp_list = dict_expand(master_hp_dict)
# ...: pprint.pprint(hp_list)
# [{'anneal_KLD': True, 'bs': 1024, 'mlp': [10, 30], 'n_epochs': 2},
# {'anneal_KLD': True, 'bs': 1024, 'mlp': [10, 30], 'n_epochs': 10},
# {'anneal_KLD': True, 'bs': 1024, 'mlp': [10, 30], 'n_epochs': 20},
# {'anneal_KLD': False, 'bs': 1024, 'mlp': [10, 30], 'n_epochs': 2},
# {'anneal_KLD': False, 'bs': 1024, 'mlp': [10, 30], 'n_epochs': 10},
# {'anneal_KLD': False, 'bs': 1024, 'mlp': [10, 30], 'n_epochs': 20}]
class Expand:
def __init__(self, *values):
# self.values = sorted(values)
self.values = values
def __repr__(self):
return str(self.values)
def __str__(self):
return str(self.values)
def dict_expand(master_d):
d_list = []
fields_to_expand = [k for k in master_d.keys() if isinstance(master_d[k], Expand)]
fields_lens = [len(master_d[k].values) for k in fields_to_expand]
fields_id_lists = [list(range(f_len)) for f_len in fields_lens]
iter = itertools.product(*fields_id_lists)
for i in iter:
update_dict = dict()
for field_id, val_id in enumerate(i):
field = fields_to_expand[field_id]
update_dict[field] = master_d[field].values[val_id]
d = copy.deepcopy(master_d)
d.update(update_dict)
d_list.append(d)
return d_list
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment