Skip to content

Instantly share code, notes, and snippets.

@rohan-varma
Created April 24, 2020 04:51
Show Gist options
  • Save rohan-varma/a0a75e9a0fbe9ccc7420b04bff4a7212 to your computer and use it in GitHub Desktop.
Save rohan-varma/a0a75e9a0fbe9ccc7420b04bff4a7212 to your computer and use it in GitHub Desktop.
PyTorch check if 2 models have the same state_dict
def validate_state_dicts(model_state_dict_1, model_state_dict_2):
if len(model_state_dict_1) != len(model_state_dict_2):
logger.info(
f"Length mismatch: {len(model_state_dict_1)}, {len(model_state_dict_2)}"
)
return False
# Replicate modules have "module" attached to their keys, so strip these off when comparing to local model.
if next(iter(model_state_dict_1.keys())).startswith("module"):
model_state_dict_1 = {
k[len("module") + 1 :]: v for k, v in model_state_dict_1.items()
}
if next(iter(model_state_dict_2.keys())).startswith("module"):
model_state_dict_2 = {
k[len("module") + 1 :]: v for k, v in model_state_dict_2.items()
}
for ((k_1, v_1), (k_2, v_2)) in zip(
model_state_dict_1.items(), model_state_dict_2.items()
):
if k_1 != k_2:
logger.info(f"Key mismatch: {k_1} vs {k_2}")
return False
# convert both to the same CUDA device
if str(v_1.device) != "cuda:0":
v_1 = v_1.to("cuda:0" if torch.cuda.is_available() else "cpu")
if str(v_2.device) != "cuda:0":
v_2 = v_2.to("cuda:0" if torch.cuda.is_available() else "cpu")
if not torch.allclose(v_1, v_2):
logger.info(f"Tensor mismatch: {v_1} vs {v_2}")
return False
@Xi-HHHM
Copy link

Xi-HHHM commented Feb 29, 2024

There is no True return in this function.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment