Last active
July 14, 2020 10:04
-
-
Save ditwoo/5de19670d9946c80916dee75e93ef545 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import collections | |
from typing import List | |
def checkpoints_weights_avg(inputs: List[str]): | |
"""Loads checkpoints from inputs and returns a model with averaged weights. | |
Args: | |
inputs: An iterable of string paths of checkpoints to load from. | |
Returns: | |
A dict of string keys mapping to various values. The 'model' key | |
from the returned dict should correspond to an OrderedDict mapping | |
string parameter names to torch Tensors. | |
""" | |
params_dict = collections.OrderedDict() | |
params_keys = None | |
new_state = None | |
num_models = len(inputs) | |
for f in inputs: | |
state = torch.load( | |
f, | |
map_location=( | |
lambda s, _: torch.serialization.default_restore_location(s, 'cpu') | |
), | |
) | |
# Copies over the settings from the first checkpoint | |
if new_state is None: | |
new_state = state | |
model_params = state['model_state_dict'] | |
model_params_keys = list(model_params.keys()) | |
if params_keys is None: | |
params_keys = model_params_keys | |
elif params_keys != model_params_keys: | |
raise KeyError( | |
'For checkpoint {}, expected list of params: {}, ' | |
'but found: {}'.format(f, params_keys, model_params_keys) | |
) | |
for k in params_keys: | |
p = model_params[k] | |
if isinstance(p, torch.HalfTensor): | |
p = p.float() | |
if k not in params_dict: | |
params_dict[k] = p.clone() | |
# NOTE: clone() is needed in case of p is a shared parameter | |
else: | |
params_dict[k] += p | |
averaged_params = collections.OrderedDict() | |
for k, v in params_dict.items(): | |
averaged_params[k] = v | |
averaged_params[k].div_(num_models) | |
new_state['model_state_dict'] = averaged_params | |
return new_state |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment