Skip to content

Instantly share code, notes, and snippets.

@Narsil
Created November 10, 2022 15:06
Show Gist options
  • Star 20 You must be signed in to star a gist
  • Fork 3 You must be signed in to fork a gist
  • Save Narsil/3edeec2669a5e94e4707aa0f901d2282 to your computer and use it in GitHub Desktop.
Save Narsil/3edeec2669a5e94e4707aa0f901d2282 to your computer and use it in GitHub Desktop.
Loading a safetensors file with pure torch only
import mmap
import torch
import json
import os
from huggingface_hub import hf_hub_download
def load_file(filename, device):
with open(filename, mode="r", encoding="utf8") as file_obj:
with mmap.mmap(file_obj.fileno(), length=0, access=mmap.ACCESS_READ) as m:
header = m.read(8)
n = int.from_bytes(header, "little")
metadata_bytes = m.read(n)
metadata = json.loads(metadata_bytes)
size = os.stat(filename).st_size
storage = torch.ByteStorage.from_file(filename, shared=False, size=size).untyped()
offset = n + 8
return {name: create_tensor(storage, info, offset) for name, info in metadata.items() if name != "__metadata__"}
DTYPES = {"F32": torch.float32}
device = "cpu"
def create_tensor(storage, info, offset):
dtype = DTYPES[info["dtype"]]
shape = info["shape"]
start, stop = info["data_offsets"]
return torch.asarray(storage[start + offset : stop + offset], dtype=torch.uint8).view(dtype=dtype).reshape(shape)
def main():
filename = hf_hub_download("gpt2", filename="model.safetensors")
weights = load_file(filename, device)
print(weights.keys())
if __name__ == "__main__":
main()
@Jeffwan
Copy link

Jeffwan commented Sep 8, 2023

Seems device is not being used here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment