Skip to content

Instantly share code, notes, and snippets.

@ditwoo
Last active July 14, 2020 10:04
Show Gist options
  • Save ditwoo/5de19670d9946c80916dee75e93ef545 to your computer and use it in GitHub Desktop.
Save ditwoo/5de19670d9946c80916dee75e93ef545 to your computer and use it in GitHub Desktop.
import torch
import collections
from typing import List
def checkpoints_weights_avg(inputs: List[str]):
"""Loads checkpoints from inputs and returns a model with averaged weights.
Args:
inputs: An iterable of string paths of checkpoints to load from.
Returns:
A dict of string keys mapping to various values. The 'model' key
from the returned dict should correspond to an OrderedDict mapping
string parameter names to torch Tensors.
"""
params_dict = collections.OrderedDict()
params_keys = None
new_state = None
num_models = len(inputs)
for f in inputs:
state = torch.load(
f,
map_location=(
lambda s, _: torch.serialization.default_restore_location(s, 'cpu')
),
)
# Copies over the settings from the first checkpoint
if new_state is None:
new_state = state
model_params = state['model_state_dict']
model_params_keys = list(model_params.keys())
if params_keys is None:
params_keys = model_params_keys
elif params_keys != model_params_keys:
raise KeyError(
'For checkpoint {}, expected list of params: {}, '
'but found: {}'.format(f, params_keys, model_params_keys)
)
for k in params_keys:
p = model_params[k]
if isinstance(p, torch.HalfTensor):
p = p.float()
if k not in params_dict:
params_dict[k] = p.clone()
# NOTE: clone() is needed in case of p is a shared parameter
else:
params_dict[k] += p
averaged_params = collections.OrderedDict()
for k, v in params_dict.items():
averaged_params[k] = v
averaged_params[k].div_(num_models)
new_state['model_state_dict'] = averaged_params
return new_state
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment