Skip to content

Instantly share code, notes, and snippets.

@so298
Created February 21, 2024 17:18
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save so298/b5fc4127f161dbd65429f5756d771d88 to your computer and use it in GitHub Desktop.
Save so298/b5fc4127f161dbd65429f5756d771d88 to your computer and use it in GitHub Desktop.
Comparing two `.safetensors` files
import argparse
from safetensors import safe_open
import torch
parser = argparse.ArgumentParser(description='Compare two safetensors')
parser.add_argument('tensor1', type=str, help='First tensor to compare')
parser.add_argument('tensor2', type=str, help='Second tensor to compare')
args = parser.parse_args()
tensor1 = {}
tensor2 = {}
with safe_open(args.tensor1, framework="pt", device="cpu") as f:
print("Loading tensor1")
for key in f.keys():
tensor1[key] = f.get_tensor(key)
with safe_open(args.tensor2, framework="pt", device="cpu") as f:
print("Loading tensor2")
for key in f.keys():
tensor2[key] = f.get_tensor(key)
# Check if the keys are the same
if tensor1.keys() != tensor2.keys():
print("Keys are not the same")
print("First tensor keys: ", tensor1.keys())
print("Second tensor keys: ", tensor2.keys())
exit(1)
# Check if the tensors are the same
for key in tensor1.keys():
if not torch.all(torch.eq(tensor1[key], tensor2[key])):
print("Tensors are not the same")
print("Key: ", key)
print("First tensor: ", tensor1[key])
print("Second tensor: ", tensor2[key])
exit(1)
print("Tensors are the same")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment