Skip to content

Instantly share code, notes, and snippets.

@Narsil
Created December 20, 2022 21:19
Show Gist options
  • Save Narsil/d5b0d747e5c8c299eb6d82709e480e3d to your computer and use it in GitHub Desktop.
Save Narsil/d5b0d747e5c8c299eb6d82709e480e3d to your computer and use it in GitHub Desktop.
from huggingface_hub import hf_hub_download
from flax.serialization import msgpack_restore
from safetensors.flax import save_file
import numpy as np
filename = hf_hub_download("gpt2", filename="flax_model.msgpack")
with open(filename, "rb") as f:
data = f.read()
flax_weights = msgpack_restore(data)
def flatten(weights, prefix=""):
values = {}
for k, v in weights.items():
newprefix = f"{prefix}.{k}" if prefix else f"{k}"
print(newprefix)
if isinstance(v, dict):
values.update(flatten(v, prefix=newprefix))
elif isinstance(v, np.ndarray):
values[newprefix] = v
return values
weights = flatten(flax_weights)
save_file(weights, "model.safetensors")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment