Skip to content

Instantly share code, notes, and snippets.

@Quentin-Anthony
Created May 27, 2023 20:50
Show Gist options
  • Save Quentin-Anthony/1648b14e1ea798e7901da985760ca53f to your computer and use it in GitHub Desktop.
Save Quentin-Anthony/1648b14e1ea798e7901da985760ca53f to your computer and use it in GitHub Desktop.
Compares numpy, native torch, safetensors for save/load
import torch
from safetensors.torch import save_file, load_file
import numpy as np
import argparse
import os
import time
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--no-save", action="store_false", help="disables saving initial tensors")
parser.add_argument("--path", type=str, default=os.getcwd(), help="where to save/load tensors")
parser.add_argument("--no-safetensors", action="store_false", help="disables saving/loading safetensors format")
parser.add_argument("--no-numpy", action="store_false", help="disables saving/loading numpy format")
parser.add_argument("--no-torch", action="store_false", help="disables saving/loading torch format")
parser.add_argument("--size", type=int, default=int(1000000000), help="number of tensor elements to save/load. Defaults to 1B.")
parser.add_argument("--dtype", type=str, default='float', help="torch tensor dtype to save/load. Defaults to float")
parser.add_argument("--no-load", action="store_false", help="whether to load tensors")
args = parser.parse_args()
base_path = os.path.join(args.path, f'tensor_{args.size}_{args.dtype}')
if args.no_save:
t = torch.ones(args.size, dtype=getattr(torch, args.dtype), device='cuda')
if args.no_numpy:
t1 = time.time()
np.save(base_path, t.detach().cpu().numpy())
print(f'time to save numpy: {time.time() - t1}')
if args.no_torch:
t1 = time.time()
torch.save(t, f'{base_path}.pt')
print(f'time to save torch: {time.time() - t1}')
if args.no_safetensors:
t1 = time.time()
save_file({'t': t}, f'{base_path}.safetensors')
print(f'time to save safetensors: {time.time() - t1}')
if args.no_load:
if args.no_numpy:
t1 = time.time()
a = np.load(f'{base_path}.npy')
t = torch.from_numpy(a).to('cuda')
print(f'time to load numpy: {time.time() - t1}')
if args.no_torch:
t1 = time.time()
t = torch.load(f'{base_path}.pt', map_location=torch.device('cuda'))
print(f'time to load torch: {time.time() - t1}')
if args.no_safetensors:
os.environ["SAFETENSORS_FAST_GPU"] = "1"
t1 = time.time()
t = load_file(f'{base_path}.safetensors', device='cuda')
print(f'time to load safetensors: {time.time() - t1}')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment