Last active
February 25, 2026 06:29
-
-
Save romitjain/89b64e31ec646b1e022bef42741f25ca to your computer and use it in GitHub Desktop.
PEFT LoRA weight tying in modules_to_save
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import os | |
| import torch | |
| from tqdm import tqdm | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from peft import LoraConfig, PeftModel, get_peft_model | |
| torch.manual_seed(42) | |
| @torch.no_grad() | |
| def print_means(m, adapter_name): | |
| """ | |
| Utility function to print and compare the means of the | |
| input/output embedding layers | |
| """ | |
| emb = m.get_input_embeddings() | |
| lm = m.get_output_embeddings() | |
| print(f"Embedding layer mean: {emb.weight.mean().item():.2e}") | |
| print(f"LM Head layer mean: {lm.weight.mean().item():.2e}") | |
| # Indicates if the embedding layer was added in `modules_to_save` | |
| if hasattr(emb, "modules_to_save"): | |
| print(f"Embedding layer module wrapper mean: {emb.modules_to_save[adapter_name].weight.mean().item():.2e}") | |
| if hasattr(lm, "modules_to_save"): | |
| print(f"LM Head layer module wrapper mean: {lm.modules_to_save[adapter_name].weight.mean().item():.2e}") | |
| try: | |
| assert torch.allclose(emb.weight, lm.weight), "Embedding and LM layer are not equal" | |
| assert emb.weight.data_ptr() == lm.weight.data_ptr(), ( | |
| "Embedding and LM layer do not have the same memory address" | |
| ) | |
| print("✅ Embedding and LM layer are equal") | |
| except Exception as err: | |
| print(f"❌ {err}") | |
| def main(args): | |
| model_name = "trl-internal-testing/tiny-Gemma2ForCausalLM" | |
| adapter_name = "default" | |
| device = "cuda:0" | |
| save_dir = "tmp" | |
| print("\nStep 1: Loading a model with tied weights") | |
| model = AutoModelForCausalLM.from_pretrained(model_name, device_map=device) | |
| tok = AutoTokenizer.from_pretrained(model_name) | |
| vocab = tok.vocab_size # vocab size | |
| seq_len = 64 # sequence length | |
| bsz = 4 # batch size | |
| steps = 10 # total train steps | |
| print(f"Is weight tying enabled?: {model.config.tie_word_embeddings}") | |
| print_means(model, adapter_name) | |
| modules_to_save = args.modules_to_save | |
| target_modules = ["q_proj"] | |
| lora_cfg = LoraConfig( | |
| modules_to_save=modules_to_save, | |
| target_modules=target_modules, | |
| task_type="CAUSAL_LM", | |
| ensure_weight_tying=args.ensure_weight_tying, | |
| ) | |
| print("\nStep 2: Adding LoRA adapters and trainable modules to the model") | |
| model = get_peft_model(model, lora_cfg) | |
| model.print_trainable_parameters() | |
| print_means(model, adapter_name) | |
| print("\nStep 3: Training the model with some dummy data") | |
| # Small training for demnonstration | |
| model.train() | |
| optim = torch.optim.AdamW([p for p in model.parameters() if p.requires_grad], lr=5e-4) | |
| for step in tqdm(range(steps), desc="training"): | |
| input_ids = torch.randint(low=0, high=vocab, size=(bsz, seq_len), device=device) | |
| attn = torch.ones_like(input_ids, device=device) | |
| out = model(input_ids=input_ids, attention_mask=attn, labels=input_ids) | |
| out.loss.backward() | |
| optim.step() | |
| optim.zero_grad() | |
| model.eval() | |
| print("\nStep 4: Post training, saving the model to a directory") | |
| print_means(model, adapter_name) | |
| os.makedirs(save_dir, exist_ok=True) | |
| model.save_pretrained(save_dir) | |
| print(f"\nSaved the model to {save_dir}") | |
| print("\nStep 5: Loading the saved model") | |
| reload_model = AutoModelForCausalLM.from_pretrained(model_name, device_map=device) | |
| reload_peft = PeftModel.from_pretrained(reload_model, save_dir) | |
| print(f"Is weight tying enabled?: {reload_model.config.tie_word_embeddings}") | |
| print_means(reload_peft, adapter_name) | |
| print("\nStep 6: Merging the adapters") | |
| merged_peft = reload_peft.merge_and_unload() # type: ignore | |
| merged_peft.eval() | |
| print_means(merged_peft, adapter_name) | |
| print("\nDone") | |
| if __name__ == "__main__": | |
| from argparse import ArgumentParser | |
| parser = ArgumentParser() | |
| parser.add_argument("--modules_to_save", nargs="+", type=str, help="List of module names to save") | |
| parser.add_argument("--ensure_weight_tying", action="store_true", help="Enable weight tying") | |
| args = parser.parse_args() | |
| main(args) |
Author
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Run command (after installing PEFT from main)
If targeting only
embed_tokensIf targeting both
embed_tokensandlm_headIf you want to ensure weight tying, just add the
--ensure_weight_tyingflag