Skip to content

Instantly share code, notes, and snippets.

@youkaichao
Created August 25, 2023 16:40
Show Gist options
  • Save youkaichao/f2fab96fa121e197d40ad3726cff5c7f to your computer and use it in GitHub Desktop.
Save youkaichao/f2fab96fa121e197d40ad3726cff5c7f to your computer and use it in GitHub Desktop.
Demonstrate very dynamic usage case of conv-bn pairs.
import torch
from torch import nn
import copy
from torch.fx.experimental.efficient_conv_bn_eval import turn_on_efficient_conv_bn_eval
class BackboneModel(nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.conv1 = nn.Conv2d(16, 16, 6)
self.bn1 = nn.BatchNorm2d(16)
self.conv2 = nn.Conv2d(16, 16, 6)
self.bn2 = nn.BatchNorm2d(16)
self.conv3 = nn.Conv2d(16, 16, 6)
self.bn3 = nn.BatchNorm2d(16)
def forward(self, x, y):
# this conv-bn pair can use efficient_conv_bn_eval feature
x = self.bn1(self.conv1(x))
# this conv-bn pair can use efficient_conv_bn_eval feature
# only for the second `self.conv2` call.
x = self.bn2(self.conv2(self.conv2(x)))
# this conv-bn pair can use efficient_conv_bn_eval feature
# just for the first forward of the `self.bn3`
x = self.bn3(self.bn3(self.conv3(x)))
return x.abs().sum() + self.bn2(self.conv1(y)).abs().sum()
model = BackboneModel()
model.eval()
opt_model = torch.compile()(model)
efficient_model = copy.deepcopy(model)
turn_on_efficient_conv_bn_eval(efficient_model)
a = torch.rand(64, 16, 32, 32)
b = torch.rand(64, 16, 32, 32)
with torch.no_grad():
output1 = model(a, b)
output2 = opt_model(a, b)
output3 = efficient_model(a, b)
print((output1 - output2).abs().max().item()) # prints 1.109375 in my computer. the result is vastly wrong.
print((output1 - output3).abs().max().item()) # prints 0 in my computer. works fine for very dynamic cases.
# TORCHINDUCTOR_FREEZING=1 python dynamic_conv_bn.py
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment