Skip to content

Instantly share code, notes, and snippets.

@lucataco
Created January 26, 2023 20:24
Show Gist options
  • Save lucataco/14cf7bc12a187885145ebbaa92572428 to your computer and use it in GitHub Desktop.
Save lucataco/14cf7bc12a187885145ebbaa92572428 to your computer and use it in GitHub Desktop.
Safetensors speed comparison with flan-t5-large
import os
import datetime
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
import torch
sf_filename = "./model.safetensors"
pt_filename = hf_hub_download("google/flan-t5-large", filename="pytorch_model.bin")
start_st = datetime.datetime.now()
weights = load_file(sf_filename, device="cpu")
load_time_st = datetime.datetime.now() - start_st
print(f"Loaded safetensors {load_time_st}")
start_pt = datetime.datetime.now()
weights = torch.load(pt_filename, map_location="cpu")
load_time_pt = datetime.datetime.now() - start_pt
print(f"Loaded pytorch {load_time_pt}")
print(f"on CPU, safetensors is faster than pytorch by: {load_time_pt/load_time_st:.1f} X")
# This is required because this feature hasn't been fully verified yet
os.environ["SAFETENSORS_FAST_GPU"] = "1"
# CUDA startup out of the measurement
torch.zeros((2, 2)).cuda()
start_st = datetime.datetime.now()
weights = load_file(sf_filename, device="cuda:0")
load_time_st = datetime.datetime.now() - start_st
print(f"Loaded safetensors {load_time_st}")
start_pt = datetime.datetime.now()
weights = torch.load(pt_filename, map_location="cuda:0")
load_time_pt = datetime.datetime.now() - start_pt
print(f"Loaded pytorch {load_time_pt}")
print(f"on GPU, safetensors is faster than pytorch by: {load_time_pt/load_time_st:.1f} X")
@lucataco
Copy link
Author

lucataco commented Jan 26, 2023

HF Documentation: https://huggingface.co/docs/safetensors/speed#gpu-benchmark
CPU: Ryzen 7 5800X
GPU: RTX 4080

Loaded safetensors 0:00:00.007597
Loaded pytorch 0:00:00.642529
on CPU, safetensors is faster than pytorch by: 84.6 X
Loaded safetensors 0:00:00.292713
Loaded pytorch 0:00:00.495908
on GPU, safetensors is faster than pytorch by: 1.7 X

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment