Skip to content

Instantly share code, notes, and snippets.

@zphang
Created July 7, 2022 23:29
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save zphang/f7c71da5c39a53002c171f1cee31ab77 to your computer and use it in GitHub Desktop.
Save zphang/f7c71da5c39a53002c171f1cee31ab77 to your computer and use it in GitHub Desktop.
GPT-NeoX-20B HF Conversion
config = configuration_gpt_neox.GPTNeoXConfig()
hf_model = modeling_gpt_neox.GPTNeoXForCausalLM(config).half().cuda()
checkpoint_path = "/path/to/global_step150000"
loaded_tp1 = torch.load(os.path.join(checkpoint_path, "layer_00-model_00-model_states.pt"))
loaded_tp2 = torch.load(os.path.join(checkpoint_path, "layer_00-model_01-model_states.pt"))
hf_model.gpt_neox.embed_in.load_state_dict({"weight": torch.cat([
loaded_tp1["word_embeddings.weight"],
loaded_tp2["word_embeddings.weight"],
], dim=0)})
for layer_i in display.trange(44):
layer = hf_model.gpt_neox.layers[layer_i]
filename_tp1 = f"layer_{layer_i + 2:02d}-model_00-model_states.pt"
filename_tp2 = f"layer_{layer_i + 2:02d}-model_01-model_states.pt"
loaded_tp1 = torch.load(os.path.join(checkpoint_path, filename_tp1))
loaded_tp2 = torch.load(os.path.join(checkpoint_path, filename_tp2))
state_dict = {}
for key in [
"attention.dense.weight",
"mlp.dense_4h_to_h.weight",
]:
state_dict[key] = torch.cat([loaded_tp1[key], loaded_tp2[key]], dim=1)
state_dict["input_layernorm.weight"] = (
loaded_tp1["input_layernorm.weight"] + loaded_tp2["input_layernorm.weight"]) / 2
state_dict["input_layernorm.bias"] = (
loaded_tp1["input_layernorm.bias"] + loaded_tp2["input_layernorm.bias"]) / 2
state_dict["post_attention_layernorm.weight"] = (
loaded_tp1["post_attention_layernorm.weight"] + loaded_tp2["post_attention_layernorm.weight"]) / 2
state_dict["post_attention_layernorm.bias"] = (
loaded_tp1["post_attention_layernorm.bias"] + loaded_tp2["post_attention_layernorm.bias"]) / 2
# LinearWithTPMerge
state_dict["mlp.dense_h_to_4h.weight"] = torch.cat([
loaded_tp1["mlp.dense_h_to_4h.weight"],
loaded_tp2["mlp.dense_h_to_4h.weight"],
], dim=0)
state_dict["mlp.dense_h_to_4h.bias"] = torch.cat([
loaded_tp1["mlp.dense_h_to_4h.bias"],
loaded_tp2["mlp.dense_h_to_4h.bias"],
], dim=0)
state_dict["attention.query_key_value.weight"] = torch.cat([
loaded_tp1["attention.query_key_value.weight"],
loaded_tp2["attention.query_key_value.weight"],
], dim=0)
state_dict["attention.query_key_value.bias"] = torch.cat([
loaded_tp1["attention.query_key_value.bias"],
loaded_tp2["attention.query_key_value.bias"],
], dim=0)
# LinearWithTPSplitBias
state_dict["mlp.dense_4h_to_h.bias"] = (
loaded_tp1["mlp.dense_4h_to_h.bias"]
+ loaded_tp2["mlp.dense_4h_to_h.bias"]
)
state_dict["attention.dense.bias"] = (
loaded_tp1["attention.dense.bias"]
+ loaded_tp2["attention.dense.bias"]
)
# Just take one
state_dict["attention.rotary_emb.inv_freq"] = loaded_tp1["attention.rotary_emb.inv_freq"]
state_dict["attention.bias"] = layer.state_dict()["attention.bias"]
state_dict["attention.masked_bias"] = layer.state_dict()["attention.masked_bias"]
layer.load_state_dict(state_dict)
# Load final layer norm
loaded_tp1 = torch.load(os.path.join(checkpoint_path, "layer_47-model_00-model_states.pt"))
loaded_tp2 = torch.load(os.path.join(checkpoint_path, "layer_47-model_01-model_states.pt"))
hf_model.gpt_neox.final_layer_norm.load_state_dict({
"weight": (loaded_tp1["norm.weight"] + loaded_tp2["norm.weight"])/2,
"bias": (loaded_tp1["norm.bias"] + loaded_tp2["norm.bias"])/2,
})
del loaded_tp1
del loaded_tp2
# Load output embedding
loaded_tp1 = torch.load(os.path.join(checkpoint_path, "layer_48-model_00-model_states.pt"))
loaded_tp2 = torch.load(os.path.join(checkpoint_path, "layer_48-model_01-model_states.pt"))
hf_model.embed_out.load_state_dict({
"weight": torch.cat([
loaded_tp1["final_linear.weight"],
loaded_tp2["final_linear.weight"],
], dim=0),
})
del loaded_tp1
del loaded_tp2
hf_model.save_pretrained(
"/path/to/neox20b",
max_shard_size="1GB",
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment