Skip to content

Instantly share code, notes, and snippets.

@andreaskoepf
Created August 29, 2023 14:06
Show Gist options
  • Save andreaskoepf/7db772b240c213f429a278a4ffeec7ba to your computer and use it in GitHub Desktop.
Save andreaskoepf/7db772b240c213f429a278a4ffeec7ba to your computer and use it in GitHub Desktop.
load model & pad embedding layers to multiple of N (e.g. 128)
import argparse
import sys
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("model_name", type=str, help="checkpoint path or model name")
parser.add_argument("--dtype", type=str, default="auto", help="auto, fp16, bf16 or fp32")
parser.add_argument("--output_folder", type=str, help="output folder path", required=True)
parser.add_argument("--max_shard_size", type=str, default="10GB")
parser.add_argument("--cache_dir", type=str)
parser.add_argument(
"--trust_remote_code",
action="store_true",
default=False,
help="allow custom model code (required for Falcon)",
)
parser.add_argument(
"--pad_vocab_size_to_multiple_of", type=int, default=128, help="make vocab size divisible by this number"
)
return parser.parse_args()
@torch.no_grad()
def main():
args = parse_args()
print(args)
if args.dtype == "auto":
torch_dtype = None
elif args.dtype in ("float16", "fp16"):
torch_dtype = torch.float16
elif args.dtype in ("float32", "fp32"):
torch_dtype = torch.float32
elif args.dtype in ("bfloat16", "bf16"):
torch_dtype = torch.bfloat16
else:
print(f"Unsupported dtype: {args.dtype}")
sys.exit(1)
print(f"Loading tokenizer '{args.model_name}' ...")
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
print(f"{type(tokenizer).__name__} (vocab_size={len(tokenizer)})")
print(f"Loading model '{args.model_name}' ({args.dtype}) ...")
model = AutoModelForCausalLM.from_pretrained(
args.model_name,
torch_dtype=torch_dtype,
cache_dir=args.cache_dir,
trust_remote_code=args.trust_remote_code,
)
print(f"{type(model).__name__} (num_parameters={model.num_parameters()})")
print("Model architecture:")
print(model)
old_input_shape = model.get_input_embeddings().weight.shape
print("input_embeddings shape before resize:", old_input_shape)
if model.get_output_embeddings() is not None and not model.config.tie_word_embeddings:
old_output_shape = model.get_output_embeddings().weight.shape
print("output_embeddings shape before resize:", old_output_shape)
else:
print(f"no (separate) output embeddings: {model.config.tie_word_embeddings=}")
model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=args.pad_vocab_size_to_multiple_of)
input_embeddings = model.get_input_embeddings()
print("new input_embeddings shape", input_embeddings.weight.shape)
# duplicate dummy embedding
def fill_dummy(embeddings, old_shape: torch.Size, dummy_id: int):
old_size = old_shape[0]
new_size = embeddings.weight.shape[0]
for i in range(old_size, new_size):
print(f"overwriting embedding[{i}] with embedding[{dummy_id}]")
embeddings.weight.data[i] = embeddings.weight.data[dummy_id]
if tokenizer.unk_token_id:
fill_dummy(input_embeddings, old_input_shape, tokenizer.unk_token_id)
output_embeddings = model.get_output_embeddings()
if output_embeddings is not None and not model.config.tie_word_embeddings:
print("new output_embeddings shape", output_embeddings.weight.shape)
if tokenizer.unk_token_id:
fill_dummy(output_embeddings, old_output_shape, tokenizer.unk_token_id)
else:
print(f"no (separate) output embeddings: {model.config.tie_word_embeddings=}")
# write model to output dir
if args.output_folder:
print(f"Saving model to: {args.output_folder}")
model.save_pretrained(args.output_folder, max_shard_size=args.max_shard_size)
print(f"Saving tokenizer to: {args.output_folder}")
tokenizer.save_pretrained(args.output_folder)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment