Skip to content

Instantly share code, notes, and snippets.

@pltrdy
Created September 30, 2019 16:15
Show Gist options
  • Save pltrdy/4454dd09183bb4734cc83f02ef133566 to your computer and use it in GitHub Desktop.
Save pltrdy/4454dd09183bb4734cc83f02ef133566 to your computer and use it in GitHub Desktop.
Comparing ONMT checkpoint files
#!/usr/bin/env python3
import torch
def model_equals(model1, model2):
for p1, p2 in zip(model1.values(), model2.values()):
if p1.data.ne(p2.data).sum() > 0:
return False
return True
def cmp_models(model1, model2):
print("Loading model1...")
m1 = torch.load(model1, map_location='cpu')['model']
print("Loading model2...")
m2 = torch.load(model2, map_location='cpu')['model']
print("Calculating differences...")
print(model_equals(m1, m2))
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("model1")
parser.add_argument("model2")
args = parser.parse_args()
cmp_models(args.model1, args.model2)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment