Skip to content

Instantly share code, notes, and snippets.

@dvruette
Created January 18, 2024 08:54
Show Gist options
  • Save dvruette/9858bd870d4d23f963a519aabb2e048e to your computer and use it in GitHub Desktop.
Save dvruette/9858bd870d4d23f963a519aabb2e048e to your computer and use it in GitHub Desktop.
Weight decay to model initialization
import copy
import torch
import torch.nn as nn
class DecayToInit(nn.Module):
def __init__(self, param: torch.Tensor):
super().__init__()
self.register_buffer("param", param)
def forward(self, delta: torch.Tensor) -> torch.Tensor:
return self.param + delta
def add_decay_to_init(model: nn.Module):
for m in list(model.modules()):
for name, param in list(m.named_parameters(recurse=False)):
init_weights = param.data.clone().detach()
param.data.zero_()
nn.utils.parametrize.register_parametrization(m, name, DecayToInit(init_weights))
def merge_decay_to_init(model: nn.Module):
for m in list(model.modules()):
if nn.utils.parametrize.is_parametrized(m):
for name, _ in list(m.named_parameters(recurse=True)):
original_name = name.replace("parametrizations.", "").replace(".original", "")
nn.utils.parametrize.remove_parametrizations(m, original_name, leave_parametrized=True)
def train(model, seed=0):
torch.manual_seed(seed)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-1)
N = 4096
xs = torch.randn(N, 768)
ys = torch.randint(0, 10, (N,))
dataset = torch.utils.data.TensorDataset(xs, ys)
dl = torch.utils.data.DataLoader(dataset, batch_size=32)
device = next(model.parameters()).device
step = 0
for epoch in range(10):
for x, y in dl:
x, y = x.to(device), y.to(device)
optimizer.zero_grad()
logits = model(x)
loss = nn.CrossEntropyLoss()(logits, y)
loss.backward()
optimizer.step()
step += 1
print(f"step {step:4d}: {loss.item():.3f}")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(42)
model_init = nn.Sequential(
nn.Linear(768, 128),
nn.ReLU(),
nn.Linear(128, 128),
nn.ReLU(),
nn.Linear(128, 10),
).to(device)
for p in model_init.parameters():
torch.nn.init.uniform_(p, 0, 1)
model_a = copy.deepcopy(model_init)
model_b = copy.deepcopy(model_init)
print("=== Training w/ decay to init ======")
add_decay_to_init(model_a)
train(model_a)
merge_decay_to_init(model_a)
print("=== Training w/o decay to init =====")
train(model_b)
print("====================================")
for init, a, b in zip(model_init.parameters(), model_a.parameters(), model_b.parameters()):
diff_a = torch.norm(init - a, p=2)
diff_b = torch.norm(init - b, p=2)
print(f"{diff_a.item():.3f}", f"{diff_b.item():.3f}", "a is closer" if diff_a < diff_b else "b is closer")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment