Skip to content

Instantly share code, notes, and snippets.

@iqiancheng
Created January 12, 2024 08:59
Show Gist options
  • Save iqiancheng/edbb0c7b409287af2e7dcadc3293dcfa to your computer and use it in GitHub Desktop.
Save iqiancheng/edbb0c7b409287af2e7dcadc3293dcfa to your computer and use it in GitHub Desktop.
对比两个模型权重是否一致 torch.equal
import torch
model1 = torch.load('output/exp-step00000100/unet/diffusion_pytorch_model.bin')
print('load model1 success')
model1 = torch.load('output/exp-step00000100/unet/diffusion_pytorch_model.bin')
print('load model2 success')
def compare_models(dict1, dict2):
models_differ = False
for (key1, item1), (key2, item2) in zip(dict1.items(), dict2.items()):
if key1 == key2 and torch.equal(item1, item2):
pass
else:
models_differ = True
print('Mismatch found in item:', key1)
return not models_differ
is_same = compare_models(model1, model2)
print('Are the models the same?', is_same)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment