Skip to content

Instantly share code, notes, and snippets.

@silphendio
Last active April 18, 2024 06:40
Show Gist options
  • Star 5 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save silphendio/90f7e23b2b1ab6949fd4b35e7dd705cf to your computer and use it in GitHub Desktop.
Save silphendio/90f7e23b2b1ab6949fd4b35e7dd705cf to your computer and use it in GitHub Desktop.
Cutting up a llama and putting it back together
# A simple script to demonstrate the sclicing and recombination of models at runtime
# inspired by mergekit
# Sadly, it doesn't work with quantisized models.
#
# public domain - silphendio
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
import torch
model_path = 'gpt2' # huggingface name or local folder
output_folder = 'sliced_llama'
layer_arrangement = list(range(0,8)) + list(range(4,12))
model = AutoModelForCausalLM.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)
# rearrange layers
new_state_dict = model.state_dict().copy()
layer_keys_template = [key.replace('.0.', '.{}.') for key in model.state_dict() if '.0.' in key]
for new_layer, old_layer in enumerate(layer_arrangement):
for key in layer_keys_template:
new_state_dict[key.format(new_layer)] = model.state_dict()[key.format(old_layer)]
new_config = model.config
new_config.n_layer = len(layer_arrangement) # for gpt2
new_config.num_hidden_layers = len(layer_arrangement) # for mistral / llama
# save the merged model
new_config.save_pretrained(output_folder)
torch.save(new_state_dict, output_folder + '/pytorch_model.bin')
# load the merged model from memory
model = AutoModelForCausalLM.from_pretrained(model_path, config=new_config, state_dict=new_state_dict)
del new_state_dict # don't need this anymore (too bad transformers couldn't reuse the memory)
##### test the merged model
prompt = "In a shocking finding, scientist discovered a herd of unicorns living in a remote, previously unexplored valley, in the Andes Mountains."
inputs = tokenizer(prompt, return_tensors="pt")
streamer = TextStreamer(tokenizer)
model.generate(**inputs, streamer=streamer, do_sample=True, max_new_tokens=250)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment