Skip to content

Instantly share code, notes, and snippets.

@jerryzh168
Created August 10, 2022 16:45
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jerryzh168/8bea7180a8ba3c279f2c9b050f2a69a6 to your computer and use it in GitHub Desktop.
Save jerryzh168/8bea7180a8ba3c279f2c9b050f2a69a6 to your computer and use it in GitHub Desktop.
def test_fusion_pattern_with_multiple_inputs(self):
""" This test tests two keys in backend_config: root_node_getter and
extra_inputs_getter,
root_node_getter is used to identify a "root" module in the node pattern,
the node that we'll keep after fusion.
extra_inputs_getter will return a list of node that needs to be added to the
fused node as extra inputs.
"""
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 3, 3)
self.bn = torch.nn.BatchNorm2d(3)
self.relu = torch.nn.ReLU()
self.maxpool = torch.nn.MaxPool2d(3)
def forward(self, x):
y = x
y = self.maxpool(x)
x = self.conv(x)
x = self.bn(x)
x = torch.add(x, y)
x = self.relu(x)
return x
m = M().eval()
def fuse_conv_bn_relu(is_qat, relu, add_pattern):
_, bn_pattern, _ = add_pattern
bn, conv = bn_pattern
return conv
def conv_bn_res_relu_root_node_getter(pattern):
relu, add_pattern = pattern
_, bn_pattern, _ = add_pattern
bn, conv = bn_pattern
return conv
def conv_bn_res_relu_extra_inputs_getter(pattern):
""" get inputs pattern for extra inputs, inputs for root node
are assumed to be copied over from root node to the fused node
"""
relu, add_pattern = pattern
_, bn_pattern, extra_input = add_pattern
bn, conv = bn_pattern
return [extra_input]
conv_bn_res_relu_config = BackendPatternConfig((nn.ReLU, (torch.add, (nn.BatchNorm2d, nn.Conv2d), MatchAllNode))) \
.set_fuser_method(fuse_conv_bn_relu) \
._set_root_node_getter(conv_bn_res_relu_root_node_getter) \
._set_extra_inputs_getter(conv_bn_res_relu_extra_inputs_getter)
backend_config = BackendConfig().set_backend_pattern_config(conv_bn_res_relu_config)
m = fuse_fx(m, backend_config=backend_config)
self.assertEqual(type(m.conv), torch.nn.Conv2d)
# check bn and relu are gone since we replaced the whole pattern to conv
self.assertFalse(hasattr(m, "bn"))
self.assertFalse(hasattr(m, "relu"))
# check conv module has two inputs
named_modules = dict(m.named_modules())
for node in m.graph.nodes:
if node.op == "call_module" and type(named_modules[node.target]) == torch.nn.Conv2d:
self.assertTrue(len(node.args) == 2), "Expecting the fused op to have two arguments"
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment