Skip to content

Instantly share code, notes, and snippets.

@ehartford
Created January 1, 2024 07:09
Show Gist options
  • Save ehartford/5d8452c1f2e8395398e86106388660df to your computer and use it in GitHub Desktop.
Save ehartford/5d8452c1f2e8395398e86106388660df to your computer and use it in GitHub Desktop.
convert yayi2-30b to llama. All the credit to Charles Goddard and Weyaxi
import copy
import os
import safetensors.torch
import glob
import json
def transform_st(path: str, out_dir: str):
data = safetensors.torch.load_file(path)
old_keys = list(data.keys())
for key in old_keys:
old_key = key
if ".ln1." in key:
key = key.replace(".ln1.", ".input_layernorm.")
if ".ln2." in key:
key = key.replace(".ln2.", ".post_attention_layernorm.")
if key != old_key:
data[key] = data[old_key]
del data[old_key]
safetensors.torch.save_file(
data, os.path.join(out_dir, os.path.basename(path)), metadata={"format": "pt"}
)
def process_model(path: str, out_path: str):
for p in glob.glob(os.path.join(path, "model-*.safetensors")):
transform_st(p, out_path)
with open(os.path.join(path, "model.safetensors.index.json", "r")) as fd:
index_data = json.load(fd)
new_index = {"metadata": copy.copy(index_data["metadata"]), "weight_map": {}}
for key in index_data["weight_map"]:
new_key = key.replace(".ln1.", ".input_layernorm.").replace(
".ln2.", ".post_attention_layernorm."
)
new_index["weight_map"][new_key] = index_data["weight_map"][key]
with open(
os.path.join(out_path, "model.safetensors.index.json", "w", encoding="utf-8")
) as fd:
json.dump(new_index, fd)
process_model("/workspace/models/yayi2/", "/workspace/yayi2-30b-llama/")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment