Skip to content

Instantly share code, notes, and snippets.

@romitjain
Last active February 25, 2026 07:00
Show Gist options
  • Select an option

  • Save romitjain/e6e0902942d68d6e8a97e078a4524eba to your computer and use it in GitHub Desktop.

Select an option

Save romitjain/e6e0902942d68d6e8a97e078a4524eba to your computer and use it in GitHub Desktop.
PEFT LoRA weight tying in target_modules
import os
import torch
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig, PeftModel, get_peft_model
from peft.tuners.lora import LoraLayer
torch.manual_seed(42)
@torch.no_grad()
def print_means(m, adapter_name):
"""
Utility function to print and compare the means of the
input/output embedding layers
"""
emb = m.get_input_embeddings()
lm = m.get_output_embeddings()
print(f"Embedding layer mean: {emb.weight.mean().item():.2e}")
print(f"LM Head layer mean: {lm.weight.mean().item():.2e}")
# Indicates if the embedding layer has LoRA adapters added
if isinstance(emb, LoraLayer):
print(f"Embedding layer Lora A adapter mean: {emb.lora_embedding_A[adapter_name].mean().item():.2e}")
print(f"Embedding layer Lora B adapter mean: {emb.lora_embedding_B[adapter_name].mean().item():.2e}")
if isinstance(lm, LoraLayer):
print(f"LM head Lora A adapter mean: {lm.lora_A[adapter_name].weight.mean().item():.2e}") # type: ignore
print(f"LM head Lora B adapter mean: {lm.lora_B[adapter_name].weight.mean().item():.2e}") # type: ignore
try:
assert torch.allclose(emb.weight, lm.weight), "Embedding and LM layer are not equal"
assert emb.weight.data_ptr() == lm.weight.data_ptr(), (
"Embedding and LM layer do not have the same memory address"
)
print("✅ Embedding and LM layer are equal")
except Exception as err:
print(f"❌ {err}")
# Compare delta weights
e_eff = emb.weight
lm_eff = lm.weight
if isinstance(emb, LoraLayer):
e_eff = e_eff + emb.get_delta_weight(adapter_name) # type: ignore
if isinstance(lm, LoraLayer):
lm_eff = lm_eff + lm.get_delta_weight(adapter_name) # type: ignore
try:
assert torch.allclose(e_eff, lm_eff, atol=1e-4), (
"Embedding and LM layer effective weights are are not equal"
)
print("✅ Effective weights from lora adapters are same for Embedding and the LM layer")
except Exception as err:
print(f"❌ {err}")
def main(args):
model_name = "trl-internal-testing/tiny-Gemma2ForCausalLM"
adapter_name = "default"
device = "cuda:0"
save_dir = "tmp"
print("\nStep 1: Loading a model with tied weights")
model = AutoModelForCausalLM.from_pretrained(model_name, device_map=device)
tok = AutoTokenizer.from_pretrained(model_name)
vocab = tok.vocab_size # vocab size
seq_len = 64 # sequence length
bsz = 4 # batch size
steps = 10 # total train steps
print(f"Is weight tying enabled?: {model.config.tie_word_embeddings}")
print_means(model, adapter_name)
target_modules = args.target_modules
lora_cfg = LoraConfig(
target_modules=target_modules,
task_type="CAUSAL_LM",
ensure_weight_tying=args.ensure_weight_tying,
)
print("\nStep 2: Adding LoRA adapters and trainable modules to the model")
model = get_peft_model(model, lora_cfg)
model.print_trainable_parameters()
print_means(model, adapter_name)
print("\nStep 3: Training the model with some dummy data")
# Small training for demnonstration
model.train()
optim = torch.optim.AdamW([p for p in model.parameters() if p.requires_grad], lr=5e-4)
for step in tqdm(range(steps), desc="training"):
input_ids = torch.randint(low=0, high=vocab, size=(bsz, seq_len), device=device)
attn = torch.ones_like(input_ids, device=device)
out = model(input_ids=input_ids, attention_mask=attn, labels=input_ids)
out.loss.backward()
optim.step()
optim.zero_grad()
model.eval()
print("\nStep 4: Post training, saving the model to a directory")
print_means(model, adapter_name)
os.makedirs(save_dir, exist_ok=True)
model.save_pretrained(save_dir)
print(f"\nSaved the model to {save_dir}")
print("\nStep 5: Loading the saved model")
reload_model = AutoModelForCausalLM.from_pretrained(model_name, device_map=device)
reload_peft = PeftModel.from_pretrained(reload_model, save_dir)
print(f"Is weight tying enabled?: {reload_model.config.tie_word_embeddings}")
print_means(reload_peft, adapter_name)
print("\nStep 6: Merging the adapters")
merged_peft = reload_peft.merge_and_unload() # type: ignore
merged_peft.eval()
print_means(merged_peft, adapter_name)
print("\nDone")
if __name__ == "__main__":
from argparse import ArgumentParser
parser = ArgumentParser()
parser.add_argument("--target_modules", nargs="+", type=str, help="List of module names to save")
parser.add_argument("--ensure_weight_tying", action="store_true", help="Enable weight tying")
args = parser.parse_args()
main(args)
@romitjain
Copy link
Copy Markdown
Author

Run command (after installing PEFT from main)

If targeting only embed_tokens

python lora_target_modules.py --target_modules embed_tokens

If targeting both embed_tokens and lm_head

python lora_target_modules.py --target_modules embed_tokens lm_head

If you want to ensure weight tying, just add the --ensure_weight_tying flag

python lora_target_modules.py --target_modules embed_tokens --ensure_weight_tying

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment