Skip to content

Instantly share code, notes, and snippets.

@Interpause
Created April 1, 2023 07:25
Show Gist options
  • Save Interpause/342bf354507b6f8ae4cd85c7a89b99df to your computer and use it in GitHub Desktop.
Save Interpause/342bf354507b6f8ae4cd85c7a89b99df to your computer and use it in GitHub Desktop.
PyTorch Tensor and Model Weights Deterministic Hexdigest
"""Utilities to hash `torch.Tensor` and `nn.Module.state_dict()`."""
import numpy as np
import torch
import torch.nn as nn
from xxhash import xxh3_64_hexdigest as hexdigest
__all__ = ["hash_tensor", "hash_model"]
def hash_tensor(x: torch.Tensor) -> str:
"""Returns deterministic hexdigest of tensor."""
# Ops used here are to minimize copies.
is_float = torch.is_floating_point(x)
# Using `x.numpy(force=True).data` is faster than `bytes(x.flatten().byte())`.
x: np.ndarray = x.numpy(force=True)
# At risk of collision, decrease precision due to floating point error.
if is_float:
x = np.interp(x, (x.min(), x.max()), (0, 255)).astype(np.uint8, order="C")
# Standardize to contiguous array for deterministic hash.
x = np.asarray(x, order="C")
return hexdigest(x.data, seed=0)
def hash_model(m: nn.Module) -> str:
"""Returns deterministic hexdigest of model based on weights."""
return hexdigest(
"".join(f"{k}{hash_tensor(v)}" for k, v in sorted(m.state_dict().items())),
seed=0,
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment