Skip to content

Instantly share code, notes, and snippets.

@pinzhenx
Last active April 26, 2020 04:45
Show Gist options
  • Save pinzhenx/584f3c2cf7b472c54a52fff47eb48ebd to your computer and use it in GitHub Desktop.
Save pinzhenx/584f3c2cf7b472c54a52fff47eb48ebd to your computer and use it in GitHub Desktop.
  1. A simple example of how const prop works
def foo(x):
    a = 1 + 2
    b = a + 3
    c = b + 4
    return x + c

jit_foo = torch.jit.script(foo)


# `module.graph` refers to the unoptimized graph
print(jit_foo.graph)

# `with torch.no_grad(): module.graph_for(input)` refers to the optimized graph
with torch.no_grad():
    print(jit_foo.graph_for(torch.rand(10)))
  1. A case that shows how const prop works in pyrys

First, we slightly modify the test_conv_bn_relu() in test_pass.py

#Conv+Bn+Relu
def test_conv_bn_relu():
    ConvBnRelu = ScriptedConv2dBnRelu(3, 32, kernel_size = 3, stride = 1)
    pyrys._jit_pass_freeze_params(ConvBnRelu._c, 'forward', 'weight')
    pyrys._jit_pass_freeze_params(ConvBnRelu._c, 'forward', 'bias')
    pyrys._jit_pass_freeze_params(ConvBnRelu._c, 'forward', 'running_mean')
    pyrys._jit_pass_freeze_params(ConvBnRelu._c, 'forward', 'running_var')
    pyrys._jit_pass_freeze_flags(ConvBnRelu._c, 'forward', 'training', False)
    x = torch.rand((1, 3, 8, 8))
    with torch.no_grad():
        print('Conv2d+BatchNorm2d+Relu Graph:\n', ConvBnRelu.graph_for(x))
        ConvBnRelu(x)

Then we can compare the difference between enabling/disabling constprop in fusion_pass.

// disable constprop in fusion_pass
Conv2d+BatchNorm2d+Relu Graph:
 graph(%self : ClassType<ScriptedConv2dBnRelu>,
      %x.1 : Float(*, *, *, *)):
  %328 : int[] = prim::Constant[value=[0, 0]]()
  %327 : int[] = prim::Constant[value=[1, 1]]()
  %3 : float = prim::Constant[value=0.001]() # /home/pinzhenxu/pytorch_llga/torch/nn/modules/batchnorm.py:79:40
  %4 : None = prim::Constant() # /home/pinzhenxu/pytorch_llga/torch/nn/modules/conv.py:339:36
  %23 : int = prim::Constant[value=1]() # /home/pinzhenxu/pytorch_llga/torch/nn/modules/conv.py:336:46
  %343.weight.1 : Float(32, 3, 3, 3) = prim::Constant[value=<Tensor>]()
  %346.running_mean : Float(32) = prim::Constant[value=<Tensor>]()
  %347.running_var : Float(32) = prim::Constant[value=<Tensor>]()
  %329 : int[] = prim::Constant[value=[1, 1]]()
  %343.weight.1.bn_folded : Float(32, 3, 3, 3) = dnnl::fold_weight(%343.weight.1, %347.running_var, %347.running_var, %3)
  %4.bn_folded : None = dnnl::fold_bias(%343.weight.1.bn_folded, %4, %347.running_var, %346.running_mean, %346.running_mean, %347.running_var, %3)
  %343 : Float(*, *, *, *) = dnnl::conv2d_relu[format_info=[1, 81, 1, 1, 1, 1, 1, 1]](%x.1, %343.weight.1.bn_folded, %4.bn_folded, %327, %328, %329, %23)
  %result.1.reorder : Tensor = dnnl::reorder[format_info=[1, 7], group_info=1](%343)
  return (%result.1.reorder)
// enable constprop in fusion_pass
Conv2d+BatchNorm2d+Relu Graph:
 graph(%self : ClassType<ScriptedConv2dBnRelu>,
      %x.1 : Float(*, *, *, *)):
  %328 : int[] = prim::Constant[value=[0, 0]]()
  %327 : int[] = prim::Constant[value=[1, 1]]()
  %23 : int = prim::Constant[value=1]() # /home/pinzhenxu/pytorch_llga/torch/nn/modules/conv.py:336:46
  %329 : int[] = prim::Constant[value=[1, 1]]()
  %344 : Float(32, 3, 3, 3) = prim::Constant[value=<Tensor>]()
  %345 : Float(32) = prim::Constant[value=<Tensor>]()
  %343 : Float(*, *, *, *) = dnnl::conv2d_relu[format_info=[1, 81, 1, 1, 1, 1, 1, 1]](%x.1, %344, %345, %327, %328, %329, %23)
  %result.1.reorder : Tensor = dnnl::reorder[format_info=[1, 7], group_info=1](%343)
  return (%result.1.reorder)
  1. Intro to conv-bn fusion

Here we fuse conv and batchnorm into one op. Conv-BN fusion is a commonly used technique in CNN inference optimization. In most CNN networks, we uses batchnorm to speed up the training. But it won't do us any good in inference-only scenarios but hurt the performance because it's a memory bound operator that requires lots of I/O.

Thanks to its algebraic properties, we can fuse the parameters of batchnorm (running_mean & running_var) into its previous convolution's parameters (weight, bias). i.e.

x -> conv -> y -> bn -> z

becomes

x -> conv -> z

If you're interested in the math details about it, please refer to https://tehnokv.com/posts/fusing-batchnorm-and-conv/ The math part of this article is correct, however, the python code is problematic. Check this out for the correct python code: intel/webml-polyfill#240 (comment)

Question (don't have to send me the answer):

  1. What does _jit_pass_freeze_params and _jit_pass_freeze_flags do? What if we do not _jit_pass_freeze_flags at all?

  2. How do we fuse the conv and bn in pyrys? That is, how do we tranform a tedious JIT graph from x -> conv -> y -> bn -> z to x -> conv -> z? You may inspect the intermediate graphs after each pass to understand the transformation process.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment