Skip to content

Instantly share code, notes, and snippets.

@romitjain
Last active February 25, 2026 06:29
Show Gist options
  • Select an option

  • Save romitjain/89b64e31ec646b1e022bef42741f25ca to your computer and use it in GitHub Desktop.

Select an option

Save romitjain/89b64e31ec646b1e022bef42741f25ca to your computer and use it in GitHub Desktop.
PEFT LoRA weight tying in modules_to_save
import os
import torch
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig, PeftModel, get_peft_model
torch.manual_seed(42)
@torch.no_grad()
def print_means(m, adapter_name):
"""
Utility function to print and compare the means of the
input/output embedding layers
"""
emb = m.get_input_embeddings()
lm = m.get_output_embeddings()
print(f"Embedding layer mean: {emb.weight.mean().item():.2e}")
print(f"LM Head layer mean: {lm.weight.mean().item():.2e}")
# Indicates if the embedding layer was added in `modules_to_save`
if hasattr(emb, "modules_to_save"):
print(f"Embedding layer module wrapper mean: {emb.modules_to_save[adapter_name].weight.mean().item():.2e}")
if hasattr(lm, "modules_to_save"):
print(f"LM Head layer module wrapper mean: {lm.modules_to_save[adapter_name].weight.mean().item():.2e}")
try:
assert torch.allclose(emb.weight, lm.weight), "Embedding and LM layer are not equal"
assert emb.weight.data_ptr() == lm.weight.data_ptr(), (
"Embedding and LM layer do not have the same memory address"
)
print("✅ Embedding and LM layer are equal")
except Exception as err:
print(f"❌ {err}")
def main(args):
model_name = "trl-internal-testing/tiny-Gemma2ForCausalLM"
adapter_name = "default"
device = "cuda:0"
save_dir = "tmp"
print("\nStep 1: Loading a model with tied weights")
model = AutoModelForCausalLM.from_pretrained(model_name, device_map=device)
tok = AutoTokenizer.from_pretrained(model_name)
vocab = tok.vocab_size # vocab size
seq_len = 64 # sequence length
bsz = 4 # batch size
steps = 10 # total train steps
print(f"Is weight tying enabled?: {model.config.tie_word_embeddings}")
print_means(model, adapter_name)
modules_to_save = args.modules_to_save
target_modules = ["q_proj"]
lora_cfg = LoraConfig(
modules_to_save=modules_to_save,
target_modules=target_modules,
task_type="CAUSAL_LM",
ensure_weight_tying=args.ensure_weight_tying,
)
print("\nStep 2: Adding LoRA adapters and trainable modules to the model")
model = get_peft_model(model, lora_cfg)
model.print_trainable_parameters()
print_means(model, adapter_name)
print("\nStep 3: Training the model with some dummy data")
# Small training for demnonstration
model.train()
optim = torch.optim.AdamW([p for p in model.parameters() if p.requires_grad], lr=5e-4)
for step in tqdm(range(steps), desc="training"):
input_ids = torch.randint(low=0, high=vocab, size=(bsz, seq_len), device=device)
attn = torch.ones_like(input_ids, device=device)
out = model(input_ids=input_ids, attention_mask=attn, labels=input_ids)
out.loss.backward()
optim.step()
optim.zero_grad()
model.eval()
print("\nStep 4: Post training, saving the model to a directory")
print_means(model, adapter_name)
os.makedirs(save_dir, exist_ok=True)
model.save_pretrained(save_dir)
print(f"\nSaved the model to {save_dir}")
print("\nStep 5: Loading the saved model")
reload_model = AutoModelForCausalLM.from_pretrained(model_name, device_map=device)
reload_peft = PeftModel.from_pretrained(reload_model, save_dir)
print(f"Is weight tying enabled?: {reload_model.config.tie_word_embeddings}")
print_means(reload_peft, adapter_name)
print("\nStep 6: Merging the adapters")
merged_peft = reload_peft.merge_and_unload() # type: ignore
merged_peft.eval()
print_means(merged_peft, adapter_name)
print("\nDone")
if __name__ == "__main__":
from argparse import ArgumentParser
parser = ArgumentParser()
parser.add_argument("--modules_to_save", nargs="+", type=str, help="List of module names to save")
parser.add_argument("--ensure_weight_tying", action="store_true", help="Enable weight tying")
args = parser.parse_args()
main(args)
@romitjain
Copy link
Copy Markdown
Author

romitjain commented Nov 14, 2025

Run command (after installing PEFT from main)

If targeting only embed_tokens

python lora_modules_to_save.py --modules_to_save embed_tokens

If targeting both embed_tokens and lm_head

python lora_modules_to_save.py --modules_to_save embed_tokens lm_head

If you want to ensure weight tying, just add the --ensure_weight_tying flag

python lora_modules_to_save.py --modules_to_save embed_tokens --ensure_weight_tying

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment