Created
April 18, 2023 14:37
-
-
Save SlyEcho/477554916bfc1a9e338240eee6396fbd to your computer and use it in GitHub Desktop.
Convert gpt4all-alpaca-oa-codealpaca-lora-13b to a HF checkpoint
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Combined https://github.com/tloen/alpaca-lora/blob/main/export_hf_checkpoint.py and | |
# https://huggingface.co/jordiclive/gpt4all-alpaca-oa-codealpaca-lora-13b#example-inference-code-note-several-embeddings-need-to-be-loaded-along-with-the-lora-weights-assumes-on-gpu-and-torchfloat16 | |
import os | |
import torch | |
import transformers | |
from peft import PeftModel | |
from transformers import LlamaForCausalLM, LlamaTokenizer, GenerationConfig | |
from typing import List, NamedTuple | |
#from huggingface_hub import hf_hub_download | |
device = "cpu" | |
tokenizer = LlamaTokenizer.from_pretrained("jordiclive/gpt4all-alpaca-oa-codealpaca-lora-13b") | |
base_model = LlamaForCausalLM.from_pretrained( | |
"decapoda-research/llama-13b-hf", | |
load_in_8bit=False, | |
torch_dtype=torch.float16, | |
device_map={"": "cpu"}, | |
) | |
base_model.resize_token_embeddings( | |
32005 # original had 32016 ?? | |
) # This model repo also contains several embeddings for special tokens that need to be loaded. | |
first_weight = base_model.model.layers[0].self_attn.q_proj.weight | |
first_weight_old = first_weight.clone() | |
lora_model = PeftModel.from_pretrained( | |
base_model, | |
"jordiclive/gpt4all-alpaca-oa-codealpaca-lora-13b", | |
device_map={"": "cpu"}, | |
torch_dtype=torch.float16, | |
) | |
filename = hf_hub_download("jordiclive/gpt4all-alpaca-oa-codealpaca-lora-13b", "extra_embeddings.pt") | |
embed_weights = torch.load( | |
filename, map_location=torch.device("cpu") | |
) # Load embeddings for special tokens | |
base_model.model.embed_tokens.weight[32000:32005, :] = embed_weights[0:5, :].to( | |
base_model.model.embed_tokens.weight.dtype | |
) # Add special token embeddings | |
lora_weight = lora_model.base_model.model.model.layers[ | |
0 | |
].self_attn.q_proj.weight | |
assert torch.allclose(first_weight_old, first_weight) | |
# merge weights - new merging method from peft | |
# pip install git+https://github.com/huggingface/peft | |
lora_model = lora_model.merge_and_unload() | |
lora_model.train(False) | |
# did we do anything? | |
assert not torch.allclose(first_weight_old, first_weight) | |
lora_model_sd = lora_model.state_dict() | |
deloreanized_sd = { | |
k.replace("base_model.model.", ""): v | |
for k, v in lora_model_sd.items() | |
if "lora" not in k | |
} | |
LlamaForCausalLM.save_pretrained( | |
base_model, "./hf_ckpt", state_dict=deloreanized_sd, max_shard_size="1000MB" | |
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
huggingface-hub==0.13.4 | |
peft @ git+https://github.com/huggingface/peft@0bdb54f03f651dab7818056f876c6629ef58f568 | |
torch==2.0.0 | |
transformers==4.28.1 | |
typing_extensions==4.5.0 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment