Skip to content

Instantly share code, notes, and snippets.

@zeux
Created January 18, 2024 20:14
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save zeux/b9b3019d0f410e07ec468a38b7aebc7c to your computer and use it in GitHub Desktop.
Save zeux/b9b3019d0f410e07ec468a38b7aebc7c to your computer and use it in GitHub Desktop.
Safetensors load/save benchmark (assumes input model is fp16 and converts to bf16)
import argparse
import json
import os
import safetensors
import safetensors.torch
import sys
import time
import torch
def fast_save_file(tensors, filename, metadata=None):
_TYPES = {
torch.float32: "F32",
torch.float16: "F16",
torch.bfloat16: "BF16",
getattr(torch, "float8_e5m2", None): "F8_E5M2",
getattr(torch, "float8_e4m3fn", None): "F8_E4M3",
torch.int32: "I32",
torch.int16: "I16",
torch.int8: "I8",
torch.uint8: "U8",
}
_ALIGN = 256
header = {}
offset = 0
if metadata:
header["__metadata__"] = metadata
for k, v in tensors.items():
size = v.numel() * v.element_size()
header[k] = { "dtype": _TYPES[v.dtype], "shape": v.shape, "data_offsets": [offset, offset + size] }
offset += size
hjson = json.dumps(header).encode("utf-8")
hjson += b" " * (-(len(hjson) + 8) % _ALIGN)
with open(filename, "wb") as f:
f.write(len(hjson).to_bytes(8, byteorder="little"))
f.write(hjson)
for k, v in tensors.items():
assert v.layout == torch.strided and v.is_contiguous()
v.view(torch.uint8).cpu().numpy().tofile(f)
argp = argparse.ArgumentParser()
argp.add_argument("input", type=str)
argp.add_argument("output", type=str)
argp.add_argument("--fast", action="store_true")
argp.add_argument("--device", type=str, default="cpu")
args = argp.parse_args()
size = os.path.getsize(args.input)
beg = time.time()
# load model files and convert to float16
weights = {}
with safetensors.safe_open(args.input, framework="pt", device=args.device) as f:
for k in f.keys():
assert(k not in weights)
v = f.get_tensor(k)
v = v.to(torch.bfloat16)
weights[k] = v
mid = time.time()
# save tensors to disk
if args.fast:
fast_save_file(weights, args.output)
else:
safetensors.torch.save_file(weights, args.output)
end = time.time()
rsize = os.path.getsize(args.output)
print(f"load: {size / 1024 / 1024 / 1024:.2f} GiB, {mid - beg:.3f} sec, {size / (mid - beg) / 1e9:.2f} GB/s")
print(f"save: {rsize / 1024 / 1024 / 1024:.2f} GiB, {end - mid:.3f} sec, {rsize / (end - mid) / 1e9:.2f} GB/s")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment