Created
January 12, 2024 08:59
-
-
Save iqiancheng/edbb0c7b409287af2e7dcadc3293dcfa to your computer and use it in GitHub Desktop.
对比两个模型权重是否一致 torch.equal
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
model1 = torch.load('output/exp-step00000100/unet/diffusion_pytorch_model.bin') | |
print('load model1 success') | |
model1 = torch.load('output/exp-step00000100/unet/diffusion_pytorch_model.bin') | |
print('load model2 success') | |
def compare_models(dict1, dict2): | |
models_differ = False | |
for (key1, item1), (key2, item2) in zip(dict1.items(), dict2.items()): | |
if key1 == key2 and torch.equal(item1, item2): | |
pass | |
else: | |
models_differ = True | |
print('Mismatch found in item:', key1) | |
return not models_differ | |
is_same = compare_models(model1, model2) | |
print('Are the models the same?', is_same) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment