Skip to content

Instantly share code, notes, and snippets.

@Narsil
Created November 10, 2022 15:06
Show Gist options
  • Save Narsil/3edeec2669a5e94e4707aa0f901d2282 to your computer and use it in GitHub Desktop.
Save Narsil/3edeec2669a5e94e4707aa0f901d2282 to your computer and use it in GitHub Desktop.
Loading a safetensors file with pure torch only
import mmap
import torch
import json
import os
from huggingface_hub import hf_hub_download
def load_file(filename, device):
with open(filename, mode="r", encoding="utf8") as file_obj:
with mmap.mmap(file_obj.fileno(), length=0, access=mmap.ACCESS_READ) as m:
header = m.read(8)
n = int.from_bytes(header, "little")
metadata_bytes = m.read(n)
metadata = json.loads(metadata_bytes)
size = os.stat(filename).st_size
storage = torch.ByteStorage.from_file(filename, shared=False, size=size).untyped()
offset = n + 8
return {name: create_tensor(storage, info, offset) for name, info in metadata.items() if name != "__metadata__"}
DTYPES = {"F32": torch.float32}
device = "cpu"
def create_tensor(storage, info, offset):
dtype = DTYPES[info["dtype"]]
shape = info["shape"]
start, stop = info["data_offsets"]
return torch.asarray(storage[start + offset : stop + offset], dtype=torch.uint8).view(dtype=dtype).reshape(shape)
def main():
filename = hf_hub_download("gpt2", filename="model.safetensors")
weights = load_file(filename, device)
print(weights.keys())
if __name__ == "__main__":
main()
@Jeffwan
Copy link

Jeffwan commented Sep 8, 2023

Seems device is not being used here.

@s-smits
Copy link

s-smits commented Nov 5, 2024

When will torch get default safetesnors support?

@julien-c
Copy link

julien-c commented Nov 7, 2024

you should open that request in torch repo, i think it'd be awesome

@s-smits
Copy link

s-smits commented Nov 7, 2024

Great, will do that.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment