Skip to content

Instantly share code, notes, and snippets.

@arose13
Created April 2, 2024 20:01
Show Gist options
  • Save arose13/8e69c3a8e5e4bb96c67a80a57a0c4f83 to your computer and use it in GitHub Desktop.
Save arose13/8e69c3a8e5e4bb96c67a80a57a0c4f83 to your computer and use it in GitHub Desktop.
This can check if the state_dict from 2 pytorch models are the same
def is_state_dict_equal(dict1, dict2):
import torch
for key in dict1:
if key not in dict2:
print(f"Key {key} not in second dict")
return False
if not torch.all(torch.eq(dict1[key], dict2[key])):
print(f"Difference in values for key {key}")
return False
return True
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment