Skip to content

Instantly share code, notes, and snippets.

@renesax14
Last active June 18, 2020 18:08
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save renesax14/d1b119153a3866cff36812c9b218aac6 to your computer and use it in GitHub Desktop.
Save renesax14/d1b119153a3866cff36812c9b218aac6 to your computer and use it in GitHub Desktop.
checking the grads are not zero when copy weights is false
def test_training_initial_weights():
import torch
import torch.optim as optim
import torch.nn as nn
from collections import OrderedDict
## training config
#device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
episodes = 5
nb_inner_train_steps = 5
## get base model
base_mdl = nn.Sequential(OrderedDict([
('fc', nn.Linear(1,1, bias=False)),
('relu', nn.ReLU())
]))
## get outer optimizer (not differentiable nor trainable)
outer_opt = optim.Adam(base_mdl.parameters(), lr=0.01)
for episode in range(episodes):
spt_x, spt_y, qry_x, qry_y = torch.randn(1), torch.randn(1), torch.randn(1), torch.randn(1)
inner_opt = torch.optim.SGD(base_mdl.parameters(), lr=1e-1)
with higher.innerloop_ctx(base_mdl, inner_opt, copy_initial_weights=False, track_higher_grads=False) as (fmodel, diffopt):
for i_inner in range(nb_inner_train_steps): # this current version implements full gradient descent on k_shot examples (which is usually small 5)
fmodel.train()
# base/child model forward pass
inner_loss = 0.5*((fmodel(spt_x) - spt_y))**2
# inner-opt update
diffopt.step(inner_loss)
## Evaluate on query set for current task
qry_loss = 0.5*((fmodel(qry_x) - qry_y))**2
qry_loss.backward() # for memory efficient computation
## outer update
print(f'episode = {episode}')
print(f'base_mdl.grad = {base_mdl.fc.weight.grad}')
outer_opt.step()
outer_opt.zero_grad()
if __name__ == '__main__':
test_training_initial_weights()
print('Done \a')
@renesax14
Copy link
Author

renesax14 commented Jun 18, 2020

output of code:

episode = 0
base_mdl.grad = None
episode = 1
base_mdl.grad = None
episode = 2
base_mdl.grad = None
episode = 3
base_mdl.grad = None
episode = 4
base_mdl.grad = None
Done 

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment