Created
September 8, 2023 12:59
-
-
Save youkaichao/7ed49dcb55b2e66dfd841b1a9b0bfeff to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import nn | |
import copy | |
class BackboneModel(nn.Module): | |
def __init__(self, *args, **kwargs) -> None: | |
super().__init__(*args, **kwargs) | |
self.conv1 = nn.Conv2d(16, 16, 6) | |
self.bn1 = nn.BatchNorm2d(16) | |
def forward(self, x): | |
x = self.bn1(self.conv1(x)) | |
return x | |
model = BackboneModel().eval() | |
optim = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) | |
old_model = copy.deepcopy(model) # remember the old model | |
opt_model = torch.compile(model) # compile the model | |
a = torch.rand(64, 16, 32, 32) | |
with torch.no_grad(): | |
model.eval() | |
output1 = model(a) | |
output2 = opt_model(a) | |
print("diff of raw model and optimized model at initialization:") | |
print((output1 - output2).abs().max().item()) | |
# train the model and update the weight | |
model.train() | |
for i in range(10): | |
optim.zero_grad() | |
output = opt_model(a) | |
output.sum().backward() | |
optim.step() | |
with torch.no_grad(): | |
model.eval() | |
output1 = model(a) | |
output2 = opt_model(a) | |
output3 = old_model(a) | |
print("diff of raw model and optimized model after training:") | |
print((output1 - output2).abs().max().item()) | |
print("diff of old model and optimized model after training:") | |
print((output3 - output2).abs().max().item()) # the opt_model remembers the old model parameters, and don't use the new weight! |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment