Skip to content

Instantly share code, notes, and snippets.

@mingfeima
Last active February 16, 2024 21:31
Show Gist options
  • Star 32 You must be signed in to star a gist
  • Fork 6 You must be signed in to fork a gist
  • Save mingfeima/363a9ab850be54d5837f9cc542ad2b38 to your computer and use it in GitHub Desktop.
Save mingfeima/363a9ab850be54d5837f9cc542ad2b38 to your computer and use it in GitHub Desktop.
BKM for PyTorch CPU Performance

General guidelines for CPU performance on PyTorch

This file serves a BKM to get better performance on CPU for PyTorch, mostly focusing on inference or deployment. Chinese version available here.

1. Use channels last memory format

Right now, on PyTorch CPU path, you may choose to use 3 types of memory formats.

  • torch.contiguous_format: default memory format, also referred as NHCW.
  • torch.channels_last: also referred as NHWC.
  • torch._mkldnn: mkldnn blocked format.

The default NCHW has worse performance compared with NHWC and MKLDNN Blocked memory format.

### 1. default (NCHW)
output = model(input)

### 2. channels last
input = input.to(memory_format=torch.channels_last)
model = model.to(memory_format=torch.channels_last)

### Note: Most of CV model has Convolution as 1st layer and channels last has higher priority in Conv2d.
###   So you can just also only convert weight to channels last and input will be converted accordingly.
###   And channels last memory format will be propagated through the model (until operator without channels
###   last support, if any).

### 3a. mkldnn blocked format (inference)
input = input.to_mkldnn()
model = torch.utils.mkldnn.to_mkldnn(model)
output = model(input)

### 3b. mkldnn blocked format (training)
input = input.to_mkldnn()
output = model(input)

In case the model has operators which doestn't support channels last memory format, you might not be able to get optimal performance since NHWC will be treated as non-contiguous of NCHW and the rest of the model will propagate NCHW.

In case the model has operators which doesn't support mkldnn blocked memory format, you need to inserts to_dense() and to_mkldnn() in between:

class MyModel(nn.Module):
    def __init__(self):
        self(MyModel, self).__init__()
        self.conv1 = nn.Conv2d(10, 10, 3)
        # MyModel has mkldnn unsupported operators X()
        self.unsupported_mod = nn.X()
        self.linear1 = nn.Linear(10, 20)
        
    def forward(self, x):
        x = self.conv1(x)
        # use default layout for module without mkldnn support
        x = x.to_dense()
        x = self.unsupported_mod(x)
        x = x.to_mkldnn()
        x = self.linear1(x)
        return x

Further explaination on Channels Last memory format optimization on PyTorch Channels Last Memory Format Performance Optimization on CPU Path.

Results on Intel(R) Xeon(R) Gold 6248 CPU @ 2.50GHz, single socket with 20 cores available here.

NHWC performance is collected with: torch-opt-test and torchvision-opt-test. Upsteaming to public is ongoing.

### NCHW run
Running on torch: 1.8.1+cpu
Running on torchvision: 0.9.1+cpu
ModelType: resnet50, Kernels: nn Input shape: 1x3x224x224
nn                              :forward:      55.89 (ms)      17.89 (imgs/s)
nn                             :backward:       0.00 (ms)
nn                               :update:       0.00 (ms)
nn                                :total:      55.89 (ms)      17.89 (imgs/s)

### NHWC run
Running on torch: 1.9.0a0+git850a6bd
Running on torchvision: 0.10.0a0+4f34ae5
ModelType: resnet50, Kernels: nn Input shape: 1x3x224x224
nn                              :forward:      14.02 (ms)      71.31 (imgs/s)
nn                             :backward:       0.00 (ms)
nn                               :update:       0.00 (ms)
nn                                :total:      14.02 (ms)      71.31 (imgs/s)

2. TorchVision with channels last support

If you model need csrc modules from torchvision, e.g. "ROIAlign", torchvision also needs to has channels last support. For example, MaskedRCNN related info has been placed at here.

Results on Intel(R) Xeon(R) Gold 6248 CPU @ 2.50GHz, single socket with 20 cores:

### with config "fast_rcnn_R_50_FPN_1x.yaml"
### NCHW (torch-1.8.1/vision-0.9.1): 300 iters in 326.0195782049559 seconds.
### NCHW (torch-opt/vision-0.9.1): 300 iters in 185.4384527085349 seconds.
### NCHW (torch-opt/vision-opt): 300 iters in 80.56146793198423 seconds.
### NHWC (torch-opt/vision-opt): 300 iters in 55.49435344198719 seconds.

Upstreaming to public pytorch repo is ongoing. Further optimization is also WIP.

3. Set correct environment variables

For single instance run, regulate omp thread count and core biding as:

export OMP_NUM_THREADS=[number_of_physical_cores]
export KMP_AFFINITY=granularity=fine,compact,1,0

For single socket run, avoid remote memory access by numactrl

# e.g. say each socket has 20 cores, to use the 1st socket:
numactl --physcpubind=0-19 --membind=0

For multi instance run, in case each instance will spawn its own omp thread pool, regulate OMP_NUM_THREADS per instance. Make sure omp_threads * num_instances do not exceed number of physical cores, so as to prevent over subscription.

The multi instance case is much more complicated than single instance, since there exists numbers of upper level of threading model, you may use torch.multiprocessing, std::threads, TBB, etc. Be careful with over subscription, this is going to result in dramatic performance drop on CPU. Easiest way to determine such issue on Intel CPU is vtune.

4. Use jemalloc to reduce memory allocation consumption

PyTorch uses dynamic graph which has a flaw that output of each operator must be allocated for each execution, which increases the burden of memory allocation and will trigger clear page for large buffer. This issue can be alleviated with jemalloc or tcmalloc to some extend.

### jemalloc
export MALLOC_CONF="oversize_threshold:1,background_thread:true,metadata_thp:auto,dirty_decay_ms:-1,muzzy_decay_ms:-1"
export LD_PRELOAD=/home/mingfeim/packages/jemalloc-5.2.1/lib/libjemalloc.so

### tcmalloc
export LD_PRELOAD=/home/mingfeim/packages/gperftools-2.8/install/lib/libtcmalloc.so

If you see clear_page from vmlinux.so is consuming a lot of time from vtune, it is time to apply jemalloc.

5. Use single process DataLoader

torch.utils.data.DataLoader may be slower in case num_workers > 0, try to compare with num_workers = 0.

@nexgus
Copy link

nexgus commented Mar 10, 2020

Hi MingFei,

I followed your guideline to use MKL-DNN to accelerate CPU inference. It works, and especially on CPU with AVX-512 support, such as Xeon.

I want to try densenet121 since it is useful for medical image analyze. In fact, we use it in some hospitals and works fine. I think my next step is to use MKL-DNN with PyTorch to do acceleration so that customer saves money (on GPU card).

Yes, basic concept is

import torch
import torchvision
from torch.utils import mkldnn as mkldnn_utils

net = torchvision.models.densenet121()
net.eval()
net = mkldnn_utils.to_mkldnn(net)
data = torch.randn(1, 3, 224, 224)
data = data.to_mkldnn()
inference = net(data)

However, when the last line (to perform inference) is executed, an exception is raised:

RuntimeError: Could not run 'aten::_cat' with arguments from the 'MkldnnCPUTensorId' backend. 'aten::_cat' is only available for these backends: [CPUTensorId, VariableTensorId].

Does it mean that this neural network which derived from torchvision is a combination of MKL-DNN supported and unsupported operators?

So how should I do to solve this problem (based on torchvision model )? Could you give me hint or advice, please?

By the way, once DNNL (1.x) is implemented in PyTorch, is it possible to use Intel GPU to accelerate inference?

B.Rds,
Augustus

@mingfeima
Copy link
Author

@nexgus torch.cat() is treated as an unsupported operator to MKL-DNN, which means you need to converts the inputs by to_dense() first. This is illustrated in 2nd example from the 1st section. I'm afraid you need to manually update the model file under torchvision.

We are trying to improve the user experience here by avoiding manually modification to model file.

@nexgus
Copy link

nexgus commented Mar 10, 2020

@mingfeima Thank you very much. It works.

@nexgus
Copy link

nexgus commented Mar 25, 2020

@mingfeima,

I tried to build PyTorch with ICC but failed. Could you give me an advice?

Here is my procedure:

  1. Build a Docker image named pytorch-dev:icc. The dockerfile is showing as following. PyTorch source code is from https://github.com/pytorch/pytorch.
    Here is the content of Dockerfile.
    FROM ubuntu:18.04
    LABEL maintainer "Augustus Chen <augustuschen@nexcom.com.tw>"
    ARG DEBIAN_FRONTEND=noninteractive
    
    COPY nvidia/cuda-repo-ubuntu1804-10-1-local-10.1.105-418.39_1.0-1_amd64.deb \
         nvidia/libcudnn7_7.6.3.30-1+cuda10.1_amd64.deb \
         nvidia/libcudnn7-dev_7.6.3.30-1+cuda10.1_amd64.deb \
         /root/
    
    # Install required libs, CUDA, and cuDNN (both runtime and dev). The version of Python is v3.7.
    # Also install required Python package for build PyTorch. However, MKL is not installed since
    # we'll install Intel Parallel Studio XE.
    RUN apt-get update && \
        apt-get install -y build-essential ca-certificates cpio curl git \
                           libasound2 libgtk2.0 libgtk-3-0 libnss3 libpango-1.0-0 libssl-dev libxss1 \
                           linux-headers-4.15.0-88-generic \
                           python3.7 python3.7-dev python3-pip vim wget && \
        dpkg -i /root/cuda-repo-ubuntu1804-10-1-local-10.1.105-418.39_1.0-1_amd64.deb && \
        apt-key add /var/cuda-repo-10-1-local-10.1.105-418.39/7fa2af80.pub && \
        apt-get update && \
        apt-get install -y cuda && \
        dpkg -i /root/libcudnn7_7.6.3.30-1+cuda10.1_amd64.deb && \
        dpkg -i /root/libcudnn7-dev_7.6.3.30-1+cuda10.1_amd64.deb && \
        apt-get clean && \
        rm -rf /var/lib/apt/lists/* && \
        rm /root/*.deb && \
        rm /usr/bin/python && \
        rm /usr/bin/python3 && \
        ln -s /usr/bin/python3.7 /usr/bin/python3 && \
        ln -s /usr/bin/python3 /usr/bin/python && \
        python3 -m pip install --upgrade pip && \
        pip install --no-cache numpy pyyaml setuptools cmake cffi typing
    
    # Clone PyTorch source code. Sure you can overwirte it by using -v option.
    RUN git clone --recursive https://github.com/pytorch/pytorch && \
        cd pytorch && \
        git reset --hard v1.4.0 && \
        git submodule sync && \
        git submodule update --init --recursive
    
    # Install Intel Parallel Studio XE (2020 first release) without GUI. We'll setup environment variables for access ICC.
    ADD icc/parallel_studio_xe_2020_cluster_edition.tgz root/
    COPY icc/silent.cfg /root/parallel_studio_xe_2020_cluster_edition/silent.cfg
    RUN cd /root/parallel_studio_xe_2020_cluster_edition/ && \
        ./install.sh --silent=silent.cfg && \
        cd /root && \
        rm -r parallel_studio_xe_2020_cluster_edition
    ENV LD_LIBRARY_PATH=/opt/intel/compilers_and_libraries_2020.0.166/linux/compiler/lib/intel64_lin:/opt/intel/compilers_and_libraries_2020.0.166/linux/mpi/intel64/libfabric/lib:/opt/intel/compilers_and_libraries_2020.0.166/linux/mpi/intel64/lib/release:/opt/intel/compilers_and_libraries_2020.0.166/linux/mpi/intel64/lib:/opt/intel/compilers_and_libraries_2020.0.166/linux/ipp/lib/intel64:/opt/intel/compilers_and_libraries_2020.0.166/linux/compiler/lib/intel64_lin:/opt/intel/compilers_and_libraries_2020.0.166/linux/mkl/lib/intel64_lin:/opt/intel/compilers_and_libraries_2020.0.166/linux/tbb/lib/intel64/gcc4.8:/opt/intel/compilers_and_libraries_2020.0.166/linux/tbb/lib/intel64/gcc4.8:/opt/intel/debugger_2020/python/intel64/lib:/opt/intel/debugger_2020/libipt/intel64/lib:/opt/intel/compilers_and_libraries_2020.0.166/linux/daal/lib/intel64_lin:/opt/intel/compilers_and_libraries_2020.0.166/linux/daal/../tbb/lib/intel64_lin/gcc4.4:/opt/intel/compilers_and_libraries_2020.0.166/linux/daal/../tbb/lib/intel64_lin/gcc4.8 \
        IPPROOT=/opt/intel/compilers_and_libraries_2020.0.166/linux/ipp \
        FI_PROVIDER_PATH=/opt/intel/compilers_and_libraries_2020.0.166/linux/mpi/intel64/libfabric/lib/prov \
        INTEL_PYTHONHOME=/opt/intel/debugger_2020/python/intel64/ \
        CLASSPATH=/opt/intel/compilers_and_libraries_2020.0.166/linux/mpi/intel64/lib/mpi.jar:/opt/intel/compilers_and_libraries_2020.0.166/linux/daal/lib/daal.jar \
        CPATH=/opt/intel/compilers_and_libraries_2020.0.166/linux/ipp/include:/opt/intel/compilers_and_libraries_2020.0.166/linux/mkl/include:/opt/intel/compilers_and_libraries_2020.0.166/linux/pstl/include:/opt/intel/compilers_and_libraries_2020.0.166/linux/pstl/stdlib:/opt/intel/compilers_and_libraries_2020.0.166/linux/tbb/include:/opt/intel/compilers_and_libraries_2020.0.166/linux/tbb/include:/opt/intel/compilers_and_libraries_2020.0.166/linux/daal/include \
        NLSPATH=/opt/intel/compilers_and_libraries_2020.0.166/linux/compiler/lib/intel64/locale/%l_%t/%N:/opt/intel/compilers_and_libraries_2020.0.166/linux/mkl/lib/intel64_lin/locale/%l_%t/%N:/opt/intel/debugger_2020/gdb/intel64/share/locale/%l_%t/%N \
        LIBRARY_PATH=/opt/intel/compilers_and_libraries_2020.0.166/linux/mpi/intel64/libfabric/lib:/opt/intel/compilers_and_libraries_2020.0.166/linux/ipp/lib/intel64:/opt/intel/compilers_and_libraries_2020.0.166/linux/compiler/lib/intel64_lin:/opt/intel/compilers_and_libraries_2020.0.166/linux/mkl/lib/intel64_lin:/opt/intel/compilers_and_libraries_2020.0.166/linux/tbb/lib/intel64/gcc4.8:/opt/intel/compilers_and_libraries_2020.0.166/linux/tbb/lib/intel64/gcc4.8:/opt/intel/compilers_and_libraries_2020.0.166/linux/daal/lib/intel64_lin:/opt/intel/compilers_and_libraries_2020.0.166/linux/daal/../tbb/lib/intel64_lin/gcc4.4:/opt/intel/compilers_and_libraries_2020.0.166/linux/daal/../tbb/lib/intel64_lin/gcc4.8 \
        DAALROOT=/opt/intel/compilers_and_libraries_2020.0.166/linux/daal \
        INTEL_LICENSE_FILE=/opt/intel/compilers_and_libraries_2020.0.166/linux/licenses:/opt/intel/licenses:/root/intel/licenses \
        MANPATH=/opt/intel/man/common:/opt/intel/compilers_and_libraries_2020.0.166/linux/mpi/man:/opt/intel/documentation_2020/en/debugger/gdb-ia/man/:/usr/local/man:/usr/local/share/man:/usr/share/man \
        MKLROOT=/opt/intel/compilers_and_libraries_2020.0.166/linux/mkl \
        PSTLROOT=/opt/intel/compilers_and_libraries_2020.0.166/linux/pstl \
        PATH=/opt/intel/compilers_and_libraries_2020.0.166/linux/bin/intel64:/opt/intel/compilers_and_libraries_2020.0.166/linux/bin:/opt/intel/compilers_and_libraries_2020.0.166/linux/mpi/intel64/libfabric/bin:/opt/intel/compilers_and_libraries_2020.0.166/linux/mpi/intel64/bin:/opt/intel/debugger_2020/gdb/intel64/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin \
        TBBROOT=/opt/intel/compilers_and_libraries_2020.0.166/linux/tbb \
        PKG_CONFIG_PATH=/opt/intel/compilers_and_libraries_2020.0.166/linux/mkl/bin/pkgconfig \
        INFOPATH=/opt/intel/documentation_2020/en/debugger/gdb-ia/info/ \
        I_MPI_ROOT=/opt/intel/compilers_and_libraries_2020.0.166/linux/mpi
    WORKDIR /pytorch
  2. Run the container to build the wheel package.
    docker run --rm -it -v $(pwd)/src/1.5.0:/pytorch \
                        -v $(pwd)/dist:/pytorch/dist \
                        -e CC=icc \
                        -e CXX=icpc \
                        -e USE_CUDA=0 \
                        -e USE_CUDNN=0 \
                        -e BUILD_CAFFE2_OPS=0 \
                        pytorch-dev:icc \
                        python3 setup.py bdist_wheel &> dist/build_log.txt
    However, it failed and the message always be the same (like the following block of message), no matter build with CUDA or not. Once I use gcc then everything is fine.
    .
    .
    .
    [ 12%] Building CXX object third_party/ideep/mkl-dnn/src/CMakeFiles/mkldnn.dir/cpu/simple_concat.cpp.o
    [ 12%] Building CXX object third_party/ideep/mkl-dnn/src/CMakeFiles/mkldnn.dir/cpu/simple_sum.cpp.o
    [ 12%] Linking CXX static library ../../../../lib/libmkldnn.a
    [ 12%] Built target mkldnn
    [ 12%] Built target fbgemm_avx2
    Makefile:140: recipe for target 'all' failed
    make: *** [all] Error 2
    Traceback (most recent call last):
      File "setup.py", line 745, in <module>
        build_deps()
      File "setup.py", line 316, in build_deps
        cmake=cmake)
      File "/pytorch/tools/build_pytorch_libs.py", line 62, in build_caffe2
        cmake.build(my_env)
      File "/pytorch/tools/setup_helpers/cmake.py", line 339, in build
        self.run(build_args, my_env)
      File "/pytorch/tools/setup_helpers/cmake.py", line 141, in run
        check_call(command, cwd=self.build_dir, env=env)
      File "/usr/lib/python3.7/subprocess.py", line 363, in check_call
        raise CalledProcessError(retcode, cmd)
    subprocess.CalledProcessError: Command '['cmake', '--build', '.', '--target', 'install', '--config', 'Release', '--', '-j', '12']' returned non-zero exit status 2.
    

Is it necessary to modify source code (like some #pragma) while use ICC?

@Coderx7
Copy link

Coderx7 commented Jun 7, 2020

Thanks a lot, very good insights . by the way did you profile with the latest master? now that the DNNL1.2 is the default?

@owoshch
Copy link

owoshch commented Mar 2, 2021

@nexgus have you succeeded on building pytorch with those specs from source?

@mingfeima
Copy link
Author

@nexgus have you succeeded on building pytorch with those specs from source?

Sorry that some recipe from this doc is out of date, i will renew it with latest optimization techniques, channels last, bf16, etc.

icc should not work for now, also performance benefit from icc is actually very limited as pytorch prefers explicit vectorization (not via compiler).

I have been working on channels last (nhwc) path optimization lately. I will include this part as well. Good thing about nhwc is that the overall performance would be much higher than default nchw and also it is a plain memory format so it supports all generic torch tensor shape manipulating operators (e.g. cat, chunk, etc) thus you don't have to inject to_mkldnn()/to_dense() in to your model. So the usage would be much simpler...

@owoshch
Copy link

owoshch commented Mar 3, 2021

@mingfeima
Thank you for a very detailed reply! When do you plan to release a new guide?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment