Last active
October 2, 2020 07:09
-
-
Save dnlcrl/118870bd60a08232465b67f7f9676419 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import sys | |
import os | |
print("Warning: Do not use this for BatchNorm-using models!") | |
model_names = sys.argv[1:] | |
if len(model_names) < 2: | |
print("need at least two models to average") | |
prefix = os.path.commonprefix(model_names) | |
suffix = os.path.commonprefix([n[::-1] for n in model_names])[::-1] | |
prefix_len = len(prefix) | |
suffix_len = len(suffix) | |
merged_name = prefix + 'averaged_' + '_'.join([n[prefix_len:-suffix_len] for n in model_names]) + suffix | |
avg = torch.load(model_names[0], map_location='cpu') | |
keys = set(avg.keys()) | |
for n in model_names[1:]: | |
sd = torch.load(n, map_location='cpu') | |
assert set(sd.keys()) == keys, "model have different parameters" | |
for k in keys: | |
avg[k] += sd[k] | |
for k in keys: | |
avg[k] /= len(model_names) | |
print("Merged model:", merged_name) | |
torch.save(avg, merged_name) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment