Created
December 8, 2020 17:31
-
-
Save cjlovering/25b72af90db47aaa1a72caed45a17c62 to your computer and use it in GitHub Desktop.
Haiku Merge Model Parameters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import haiku as hk | |
def merge_pretrained_params(new_params: hk.Params, pre_params: hk.Params) -> hk.Params: | |
"""Merges pre-trained `pre_params` parameters into new_parameters `new_params`. | |
The names of the pre_params and new_params are (a) selected intentionally | |
or otherwise (b) the reused modules are called before new modules | |
s.t. that they end up with the same names. | |
""" | |
# Filter out the parameters from the pre-trained model that aren't used | |
# because the optimizer expects the structure of the new_params given: | |
# adding new values to the flatmap will cause errors during sgd. | |
new_param_keys = set(new_params.keys()) | |
used_only = lambda module_name, name, value: module_name in new_param_keys | |
used_pre_params = hk.data_structures.filter(used_only, pre_params) | |
# replaced (untrained) parameters in new params with the pretrained ones. | |
return hk.data_structures.merge(new_params, used_pre_params) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment