Skip to content

Instantly share code, notes, and snippets.

@rkaplan
Created July 18, 2018 05:02
Show Gist options
  • Save rkaplan/873fb8b90f6828e49c56f27e9ed06bf0 to your computer and use it in GitHub Desktop.
Save rkaplan/873fb8b90f6828e49c56f27e9ed06bf0 to your computer and use it in GitHub Desktop.
import random
import torch
class DynamicNet(torch.nn.Module):
def __init__(self, D_in, H, D_out):
super(DynamicNet, self).__init__()
self.backbone = torch.nn.Linear(D_in, H)
self.head1 = torch.nn.Linear(H, D_out)
self.head2 = torch.nn.Linear(H, D_out)
def forward(self, x, use_head1=True):
h = self.backbone(x).clamp(min=0)
if use_head1:
return self.head1(h)
else:
return self.head2(h)
N, D_in, H, D_out = 8, 10, 30, 2
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)
model = DynamicNet(D_in, H, D_out)
criterion = torch.nn.MSELoss(size_average=False)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
MAKE_ZERO_GRADS_NONE = True
for t in range(5):
y_pred = model(x, use_head1=t==0)
loss = criterion(y_pred, y)
print('Iter {}:'.format(t), loss.item())
optimizer.zero_grad()
loss.backward()
print('Grads norms:')
for name, module in zip(['backbone', 'head1', 'head2'], [model.backbone, model.head1, model.head2]):
if MAKE_ZERO_GRADS_NONE and name == 'head1' and t > 0:
module.weight.grad = None
module.bias.grad = None
print(name, module.weight.grad.norm() if module.weight.grad is not None else 'None')
temp_weights_h1 = torch.tensor(model.head1.weight.data)
temp_weights_h2 = torch.tensor(model.head2.weight.data)
optimizer.step()
print('Norm of the head1 update:', (model.head1.weight - temp_weights_h1).norm())
print('Norm of the head2 update:', (model.head2.weight - temp_weights_h2).norm())
@rkaplan
Copy link
Author

rkaplan commented Jul 18, 2018

Output with MAKE_ZERO_GRADS_NONE = True:

Iter 0: 23.538616180419922
Grads norms:
backbone tensor(13.9520)
head1 tensor(25.7743)
head2 None
Norm of the head1 update: tensor(1.00000e-03 *
       2.5774)
Norm of the head2 update: tensor(0.)
Iter 1: 21.148746490478516
Grads norms:
backbone tensor(11.9030)
head1 None
head2 tensor(23.8908)
Norm of the head1 update: tensor(0.)
Norm of the head2 update: tensor(1.00000e-03 *
       2.3891)
Iter 2: 21.064136505126953
Grads norms:
backbone tensor(11.8662)
head1 None
head2 tensor(23.7841)
Norm of the head1 update: tensor(0.)
Norm of the head2 update: tensor(1.00000e-03 *
       4.5286)
Iter 3: 20.904308319091797
Grads norms:
backbone tensor(11.7976)
head1 None
head2 tensor(23.5923)
Norm of the head1 update: tensor(0.)
Norm of the head2 update: tensor(1.00000e-03 *
       6.4349)
Iter 4: 20.679258346557617
Grads norms:
backbone tensor(11.7025)
head1 None
head2 tensor(23.3263)
Norm of the head1 update: tensor(0.)
Norm of the head2 update: tensor(1.00000e-03 *
       8.1239)

Output with MAKE_ZERO_GRADS_NONE = False:

Grads norms:
backbone tensor(9.3807)
head1 tensor(18.8098)
head2 None
Norm of the head1 update: tensor(1.00000e-03 *
       1.8810)
Norm of the head2 update: tensor(0.)
Iter 1: 19.70391845703125
Grads norms:
backbone tensor(12.5777)
head1 tensor(0.)
head2 tensor(29.5197)
Norm of the head1 update: tensor(1.00000e-03 *
       1.6929)
Norm of the head2 update: tensor(1.00000e-03 *
       2.9520)
Iter 2: 19.580974578857422
Grads norms:
backbone tensor(12.4958)
head1 tensor(0.)
head2 tensor(29.3514)
Norm of the head1 update: tensor(1.00000e-03 *
       1.5236)
Norm of the head2 update: tensor(1.00000e-03 *
       5.5919)
Iter 3: 19.343791961669922
Grads norms:
backbone tensor(12.3394)
head1 tensor(0.)
head2 tensor(29.0215)
Norm of the head1 update: tensor(1.00000e-03 *
       1.3712)
Norm of the head2 update: tensor(1.00000e-03 *
       7.9349)
Iter 4: 19.009075164794922
Grads norms:
backbone tensor(12.1208)
head1 tensor(0.)
head2 tensor(28.5505)
Norm of the head1 update: tensor(1.00000e-03 *
       1.2341)
Norm of the head2 update: tensor(1.00000e-03 *
       9.9964)```

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment