-
-
Save amytimed/8acd6867c0d00ed4dcd7c3d1768678b7 to your computer and use it in GitHub Desktop.
merging script for trained model
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from unsloth import FastModel | |
| import torch | |
| import json | |
| import os | |
| import sys | |
| from safetensors.torch import load_file | |
| def main(): | |
| base_model_path = "Qwen/Qwen3.5-2B" | |
| lora_model_path = "./vivian_model" | |
| output_model_path = "vivian_merged" | |
| print("[1/3] Loading base model") | |
| model, tokenizer = FastModel.from_pretrained( | |
| model_name=base_model_path, | |
| max_seq_length=8192, | |
| load_in_4bit=False, | |
| load_in_8bit=False, | |
| full_finetuning=False, | |
| use_exact_model_name=True, | |
| ) | |
| print("[2/3] Manually merging LoRA weights") | |
| # Load adapter config to get scaling | |
| with open(os.path.join(lora_model_path, "adapter_config.json")) as f: | |
| adapter_config = json.load(f) | |
| r = adapter_config["r"] | |
| lora_alpha = adapter_config["lora_alpha"] | |
| scaling = lora_alpha / r | |
| # Load LoRA weights | |
| lora_weights = load_file(os.path.join(lora_model_path, "adapter_model.safetensors")) | |
| # Group into A/B pairs and merge | |
| merged_count = 0 | |
| base_state = model.state_dict() | |
| # Find all lora_A keys | |
| for key in list(lora_weights.keys()): | |
| if "lora_A" not in key: | |
| continue | |
| b_key = key.replace("lora_A", "lora_B") | |
| if b_key not in lora_weights: | |
| continue | |
| base_key = (key | |
| .replace("base_model.model.", "") | |
| .replace(".lora_A.weight", ".weight") | |
| ) | |
| if base_key not in base_state: | |
| print(f" Warning: {base_key} not found in model, skipping") | |
| continue | |
| A = lora_weights[key].to(torch.bfloat16).to(base_state[base_key].device) | |
| B = lora_weights[b_key].to(torch.bfloat16).to(base_state[base_key].device) | |
| # LoRA merge: W = W + scaling * B @ A | |
| base_state[base_key].data += scaling * (B @ A) | |
| merged_count += 1 | |
| print(f" Merged {merged_count} LoRA pairs (scaling={scaling})") | |
| model.load_state_dict(base_state) | |
| print(f"[3/3] Saving to {output_model_path}") | |
| os.makedirs(output_model_path, exist_ok=True) | |
| model.save_pretrained(output_model_path) | |
| tokenizer.save_pretrained(output_model_path) | |
| print("Done!") | |
| python_executable = sys.executable | |
| args = [python_executable, "inference.py"] | |
| os.execv(python_executable, args) | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment