Skip to content

Instantly share code, notes, and snippets.

@vgoklani
Forked from jxmorris12/torch_ddp_verify.py
Created April 17, 2024 22:21
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save vgoklani/e9b6564bbe7d9f1dddb6ced3e3d2a29b to your computer and use it in GitHub Desktop.
Save vgoklani/e9b6564bbe7d9f1dddb6ced3e3d2a29b to your computer and use it in GitHub Desktop.
verify parameter weights & gradients in pytorch
def verify_ddp_weights_equal(model: torch.nn.Module, atol: float = 1e-5) -> None:
if hasattr(model, "module"):
model = model.module
world_size = get_world_size()
for name, param in model.named_parameters():
gathered_param = gather(param).reshape((world_size, -1))
absolute_diffs = (gathered_param[None, 0, :] - gathered_param).abs()
rank_params_eq = (absolute_diffs < atol).all()
assert rank_params_eq, f"❌ param [{name}] not equal - got max_absolute_diff={absolute_diffs.max()}"
###################################################################################################################
gathered_param_grad = gather(param.grad).reshape((world_size, -1))
absolute_grad_diffs = (gathered_param_grad[None, 0, :] - gathered_param_grad).abs()
rank_grad_params_eq = (absolute_grad_diffs < atol).all()
assert rank_grad_params_eq, f"❌ param [{name}] grad not equal - got max_absolute_diff={absolute_grad_diffs.max()}"
###################################################################################################################
print0("Verified DDP parameter correctness ✅")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment