Skip to content

Instantly share code, notes, and snippets.

@Coderx7
Forked from mingfeima/pytorch_cpu_perf_bkm.md
Created June 9, 2020 02:15
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Coderx7/bb7cad895378d9faeedb6118abda1685 to your computer and use it in GitHub Desktop.
Save Coderx7/bb7cad895378d9faeedb6118abda1685 to your computer and use it in GitHub Desktop.
BKM for PyTorch CPU Performance

General guidelines for CPU performance on PyTorch

This file serves a BKM to get better performance on CPU for PyTorch, mostly focusing on inference or deployment. Chinese version available here.

1. Use mkldnn layout

layout refers to how data is organized in a tensor. PyTorch default layout is NCHW, from optimization perspective, MKL-DNN library (renamed as DNNL recently) may choose a different layout, sometimes refered to as internal layout or primitive layout. This is actually a normal technique for acceleration libraries, common knowledge is that NHWC runs faster than NCHW for convolution, changing the default NCHW to NHWC is called a reorder. MKL-DNN may choose different internal layouts based on the input pattern and the algorithm selected, e.g. nChw16c, a.k.a. reorder a 4-dim tensor into 5-dim by chop down dimension C by 16, for vectorization purpose (AVX512 instruction length is 16x32 bit).

By default on CPU, conv2d will run MKL-DNN but with reorder overhead. input and weight will be reordered from default layout to mkldnn layout and output will be reordered from mkldnn layout to default layout.

To achieve better performance, we need allow mkldnn layout to flow through different operators, which involoves two aspects:

  • change input to mkldnn layout and output will be in mkldnn layout so that input/output reorder will be removed;
  • change model to mkldnn so that weight reorder will be removed;

To be more clear:

  • Method .to_mkldnn() will change layout from default to mkldnn, it is a signal that this tensor is now only valid for MKL-DNN operators and no longer visible to users; Method to_dense() will change layout back from mkldnn to default and now users can read it.
  • Function torch.utils.mkldnn.to_mkldnn() will change modules to MKL-DNN counterparts, e.g. from Conv2d to MkldnnConv2d, from Linear to MkldnnLinear, etc. And in the meantime, weights are changed to mkldnn layout.

A problem is that only a dozen of operators support mkldnn layout, e.g. Conv2d, BatchNorm, ReLU, etc.

In case the model only consist of mkldnn supported operators, all you have to do is:

input_ = input.to_mkldnn()
model_ = torch.utils.mkldnn.to_mkldnn(model)
output_ = model(input)
output = output.to_dense()

In case the model is a combination of mkldnn supported and unsupported operators, you need to inserts to_dense() and to_mkldnn() in between:

class MyModel(nn.Module):
    def __init__(self):
        self(MyModel, self).__init__()
        self.conv1 = nn.Conv2d(10, 10, 3)
        # MyModel has mkldnn unsupported operators X()
        self.unsupported_mod = nn.X()
        self.linear1 = nn.Linear(10, 20)
        
    def forward(self, x):
        x = self.conv1(x)
        # use default layout for module without mkldnn support
        x = x.to_dense()
        x = self.unsupported_mod(x)
        x = x.to_mkldnn()
        x = self.linear1(x)
        return x

Notes:

  • to_mkldnn() and to_dense() is no free lunch, it is memory copy which increases performance overhead.
  • mkldnn layout doesn't support view(), you will receive a runtime error if try to view a mkldnn tensor, use reshape() instead. Again, reshape() is no free lunch.

2. Example of deploying ResNext101

Here is an example with resnext101_32x8d inference on CPU, with convnet-benchmark-py using single batch size.

2.a default run:

./run.sh --inference --single

output is 92ms per image:

ModelType: resnext101, Kernels: nn Input shape: 1x3x224x224
nn                              :forward:      92.52 (ms)      10.81 (imgs/s)
nn                             :backward:       0.00 (ms)
nn                               :update:       0.00 (ms)
nn                                :total:      92.52 (ms)      10.81 (imgs/s)

2.b change input and model to mkldnn

./run.sh --inference --single --mkldnn

this will do

from torch.utils import mkldnn as mkldnn_utils
input = input.to_mkldnn() # input will be _mkldnn layout
model = mkldnn_utils.to_mkldnn(model) # weight will be _mkldnn layout

output is 43ms per image:

ModelType: resnext101, Kernels: nn Input shape: 1x3x224x224
nn                              :forward:      43.27 (ms)      23.11 (imgs/s)
nn                             :backward:       0.00 (ms)
nn                               :update:       0.00 (ms)
nn                                :total:      43.27 (ms)      23.11 (imgs/s)

2.c cache reordered mkldnn weight

./run.sh --inference --single --mkldnn --cache-weight

this will generate the script module and save script module into a .pt file and loaded it. Weight cache is done during the save.

traced = torch.jit.trace(net, data, check_trace=False)
script = traced.save('model.pt') # mkldnn reordered weight will be registered as a module parameter
model = torch.jit.load('model.pt')

output is 32ms per image:

nn                              :forward:      32.35 (ms)      30.91 (imgs/s)
nn                             :backward:       0.00 (ms)
nn                               :update:       0.00 (ms)
nn                                :total:      32.35 (ms)      30.91 (imgs/s)

This also applies to libtorch which means you can save the script model in python and load .pt file from C++.

3. Set correct environment variables

For single instance run, regulate omp thread count and core biding as:

export OMP_NUM_THREADS=[number_of_physical_cores]
export KMP_AFFINITY=granularity=fine,compact,1,0

For single socket run, avoid remote memory access by numactrl

numactl --physcpubind=0-$LAST_CORE --membind=0

For multi instance run, in case each instance will spawn its own omp thread pool, regulate OMP_NUM_THREADS per instance. Make sure omp_threads * num_instances do not exceed number of physical cores, so as to prevent over subscription.

The multi instance case is much more complicated than single instance, since there exists numbers of upper level of threading model, you may use torch.multiprocessing, std::threads, TBB, etc. Be careful with over subscription, this is going to result in dramatic performance drop on CPU. Easiest way to determine such issue on Intel CPU is vtune.

4. Use intel OpenMP library

At the current stage, PyTorch compiles with GNU OMP library by default. You may use Intel OMP library (which has better performance) by pre-loading:

LD_PRELOAD=/opt/intel/compilers_and_libraries/linux/lib/intel64/libiomp5.so ./your_script.sh

5. Use jemalloc to reduce memory allocation consumption

PyTorch uses dynamic graph which has a flaw that output of each operator must be allocated for each execution, which increases the burden of memory allocation and will trigger clear page for large buffer. This issue can be alleviated with jemalloc to some extend.

LD_PRELOAD=/home/mingfeim/packages/jemalloc-5.2.0/lib/libjemalloc.so ./your_script.sh

From my experience, it may work like a charm or no effect at all. Anyway, it is worth a trial.

[TODO]: recent experiments show tbbmalloc has ~25% performance improvements. Will do more profiling and tuning and update this section.

6. Use icc

PyTorch by default compiles with GCC. However GCC is very lame coming to automatic vectorization which leads to worse CPU performance. Older PyTorch version do compile with ICC and I used to ship default compiler under intel/pytorch with ICC. After PyTorch and Caffe2 merge, ICC build will trigger ~2K errors and warninings.

So, if you intend to build PyTorch with ICC, disable caffe2 build with BUILD_CAFFE2_OPS=0 and

CC=icc CXX=icpc python setup.py build

7. Use single process DataLoader

torch.utils.data.DataLoader may be slower in case num_workers > 0, try to compare with num_workers = 0.

8. Profile PyTorch

Use torch.autograd.profiler to identify hotspots of your workload, additional info is listed in pytorch_profiler_parser

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment