Skip to content

Instantly share code, notes, and snippets.

@goddoe
Last active December 8, 2023 05:41
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save goddoe/a8a172442da7fdd4f4198ff3c3b4da90 to your computer and use it in GitHub Desktop.
Save goddoe/a8a172442da7fdd4f4198ff3c3b4da90 to your computer and use it in GitHub Desktop.
8k_to_4k.py
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
input_path = "./model_in"
output_path = "./model_out"
max_shard_size = "5GB"
new_max_length = 4096
print("load model...start")
model = AutoModelForCausalLM.from_pretrained(input_path)
print("load model...done")
print(model.dtype)
print("load tokenizer...start")
tokenizer = AutoTokenizer.from_pretrained(input_path)
print("load tokenizer...done")
tokenizer.model_max_length = new_max_length
# New positional embedding
old_embeddings = model.transformer.wpe.weight
new_embeddings = old_embeddings[:new_max_length, :]
model.transformer.wpe.weight = torch.nn.Parameter(new_embeddings)
# Update Config
model.config.n_positions = new_max_length
# Save
model.save_pretrained(output_path, max_shard_size=max_shard_size)
tokenizer.save_pretrained(output_path)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment