Skip to content

Instantly share code, notes, and snippets.

@catid
Created February 21, 2024 16:24
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save catid/e0702e92ca4d225bb08e5936ba679753 to your computer and use it in GitHub Desktop.
Save catid/e0702e92ca4d225bb08e5936ba679753 to your computer and use it in GitHub Desktop.
Never Forget
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import math
#torch.autograd.set_detect_anomaly(True)
class FeedForward(torch.nn.Module):
def __init__(self, input_features, output_features):
super().__init__()
self.input_features = input_features
self.output_features = output_features
hidden_features = input_features * 4
self.proj_in = nn.Linear(input_features, hidden_features)
self.act = nn.GELU()
self.proj_out = nn.Linear(hidden_features, output_features)
def forward(self, x):
x = self.proj_in(x)
x = self.act(x)
x = self.proj_out(x)
return x
if __name__ == "__main__":
torch.manual_seed(5)
input_features, output_features = 3, 2
batch_size = 4
model = FeedForward(input_features, output_features)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.01)
criterion = torch.nn.L1Loss()
x0 = torch.randn(batch_size, input_features)
y0 = torch.randn(batch_size, output_features)
model.train()
for epoch in range(1000):
optimizer.zero_grad()
output = model(x0)
loss = criterion(output, y0)
loss.backward()
#print(f"model.weight.grad = {model.weight.grad}")
optimizer.step()
# Store model weights from first training run, and calculate normalized FIM coefficients
original_params = {}
fims = {}
for name, param in model.named_parameters():
if param.requires_grad:
original_params[name] = param.data.clone().detach()
fim = param.grad ** 2
fim_norm = torch.norm(fim, p=1)
if fim_norm > 0: # Avoid division by zero
normalized_fim = fim / fim_norm
else:
normalized_fim = fim
fims[name] = normalized_fim
loss = criterion(model(x0), y0)
print(f"dataset0: epoch {epoch} loss = {loss.item()}")
x1 = torch.randn(batch_size, input_features)
y1 = torch.randn(batch_size, output_features)
model.train()
for epoch in range(1000):
optimizer.zero_grad()
output = model(x1)
loss = criterion(output, y1)
preloss = loss.clone()
for name, param in model.named_parameters():
# Compute penalty as the sum of FIM values times squared parameter changes
fim_contribution = fims[name] * (param.data - original_params[name]).pow(2)
#fim_contribution = (param.data - original_params[name]).pow(2)
loss += 1.0 * fim_contribution.sum()
#print(f"loss delta={(loss-preloss)*100/preloss}%")
loss.backward()
if epoch == 10:
for name, param in model.named_parameters():
if param.requires_grad:
print(f"param {name}: param.data={param.grad}")
optimizer.step()
loss = criterion(model(x1), y1)
print(f"dataset1: epoch {epoch} loss = {loss.item()}")
loss = criterion(model(x0), y0)
print(f"dataset0: epoch {epoch} loss = {loss.item()} (forgotten)")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment