Skip to content

Instantly share code, notes, and snippets.

@amytimed
Created March 3, 2026 00:01
Show Gist options
  • Select an option

  • Save amytimed/8acd6867c0d00ed4dcd7c3d1768678b7 to your computer and use it in GitHub Desktop.

Select an option

Save amytimed/8acd6867c0d00ed4dcd7c3d1768678b7 to your computer and use it in GitHub Desktop.
merging script for trained model
from unsloth import FastModel
import torch
import json
import os
import sys
from safetensors.torch import load_file
def main():
base_model_path = "Qwen/Qwen3.5-2B"
lora_model_path = "./vivian_model"
output_model_path = "vivian_merged"
print("[1/3] Loading base model")
model, tokenizer = FastModel.from_pretrained(
model_name=base_model_path,
max_seq_length=8192,
load_in_4bit=False,
load_in_8bit=False,
full_finetuning=False,
use_exact_model_name=True,
)
print("[2/3] Manually merging LoRA weights")
# Load adapter config to get scaling
with open(os.path.join(lora_model_path, "adapter_config.json")) as f:
adapter_config = json.load(f)
r = adapter_config["r"]
lora_alpha = adapter_config["lora_alpha"]
scaling = lora_alpha / r
# Load LoRA weights
lora_weights = load_file(os.path.join(lora_model_path, "adapter_model.safetensors"))
# Group into A/B pairs and merge
merged_count = 0
base_state = model.state_dict()
# Find all lora_A keys
for key in list(lora_weights.keys()):
if "lora_A" not in key:
continue
b_key = key.replace("lora_A", "lora_B")
if b_key not in lora_weights:
continue
base_key = (key
.replace("base_model.model.", "")
.replace(".lora_A.weight", ".weight")
)
if base_key not in base_state:
print(f" Warning: {base_key} not found in model, skipping")
continue
A = lora_weights[key].to(torch.bfloat16).to(base_state[base_key].device)
B = lora_weights[b_key].to(torch.bfloat16).to(base_state[base_key].device)
# LoRA merge: W = W + scaling * B @ A
base_state[base_key].data += scaling * (B @ A)
merged_count += 1
print(f" Merged {merged_count} LoRA pairs (scaling={scaling})")
model.load_state_dict(base_state)
print(f"[3/3] Saving to {output_model_path}")
os.makedirs(output_model_path, exist_ok=True)
model.save_pretrained(output_model_path)
tokenizer.save_pretrained(output_model_path)
print("Done!")
python_executable = sys.executable
args = [python_executable, "inference.py"]
os.execv(python_executable, args)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment