Skip to content

Instantly share code, notes, and snippets.

@ptrblck
Created November 26, 2018 22:13
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ptrblck/4a32a16b1ddacc9ef35690b5d948654d to your computer and use it in GitHub Desktop.
Save ptrblck/4a32a16b1ddacc9ef35690b5d948654d to your computer and use it in GitHub Desktop.
"""
Script to update parameters and the optimizer on the fly.
It is not recommended to use this approach.
This script just shows the disadvantages using this approach.
@author: ptrblck
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.weight = nn.Parameter(torch.randn(10, 2))
def forward(self, x):
x = F.linear(x, self.weight)
return x
# Create model and initialize all params
model = MyModel()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-1)
print(optimizer.state_dict()) # state is empty
criterion = nn.MSELoss()
x = torch.randn(1, 2)
target = torch.randn(1, 10)
# Train for a few epochs
for epoch in range(10):
optimizer.zero_grad()
output = model(x)
loss = criterion(output, target)
loss.backward()
optimizer.step()
print('Epoch {}, loss {}'.format(epoch, loss.item()))
# Store old id of parameters
old_id = id(model.weight)
# Add another input feature
with torch.no_grad():
model.weight = nn.Parameter(
torch.cat((model.weight, torch.randn(10, 1)), 1)
)
# Store new id
new_id = id(model.weight)
# Get old state_dict and store all internals
opt_state_dict = optimizer.state_dict()
step = opt_state_dict['state'][old_id]['step']
exp_avg = opt_state_dict['state'][old_id]['exp_avg']
exp_avg_sq = opt_state_dict['state'][old_id]['exp_avg_sq']
# Extend exp_avg_* to match new shape
exp_avg = torch.cat((exp_avg, torch.zeros(10, 1)), 1)
exp_avg_sq = torch.cat((exp_avg_sq, torch.zeros(10, 1)), 1)
# Delete old id from state_dict and update with new params and new id
del opt_state_dict['state'][old_id]
opt_state_dict['state'] = {
new_id: {
'step': step,
'exp_avg': exp_avg,
'exp_avg_sq': exp_avg_sq
}
}
opt_state_dict['param_groups'][0]['params'].remove(old_id)
opt_state_dict['param_groups'][0]['params'].append(new_id)
# Create new optimizer and load state_dict with running estimates for old
# parameters
optimizer = optim.Adam(model.parameters(), lr=1e-1)
optimizer.load_state_dict(opt_state_dict)
# Continue training
x = torch.randn(1, 3)
for epoch in range(10):
optimizer.zero_grad()
output = model(x)
loss = criterion(output, target)
loss.backward()
optimizer.step()
print('Epoch {}, loss {}'.format(epoch, loss.item()))
@JoeHarrison
Copy link

Small correction on line 67 for if you're planning on doing this for multiple parameters: opt_state_dict['state'] <- opt_state_dict['state'][new_id]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment