Skip to content

Instantly share code, notes, and snippets.

@youkaichao
Created September 8, 2023 12:59
Show Gist options
  • Save youkaichao/7ed49dcb55b2e66dfd841b1a9b0bfeff to your computer and use it in GitHub Desktop.
Save youkaichao/7ed49dcb55b2e66dfd841b1a9b0bfeff to your computer and use it in GitHub Desktop.
import torch
from torch import nn
import copy
class BackboneModel(nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.conv1 = nn.Conv2d(16, 16, 6)
self.bn1 = nn.BatchNorm2d(16)
def forward(self, x):
x = self.bn1(self.conv1(x))
return x
model = BackboneModel().eval()
optim = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
old_model = copy.deepcopy(model) # remember the old model
opt_model = torch.compile(model) # compile the model
a = torch.rand(64, 16, 32, 32)
with torch.no_grad():
model.eval()
output1 = model(a)
output2 = opt_model(a)
print("diff of raw model and optimized model at initialization:")
print((output1 - output2).abs().max().item())
# train the model and update the weight
model.train()
for i in range(10):
optim.zero_grad()
output = opt_model(a)
output.sum().backward()
optim.step()
with torch.no_grad():
model.eval()
output1 = model(a)
output2 = opt_model(a)
output3 = old_model(a)
print("diff of raw model and optimized model after training:")
print((output1 - output2).abs().max().item())
print("diff of old model and optimized model after training:")
print((output3 - output2).abs().max().item()) # the opt_model remembers the old model parameters, and don't use the new weight!
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment