Created
August 10, 2022 16:45
-
-
Save jerryzh168/8bea7180a8ba3c279f2c9b050f2a69a6 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def test_fusion_pattern_with_multiple_inputs(self): | |
""" This test tests two keys in backend_config: root_node_getter and | |
extra_inputs_getter, | |
root_node_getter is used to identify a "root" module in the node pattern, | |
the node that we'll keep after fusion. | |
extra_inputs_getter will return a list of node that needs to be added to the | |
fused node as extra inputs. | |
""" | |
class M(torch.nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.conv = torch.nn.Conv2d(3, 3, 3) | |
self.bn = torch.nn.BatchNorm2d(3) | |
self.relu = torch.nn.ReLU() | |
self.maxpool = torch.nn.MaxPool2d(3) | |
def forward(self, x): | |
y = x | |
y = self.maxpool(x) | |
x = self.conv(x) | |
x = self.bn(x) | |
x = torch.add(x, y) | |
x = self.relu(x) | |
return x | |
m = M().eval() | |
def fuse_conv_bn_relu(is_qat, relu, add_pattern): | |
_, bn_pattern, _ = add_pattern | |
bn, conv = bn_pattern | |
return conv | |
def conv_bn_res_relu_root_node_getter(pattern): | |
relu, add_pattern = pattern | |
_, bn_pattern, _ = add_pattern | |
bn, conv = bn_pattern | |
return conv | |
def conv_bn_res_relu_extra_inputs_getter(pattern): | |
""" get inputs pattern for extra inputs, inputs for root node | |
are assumed to be copied over from root node to the fused node | |
""" | |
relu, add_pattern = pattern | |
_, bn_pattern, extra_input = add_pattern | |
bn, conv = bn_pattern | |
return [extra_input] | |
conv_bn_res_relu_config = BackendPatternConfig((nn.ReLU, (torch.add, (nn.BatchNorm2d, nn.Conv2d), MatchAllNode))) \ | |
.set_fuser_method(fuse_conv_bn_relu) \ | |
._set_root_node_getter(conv_bn_res_relu_root_node_getter) \ | |
._set_extra_inputs_getter(conv_bn_res_relu_extra_inputs_getter) | |
backend_config = BackendConfig().set_backend_pattern_config(conv_bn_res_relu_config) | |
m = fuse_fx(m, backend_config=backend_config) | |
self.assertEqual(type(m.conv), torch.nn.Conv2d) | |
# check bn and relu are gone since we replaced the whole pattern to conv | |
self.assertFalse(hasattr(m, "bn")) | |
self.assertFalse(hasattr(m, "relu")) | |
# check conv module has two inputs | |
named_modules = dict(m.named_modules()) | |
for node in m.graph.nodes: | |
if node.op == "call_module" and type(named_modules[node.target]) == torch.nn.Conv2d: | |
self.assertTrue(len(node.args) == 2), "Expecting the fused op to have two arguments" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment