Created
August 29, 2023 14:06
-
-
Save andreaskoepf/7db772b240c213f429a278a4ffeec7ba to your computer and use it in GitHub Desktop.
load model & pad embedding layers to multiple of N (e.g. 128)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import sys | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
def parse_args(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("model_name", type=str, help="checkpoint path or model name") | |
parser.add_argument("--dtype", type=str, default="auto", help="auto, fp16, bf16 or fp32") | |
parser.add_argument("--output_folder", type=str, help="output folder path", required=True) | |
parser.add_argument("--max_shard_size", type=str, default="10GB") | |
parser.add_argument("--cache_dir", type=str) | |
parser.add_argument( | |
"--trust_remote_code", | |
action="store_true", | |
default=False, | |
help="allow custom model code (required for Falcon)", | |
) | |
parser.add_argument( | |
"--pad_vocab_size_to_multiple_of", type=int, default=128, help="make vocab size divisible by this number" | |
) | |
return parser.parse_args() | |
@torch.no_grad() | |
def main(): | |
args = parse_args() | |
print(args) | |
if args.dtype == "auto": | |
torch_dtype = None | |
elif args.dtype in ("float16", "fp16"): | |
torch_dtype = torch.float16 | |
elif args.dtype in ("float32", "fp32"): | |
torch_dtype = torch.float32 | |
elif args.dtype in ("bfloat16", "bf16"): | |
torch_dtype = torch.bfloat16 | |
else: | |
print(f"Unsupported dtype: {args.dtype}") | |
sys.exit(1) | |
print(f"Loading tokenizer '{args.model_name}' ...") | |
tokenizer = AutoTokenizer.from_pretrained(args.model_name) | |
print(f"{type(tokenizer).__name__} (vocab_size={len(tokenizer)})") | |
print(f"Loading model '{args.model_name}' ({args.dtype}) ...") | |
model = AutoModelForCausalLM.from_pretrained( | |
args.model_name, | |
torch_dtype=torch_dtype, | |
cache_dir=args.cache_dir, | |
trust_remote_code=args.trust_remote_code, | |
) | |
print(f"{type(model).__name__} (num_parameters={model.num_parameters()})") | |
print("Model architecture:") | |
print(model) | |
old_input_shape = model.get_input_embeddings().weight.shape | |
print("input_embeddings shape before resize:", old_input_shape) | |
if model.get_output_embeddings() is not None and not model.config.tie_word_embeddings: | |
old_output_shape = model.get_output_embeddings().weight.shape | |
print("output_embeddings shape before resize:", old_output_shape) | |
else: | |
print(f"no (separate) output embeddings: {model.config.tie_word_embeddings=}") | |
model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=args.pad_vocab_size_to_multiple_of) | |
input_embeddings = model.get_input_embeddings() | |
print("new input_embeddings shape", input_embeddings.weight.shape) | |
# duplicate dummy embedding | |
def fill_dummy(embeddings, old_shape: torch.Size, dummy_id: int): | |
old_size = old_shape[0] | |
new_size = embeddings.weight.shape[0] | |
for i in range(old_size, new_size): | |
print(f"overwriting embedding[{i}] with embedding[{dummy_id}]") | |
embeddings.weight.data[i] = embeddings.weight.data[dummy_id] | |
if tokenizer.unk_token_id: | |
fill_dummy(input_embeddings, old_input_shape, tokenizer.unk_token_id) | |
output_embeddings = model.get_output_embeddings() | |
if output_embeddings is not None and not model.config.tie_word_embeddings: | |
print("new output_embeddings shape", output_embeddings.weight.shape) | |
if tokenizer.unk_token_id: | |
fill_dummy(output_embeddings, old_output_shape, tokenizer.unk_token_id) | |
else: | |
print(f"no (separate) output embeddings: {model.config.tie_word_embeddings=}") | |
# write model to output dir | |
if args.output_folder: | |
print(f"Saving model to: {args.output_folder}") | |
model.save_pretrained(args.output_folder, max_shard_size=args.max_shard_size) | |
print(f"Saving tokenizer to: {args.output_folder}") | |
tokenizer.save_pretrained(args.output_folder) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment