Skip to content

Instantly share code, notes, and snippets.

@skeeet
Forked from Narsil/pure_torch.py
Created November 21, 2022 23:01
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save skeeet/88bbad9dae532473e36e855721cf7e06 to your computer and use it in GitHub Desktop.
Save skeeet/88bbad9dae532473e36e855721cf7e06 to your computer and use it in GitHub Desktop.
Loading a safetensors file with pure torch only
import mmap
import torch
import json
import os
from huggingface_hub import hf_hub_download
def load_file(filename, device):
with open(filename, mode="r", encoding="utf8") as file_obj:
with mmap.mmap(file_obj.fileno(), length=0, access=mmap.ACCESS_READ) as m:
header = m.read(8)
n = int.from_bytes(header, "little")
metadata_bytes = m.read(n)
metadata = json.loads(metadata_bytes)
size = os.stat(filename).st_size
storage = torch.ByteStorage.from_file(filename, shared=False, size=size).untyped()
offset = n + 8
return {name: create_tensor(storage, info, offset) for name, info in metadata.items() if name != "__metadata__"}
DTYPES = {"F32": torch.float32}
device = "cpu"
def create_tensor(storage, info, offset):
dtype = DTYPES[info["dtype"]]
shape = info["shape"]
start, stop = info["data_offsets"]
return torch.asarray(storage[start + offset : stop + offset], dtype=torch.uint8).view(dtype=dtype).reshape(shape)
def main():
filename = hf_hub_download("gpt2", filename="model.safetensors")
weights = load_file(filename, device)
print(weights.keys())
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment