Skip to content

Instantly share code, notes, and snippets.

@dnlcrl
Last active October 2, 2020 07:09
Show Gist options
  • Save dnlcrl/118870bd60a08232465b67f7f9676419 to your computer and use it in GitHub Desktop.
Save dnlcrl/118870bd60a08232465b67f7f9676419 to your computer and use it in GitHub Desktop.
import torch
import sys
import os
print("Warning: Do not use this for BatchNorm-using models!")
model_names = sys.argv[1:]
if len(model_names) < 2:
print("need at least two models to average")
prefix = os.path.commonprefix(model_names)
suffix = os.path.commonprefix([n[::-1] for n in model_names])[::-1]
prefix_len = len(prefix)
suffix_len = len(suffix)
merged_name = prefix + 'averaged_' + '_'.join([n[prefix_len:-suffix_len] for n in model_names]) + suffix
avg = torch.load(model_names[0], map_location='cpu')
keys = set(avg.keys())
for n in model_names[1:]:
sd = torch.load(n, map_location='cpu')
assert set(sd.keys()) == keys, "model have different parameters"
for k in keys:
avg[k] += sd[k]
for k in keys:
avg[k] /= len(model_names)
print("Merged model:", merged_name)
torch.save(avg, merged_name)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment