Skip to content

Instantly share code, notes, and snippets.

@Raibows
Created May 13, 2024 02:05
Show Gist options
  • Save Raibows/079713a060f0c49c8f3b47c227aff722 to your computer and use it in GitHub Desktop.
Save Raibows/079713a060f0c49c8f3b47c227aff722 to your computer and use it in GitHub Desktop.
split merged qkv_proj of lora for phi-3, useful for enabling lora in VLLM.
import torch
import os
import json
from safetensors.torch import load_file, save_file
def replicate_lora_a(name: str, weight: "torch.Tensor") -> dict[str, "torch.Tensor"]:
prefix, suffix = name.split('qkv_proj')
res = {}
for t in ['q_proj', 'k_proj', 'v_proj']:
name = f"{prefix}{t}{suffix}"
res[name] = weight.clone()
return res
def split_lora_b(name: str, weight: "torch.Tensor") -> dict[str, "torch.Tensor"]:
size = weight.shape[0] // 3
prefix, suffix = name.split('qkv_proj')
res = {
f"{prefix}{t}{suffix}": w
for t, w in zip(['q_proj', 'k_proj', 'v_proj'], weight.split(size))
}
return res
def convert_qkv_lora_to_splits_vllm(adapter_folder_path: str, output_folder_path: str) -> dict[str, torch.Tensor]:
"""return the new adapter dict"""
adapter_bin_name = 'adapter_model.safetensors'
adapter_config_name = 'adapter_config.json'
lora = load_file(f"{adapter_folder_path}/{adapter_bin_name}")
with open(f"{adapter_folder_path}/{adapter_config_name}", 'r') as f:
lora_config = json.load(f)
assert 'qkv_proj' in lora_config['target_modules']
assert lora_config['base_model_name_or_path'] == "microsoft/Phi-3-mini-4k-instruct"
# converting weights
res = {}
for k, v in lora.items():
if 'qkv_proj' in k and 'lora_A' in k:
res.update(replicate_lora_a(k, v))
elif 'qkv_proj' in k and 'lora_B' in k:
res.update(split_lora_b(k, v))
else:
res[k] = v
# converting config
temp = ['q_proj', 'k_proj', 'v_proj'] + [t for t in lora_config['target_modules'] if t != 'qkv_proj']
lora_config['target_modules'] = temp
# saving
os.makedirs(output_folder_path, exist_ok=True)
save_file(res, f"{output_folder_path}/{adapter_bin_name}", metadata={"format": "pt"})
with open(f"{output_folder_path}/{adapter_config_name}", 'w') as f:
json.dump(lora_config, f, indent=4)
return res
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment