Skip to content

Instantly share code, notes, and snippets.

@cjlovering
Created December 8, 2020 17:31
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save cjlovering/25b72af90db47aaa1a72caed45a17c62 to your computer and use it in GitHub Desktop.
Save cjlovering/25b72af90db47aaa1a72caed45a17c62 to your computer and use it in GitHub Desktop.
Haiku Merge Model Parameters
import haiku as hk
def merge_pretrained_params(new_params: hk.Params, pre_params: hk.Params) -> hk.Params:
"""Merges pre-trained `pre_params` parameters into new_parameters `new_params`.
The names of the pre_params and new_params are (a) selected intentionally
or otherwise (b) the reused modules are called before new modules
s.t. that they end up with the same names.
"""
# Filter out the parameters from the pre-trained model that aren't used
# because the optimizer expects the structure of the new_params given:
# adding new values to the flatmap will cause errors during sgd.
new_param_keys = set(new_params.keys())
used_only = lambda module_name, name, value: module_name in new_param_keys
used_pre_params = hk.data_structures.filter(used_only, pre_params)
# replaced (untrained) parameters in new params with the pretrained ones.
return hk.data_structures.merge(new_params, used_pre_params)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment