Created
August 25, 2023 16:40
-
-
Save youkaichao/f2fab96fa121e197d40ad3726cff5c7f to your computer and use it in GitHub Desktop.
Demonstrate very dynamic usage case of conv-bn pairs.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import nn | |
import copy | |
from torch.fx.experimental.efficient_conv_bn_eval import turn_on_efficient_conv_bn_eval | |
class BackboneModel(nn.Module): | |
def __init__(self, *args, **kwargs) -> None: | |
super().__init__(*args, **kwargs) | |
self.conv1 = nn.Conv2d(16, 16, 6) | |
self.bn1 = nn.BatchNorm2d(16) | |
self.conv2 = nn.Conv2d(16, 16, 6) | |
self.bn2 = nn.BatchNorm2d(16) | |
self.conv3 = nn.Conv2d(16, 16, 6) | |
self.bn3 = nn.BatchNorm2d(16) | |
def forward(self, x, y): | |
# this conv-bn pair can use efficient_conv_bn_eval feature | |
x = self.bn1(self.conv1(x)) | |
# this conv-bn pair can use efficient_conv_bn_eval feature | |
# only for the second `self.conv2` call. | |
x = self.bn2(self.conv2(self.conv2(x))) | |
# this conv-bn pair can use efficient_conv_bn_eval feature | |
# just for the first forward of the `self.bn3` | |
x = self.bn3(self.bn3(self.conv3(x))) | |
return x.abs().sum() + self.bn2(self.conv1(y)).abs().sum() | |
model = BackboneModel() | |
model.eval() | |
opt_model = torch.compile()(model) | |
efficient_model = copy.deepcopy(model) | |
turn_on_efficient_conv_bn_eval(efficient_model) | |
a = torch.rand(64, 16, 32, 32) | |
b = torch.rand(64, 16, 32, 32) | |
with torch.no_grad(): | |
output1 = model(a, b) | |
output2 = opt_model(a, b) | |
output3 = efficient_model(a, b) | |
print((output1 - output2).abs().max().item()) # prints 1.109375 in my computer. the result is vastly wrong. | |
print((output1 - output3).abs().max().item()) # prints 0 in my computer. works fine for very dynamic cases. | |
# TORCHINDUCTOR_FREEZING=1 python dynamic_conv_bn.py |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment