Skip to content

Instantly share code, notes, and snippets.

@ptrblck
Created October 24, 2018 14:12
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save ptrblck/e78b59052f4bea6f91ec35f9f15adfb4 to your computer and use it in GitHub Desktop.
Save ptrblck/e78b59052f4bea6f91ec35f9f15adfb4 to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
torch.manual_seed(2809)
def check_params(modelA, modelB):
for key in modelA.state_dict():
is_equal = (modelA.state_dict()[key]==modelB.state_dict()[key]).all()
print('Checking {}, is equal = {}'.format(key, is_equal))
if not is_equal:
print('ERROR!')
break
def check_grads(modelA, modelB):
for name, module in modelA.named_parameters():
module_name = name.split('.')[0]
param_name = name.split('.')[1]
modelB_grad = getattr(getattr(modelB, module_name), param_name).grad
is_equal = (module.grad==modelB_grad).all()
print('Gradient for {} is equal {}'.format(name, is_equal))
if not is_equal:
print('ERROR!')
break
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 3, 1, 1)
self.pool1 = nn.MaxPool2d(2)
self.conv2 = nn.Conv2d(6, 12, 3, 1, 1)
self.pool2 = nn.MaxPool2d(2)
self.fc = nn.Linear(12*6*6, 2)
def forward(self, x):
x = F.relu(self.conv1(x))
x = self.pool1(x)
x = F.relu(self.conv2(x))
x = self.pool2(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
class MyModelUnused(nn.Module):
def __init__(self):
super(MyModelUnused, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 3, 1, 1)
self.pool1 = nn.MaxPool2d(2)
self.conv2 = nn.Conv2d(6, 12, 3, 1, 1)
self.pool2 = nn.MaxPool2d(2)
self.fc = nn.Linear(12*6*6, 2)
self.conv_unused1 = nn.Conv2d(12, 24, 3, 1, 1)
self.conv_unused2 = nn.Conv2d(24, 12, 3, 1, 1)
def forward(self, x):
x = F.relu(self.conv1(x))
x = self.pool1(x)
x = F.relu(self.conv2(x))
x = self.pool2(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
x = torch.randn(10, 3, 24, 24)
target = torch.empty(10, dtype=torch.long).random_(2)
criterion = nn.CrossEntropyLoss()
torch.manual_seed(2809)
modelA = MyModel()
torch.manual_seed(2809)
modelB = MyModelUnused()
# Check weights for equality
check_params(modelA, modelB)
optimizerA = optim.Adam(modelA.parameters(), lr=1e-3)
optimizerB = optim.Adam(modelB.parameters(), lr=1e-3)
for epoch in range(10):
print('Checking epoch {}'.format(epoch))
optimizerA.zero_grad()
optimizerB.zero_grad()
check_params(modelA, modelB)
outputA = modelA(x)
outputB = modelB(x)
(outputA==outputB).all()
lossA = criterion(outputA, target)
lossB = criterion(outputB, target)
(lossA==lossB).all()
lossA.backward()
lossB.backward()
check_grads(modelA, modelB)
optimizerA.step()
optimizerB.step()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment