Skip to content

Instantly share code, notes, and snippets.

@madebyollin
Created July 29, 2023 02:36
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save madebyollin/034afe6670fc03966d075912cbccf797 to your computer and use it in GitHub Desktop.
Save madebyollin/034afe6670fc03966d075912cbccf797 to your computer and use it in GitHub Desktop.
script for comparing the contents of safetensors files
#!/usr/bin/env python3
from pathlib import Path
from safetensors.torch import load_file
def summarize_tensor(x):
if x is None:
return "None"
x = x.float()
return f"({x.min().item():.3f}, {x.mean().item():.3f}, {x.max().item():.3f})"
def compare_keys(dev_keys, ref_keys, dev_name, ref_name):
out = f"\n{ref_name} has {len(ref_keys)} keys; {dev_name} has {len(dev_keys)} keys"
out += f"\nkeys in {ref_name} but not in {dev_name}: {ref_keys - dev_keys}"
out += f"\nkeys in {dev_name} but not in {ref_name}: {dev_keys - ref_keys}"
return out
def main(a_path, b_path):
a_path, b_path = Path(a_path), Path(b_path)
assert a_path.exists()
assert b_path.exists()
a_st = load_file(a_path)
b_st = load_file(b_path)
print(compare_keys(a_st.keys(), b_st.keys(), a_path, b_path))
all_keys = sorted(list(a_st.keys() | b_st.keys()))
key_col_width = max(len(k) for k in all_keys) + 1
for k in all_keys:
a_val = a_st.get(k, None)
b_val = b_st.get(k, None)
if a_val is not None and b_val is not None and (a_val == b_val).all():
print(f"{k.ljust(key_col_width)} \033[37mIdentical\033[0m")
else:
diff = f"\033[34m{summarize_tensor(a_val).ljust(32)} \033[30m->\033[0m \033[36m{summarize_tensor(b_val).ljust(32)}\033[0m"
if a_val is not None and b_val is not None:
net_change = b_val.std() / a_val.std().add(1e-8)
net_change_str = f"{net_change.item():.4f}x"
else:
net_change = 1.0
net_change_str = ""
if net_change > 1.5:
net_change_str = f"\033[31m{net_change_str}\033[0m"
elif net_change < 0.5:
net_change_str = f"\033[32m{net_change_str}\033[0m"
else:
net_change_str = f"\033[30m{net_change_str}\033[0m"
print(f"{k.ljust(key_col_width)} {diff} {net_change_str}")
if __name__ == "__main__":
import sys
assert len(sys.argv[1:]) == 2, f"Try: {sys.argv[0]} a.safetensors b.safetensors"
main(*sys.argv[1:])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment