Skip to content

Instantly share code, notes, and snippets.

@malteos
Created August 16, 2023 12:04
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save malteos/85fd117cb0ba4cd28882464026252ee9 to your computer and use it in GitHub Desktop.
Save malteos/85fd117cb0ba4cd28882464026252ee9 to your computer and use it in GitHub Desktop.
# Copyright 2022 EleutherAI and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import argparse
import gc
import json
import os
import warnings
from tempfile import TemporaryDirectory
import torch
from transformers import AutoTokenizer
# TODO (files copied from HF-hub) this can be proably done with AutoModel + trust_remote
from custom_models.falcon_7b import RWConfig, RWForCausalLM
from permute_qkv import permute_qkv
def write_json(text, path):
with open(path, "w") as f:
json.dump(text, f)
def convert_wqkv(llama_mega, layer_idx=0, n_heads=32, n_heads_kv=8):
mega_qkv = llama_mega["transformer"][f'layers.{layer_idx}.attention.query_key_value.weight']
n_hidden_per_head = mega_qkv.shape[1]//n_heads
# mega_qkv = permute_qkv(mega_qkv, mega_qkv.shape[1], n_heads, n_heads_kv, revert=True)
mega_qkv_chunk = torch.split(mega_qkv, n_hidden_per_head, dim=0)
wq_proj, wk_proj, wv_proj = [], [], []
for i,chk in enumerate(mega_qkv_chunk):
if i%3 == 0:
wq_proj.append(chk)
elif i%3 == 1:
wk_proj.append(chk)
else:
wv_proj.append(chk)
wq_proj = torch.concat(wq_proj, dim=0)
wk_proj = torch.concat(wk_proj, dim=0)
wv_proj = torch.concat(wv_proj, dim=0)
return wq_proj, wk_proj, wv_proj
def convert_ffn(llama_mega, layer_idx=0, n_dense=11008):
mega_ffn = llama_mega["transformer"][f'layers.{layer_idx}.mlp.dense_h_to_4h.weight']
ffn_w3, ffn_w1 = mega_ffn.split(n_dense, dim=0)
return ffn_w1, ffn_w3
def write_model(model_path,
input_base_path,
num_output_shards=2,
norm_eps=1e-05,
falcon_size: int = 7,
validate_shapes: bool = False):
# Preliminaries
print(f"Fetching all parameters from the checkpoint at {input_base_path}.")
os.makedirs(model_path, exist_ok=True)
base = 10000.0
with open(os.path.join(input_base_path, 'latest_checkpointed_iteration.txt')) as f:
iteration = f.read()
if iteration != "release":
iteration = f"iter_{int(iteration):07d}"
print(f"Fetching iteration {iteration}")
# Load weights
loaded = torch.load(os.path.join(input_base_path, iteration, 'mp_rank_00', 'model_optim_rng.pt'), map_location="cpu")
args = loaded['args']
loaded = loaded['model']['language_model']
if 'transformer' not in loaded: # normalize key names
loaded["transformer"] = loaded.pop("encoder")
for key in list(loaded["transformer"].keys()):
loaded["transformer"][key.replace("self_attention", "attention")] = loaded["transformer"].pop(key)
loaded["embedding"]["word_embeddings.weight"] = loaded["embedding"].pop("word_embeddings")["weight"]
args.num_layers = args.encoder_num_layers
# Load arguments
n_layers = args.num_layers
n_heads = args.num_attention_heads
n_heads_kv = getattr(args, "num_attention_heads_kv", n_heads)
n_dense = args.ffn_hidden_size
n_hidden = args.hidden_size
hidden_per_head = n_hidden // n_heads
intermediate_size = args.ffn_hidden_size
inv_freq = 1.0 / (base ** (torch.arange(0, hidden_per_head, 2).float() / hidden_per_head))
print('Falcon-Megatron Loaded!')
param_count = 0
index_dict = {"weight_map": {}}
hf_config = RWConfig(
vocab_size=args.padded_vocab_size,
hidden_size=n_hidden,
intermediate_size=intermediate_size,
num_attention_heads=n_heads,
num_hidden_layers=n_layers,
rms_norm_eps=norm_eps,
parallel_attn=True if falcon_size == 7 else False, # TODO difference falcon7 vs 40
)
if validate_shapes:
print("Initializing dummy model...")
hf_model = RWForCausalLM(
config=hf_config
)
hf_sd = hf_model.state_dict()
"""
HF state dict as the following format:
transformer.word_embeddings.weight
transformer.h.0
- .input_layernorm
- .weight
- .bias
- .self_attention
- .query_key_value
- .post_attention
- .mlp.
...
- transformer.ln_f.weight/bias
- .lm_head.weight
"""
# loaded = source weights from Megatron
def permute(qkv_w):
# return permute_qkv(qkv_w, dim, n_heads, n_heads_kv)
return permute_qkv(qkv_w, n_hidden, n_heads, n_heads_kv, revert=True)
# Start conversion
with TemporaryDirectory() as tmp_model_path:
print(f'Weighted Converting for {n_layers} layers...')
for layer_i in range(n_layers):
filename = f"pytorch_model-{layer_i + 1}-of-{n_layers + 1}.bin"
hf_prefix = f"transformer.h.{layer_i}"
meg_prefix = f"layers.{layer_i}"
key_mapping = {
# mlp
f"{hf_prefix}.mlp.dense_h_to_4h.weight": {"meg_key": f"{meg_prefix}.mlp.dense_h_to_4h.weight"},
f"{hf_prefix}.mlp.dense_4h_to_h.weight": {"meg_key": f"{meg_prefix}.mlp.dense_4h_to_h.weight"},
# qkv weights
f"{hf_prefix}.self_attention.query_key_value.weight": {"meg_key": f"{meg_prefix}.attention.query_key_value.weight", "permute": True},
# dense
f"{hf_prefix}.self_attention.dense.weight": {"meg_key": f"{meg_prefix}.attention.dense.weight"},
}
# falcon7 and falcon40 differ in the input layernorms
if falcon_size == 7:
key_mapping.update({
f"{hf_prefix}.input_layernorm.weight": {"meg_key": f"{meg_prefix}.input_layernorm.weight"},
f"{hf_prefix}.input_layernorm.bias": {"meg_key": f"{meg_prefix}.input_layernorm.bias"},
})
else:
key_mapping.update({
f"{hf_prefix}.ln_attn.weight": {"meg_key": f"{meg_prefix}.input_layernorm.weight"},
f"{hf_prefix}.ln_mlp.weight": {"meg_key": f"{meg_prefix}.mlp_layernorm.weight"},
f"{hf_prefix}.ln_attn.bias": {"meg_key": f"{meg_prefix}.input_layernorm.bias"},
f"{hf_prefix}.ln_mlp.bias": {"meg_key": f"{meg_prefix}.mlp_layernorm.bias"},
})
# convert based on mapping
state_dict = {}
for hf_key, mapping_item in key_mapping.items():
meg_v = loaded["transformer"][mapping_item["meg_key"]]
if "permute" in mapping_item:
meg_v = permute(meg_v)
if validate_shapes:
hf_v = hf_sd[hf_key]
assert hf_v.shape == meg_v.shape
state_dict[hf_key] = meg_v
# count parameters
for k, v in state_dict.items():
index_dict["weight_map"][k] = filename
param_count += v.numel()
torch.save(state_dict, os.path.join(tmp_model_path, filename))
print(f'Sharded file saved to {filename}')
# Last layer + embeddings
filename = f"pytorch_model-{n_layers + 1}-of-{n_layers + 1}.bin"
state_dict = {
"transformer.word_embeddings.weight": loaded["embedding"]["word_embeddings.weight"],
"transformer.ln_f.weight": loaded["transformer"]["final_layernorm.weight"],
"transformer.ln_f.bias": loaded["transformer"]["final_layernorm.bias"],
}
if validate_shapes:
assert hf_sd["transformer.word_embeddings.weight"].shape == loaded["embedding"]["word_embeddings.weight"].shape
assert hf_sd["transformer.ln_f.weight"].shape == loaded["transformer"]["final_layernorm.weight"].shape
assert hf_sd["transformer.ln_f.bias"].shape == loaded["transformer"]["final_layernorm.bias"].shape
for k, v in state_dict.items():
index_dict["weight_map"][k] = filename
param_count += v.numel()
torch_dtype = state_dict["transformer.ln_f.weight"].dtype
torch.save(state_dict, os.path.join(tmp_model_path, filename))
print(f'Sharded file saved to {filename}')
# Write configs and save
index_dict["metadata"] = {"total_size": param_count * 2}
write_json(index_dict, os.path.join(tmp_model_path, "pytorch_model.bin.index.json"))
hf_config.save_pretrained(tmp_model_path)
# Make space so we can load the model properly now.
del state_dict
del loaded
gc.collect()
print("Loading the checkpoint in a Falcon model...")
model = RWForCausalLM.from_pretrained(tmp_model_path, torch_dtype=torch_dtype)
# Avoid saving this as part of the config.
del model.config._name_or_path
print("Saving in the Transformers format.")
max_num_params_per_shard = param_count*2 // max(1,(num_output_shards-1))
model.save_pretrained(model_path, max_shard_size=max_num_params_per_shard)
def write_tokenizer(tokenizer_path, input_tokenizer_path):
tokenizer = AutoTokenizer.from_pretrained(input_tokenizer_path)
tokenizer.save_pretrained(tokenizer_path)
def main():
# make sure megatron is importable
sys.path.append(os.path.abspath(
os.path.join(os.path.dirname(__file__),
os.path.pardir)))
parser = argparse.ArgumentParser()
parser.add_argument(
"--input_dir",
help="Location of Falcon_Megatron weights",
)
parser.add_argument(
"--num_output_shards",
type=int,
default=1,
)
parser.add_argument(
"--falcon_size",
type=int,
default=7,
)
parser.add_argument(
"--output_dir",
type=str,
help="Location to write HF model and tokenizer",
)
parser.add_argument(
"--input_tokenizer_path",
type=str,
help="Location load tokenizer",
)
parser.add_argument(
"--validate_shapes",
type=bool,
help="Check if shapes of HF or Meg weights are identical (requires loading a dummy model)",
)
args = parser.parse_args()
write_model(
model_path=args.output_dir,
input_base_path=args.input_dir,
num_output_shards=args.num_output_shards,
validate_shapes=args.validate_shapes,
falcon_size=args.falcon_size,
)
if args.input_tokenizer_path:
write_tokenizer(args.output_dir, args.input_tokenizer_path)
if __name__ == "__main__":
main()
@malteos
Copy link
Author

malteos commented Aug 16, 2023

Output of verify_correctness.py for a 340M model like falcon-7b:

Max absoulute error in the logits: max=0.155545, avg=0.007444
Abs loss error: 0.003520 Our loss: 6.099, theirs: 6.103

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment