Skip to content

Instantly share code, notes, and snippets.

@InnovArul
Last active November 20, 2021 21:16
Show Gist options
  • Save InnovArul/ee2986feff215fed94401ec83367408b to your computer and use it in GitHub Desktop.
Save InnovArul/ee2986feff215fed94401ec83367408b to your computer and use it in GitHub Desktop.
to freeze weights and avoid weight decay of frozen weights
import torch, torch.nn as nn
import torch.optim as optim, torch.nn.functional as F
class CustomLinearNoWeightDecay(nn.Module):
def __init__(self, mask):
super().__init__()
self.register_buffer("mask", mask)
out_channels, in_channels = mask.shape
self.weight = nn.Parameter(torch.randn(out_channels, in_channels))
fixed_weight = (mask * self.weight).detach()
self.register_buffer("fixed_weight", fixed_weight)
self.bias = nn.Parameter(torch.randn(out_channels))
def forward(self, x):
weight = (self.mask * self.fixed_weight) + (1 - self.mask) * self.weight
out = F.linear(x, weight, self.bias)
return out
if __name__ == '__main__':
mask = (torch.rand(3,4) > 0.5).float()
print("mask", mask)
lin = CustomLinearNoWeightDecay(mask)
for i in range(100):
inp = torch.randn(10, 4)
out = lin(inp)
out.sum().backward()
print(lin.weight.grad)
lin.weight.grad = None
input()
import torch, torch.nn as nn
import torch.optim as optim, torch.nn.functional as F
class CustomLinearWithWeightDecay(nn.Module):
def __init__(self, mask):
super().__init__()
self.register_buffer("mask", mask)
out_channels, in_channels = mask.shape
self.weight = nn.Parameter(torch.randn(out_channels, in_channels))
self.bias = nn.Parameter(torch.randn(out_channels))
def forward(self, x):
weight = (self.mask * self.weight).detach() + (1 - self.mask) * self.weight
out = F.linear(x, weight, self.bias)
return out
if __name__ == '__main__':
mask = (torch.rand(3,4) > 0.5).float()
print("mask", mask)
lin = CustomLinearWithWeightDecay(mask)
for i in range(100):
inp = torch.randn(10, 4)
out = lin(inp)
out.sum().backward()
print(lin.weight.grad)
lin.weight.grad = None
input()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment