Skip to content

Instantly share code, notes, and snippets.

@mingfeima
Last active February 16, 2024 21:31
Show Gist options
  • Star 32 You must be signed in to star a gist
  • Fork 6 You must be signed in to fork a gist
  • Save mingfeima/363a9ab850be54d5837f9cc542ad2b38 to your computer and use it in GitHub Desktop.
Save mingfeima/363a9ab850be54d5837f9cc542ad2b38 to your computer and use it in GitHub Desktop.
BKM for PyTorch CPU Performance

General guidelines for CPU performance on PyTorch

This file serves a BKM to get better performance on CPU for PyTorch, mostly focusing on inference or deployment. Chinese version available here.

1. Use channels last memory format

Right now, on PyTorch CPU path, you may choose to use 3 types of memory formats.

  • torch.contiguous_format: default memory format, also referred as NHCW.
  • torch.channels_last: also referred as NHWC.
  • torch._mkldnn: mkldnn blocked format.

The default NCHW has worse performance compared with NHWC and MKLDNN Blocked memory format.

### 1. default (NCHW)
output = model(input)

### 2. channels last
input = input.to(memory_format=torch.channels_last)
model = model.to(memory_format=torch.channels_last)

### Note: Most of CV model has Convolution as 1st layer and channels last has higher priority in Conv2d.
###   So you can just also only convert weight to channels last and input will be converted accordingly.
###   And channels last memory format will be propagated through the model (until operator without channels
###   last support, if any).

### 3a. mkldnn blocked format (inference)
input = input.to_mkldnn()
model = torch.utils.mkldnn.to_mkldnn(model)
output = model(input)

### 3b. mkldnn blocked format (training)
input = input.to_mkldnn()
output = model(input)

In case the model has operators which doestn't support channels last memory format, you might not be able to get optimal performance since NHWC will be treated as non-contiguous of NCHW and the rest of the model will propagate NCHW.

In case the model has operators which doesn't support mkldnn blocked memory format, you need to inserts to_dense() and to_mkldnn() in between:

class MyModel(nn.Module):
    def __init__(self):
        self(MyModel, self).__init__()
        self.conv1 = nn.Conv2d(10, 10, 3)
        # MyModel has mkldnn unsupported operators X()
        self.unsupported_mod = nn.X()
        self.linear1 = nn.Linear(10, 20)
        
    def forward(self, x):
        x = self.conv1(x)
        # use default layout for module without mkldnn support
        x = x.to_dense()
        x = self.unsupported_mod(x)
        x = x.to_mkldnn()
        x = self.linear1(x)
        return x

Further explaination on Channels Last memory format optimization on PyTorch Channels Last Memory Format Performance Optimization on CPU Path.

Results on Intel(R) Xeon(R) Gold 6248 CPU @ 2.50GHz, single socket with 20 cores available here.

NHWC performance is collected with: torch-opt-test and torchvision-opt-test. Upsteaming to public is ongoing.

### NCHW run
Running on torch: 1.8.1+cpu
Running on torchvision: 0.9.1+cpu
ModelType: resnet50, Kernels: nn Input shape: 1x3x224x224
nn                              :forward:      55.89 (ms)      17.89 (imgs/s)
nn                             :backward:       0.00 (ms)
nn                               :update:       0.00 (ms)
nn                                :total:      55.89 (ms)      17.89 (imgs/s)

### NHWC run
Running on torch: 1.9.0a0+git850a6bd
Running on torchvision: 0.10.0a0+4f34ae5
ModelType: resnet50, Kernels: nn Input shape: 1x3x224x224
nn                              :forward:      14.02 (ms)      71.31 (imgs/s)
nn                             :backward:       0.00 (ms)
nn                               :update:       0.00 (ms)
nn                                :total:      14.02 (ms)      71.31 (imgs/s)

2. TorchVision with channels last support

If you model need csrc modules from torchvision, e.g. "ROIAlign", torchvision also needs to has channels last support. For example, MaskedRCNN related info has been placed at here.

Results on Intel(R) Xeon(R) Gold 6248 CPU @ 2.50GHz, single socket with 20 cores:

### with config "fast_rcnn_R_50_FPN_1x.yaml"
### NCHW (torch-1.8.1/vision-0.9.1): 300 iters in 326.0195782049559 seconds.
### NCHW (torch-opt/vision-0.9.1): 300 iters in 185.4384527085349 seconds.
### NCHW (torch-opt/vision-opt): 300 iters in 80.56146793198423 seconds.
### NHWC (torch-opt/vision-opt): 300 iters in 55.49435344198719 seconds.

Upstreaming to public pytorch repo is ongoing. Further optimization is also WIP.

3. Set correct environment variables

For single instance run, regulate omp thread count and core biding as:

export OMP_NUM_THREADS=[number_of_physical_cores]
export KMP_AFFINITY=granularity=fine,compact,1,0

For single socket run, avoid remote memory access by numactrl

# e.g. say each socket has 20 cores, to use the 1st socket:
numactl --physcpubind=0-19 --membind=0

For multi instance run, in case each instance will spawn its own omp thread pool, regulate OMP_NUM_THREADS per instance. Make sure omp_threads * num_instances do not exceed number of physical cores, so as to prevent over subscription.

The multi instance case is much more complicated than single instance, since there exists numbers of upper level of threading model, you may use torch.multiprocessing, std::threads, TBB, etc. Be careful with over subscription, this is going to result in dramatic performance drop on CPU. Easiest way to determine such issue on Intel CPU is vtune.

4. Use jemalloc to reduce memory allocation consumption

PyTorch uses dynamic graph which has a flaw that output of each operator must be allocated for each execution, which increases the burden of memory allocation and will trigger clear page for large buffer. This issue can be alleviated with jemalloc or tcmalloc to some extend.

### jemalloc
export MALLOC_CONF="oversize_threshold:1,background_thread:true,metadata_thp:auto,dirty_decay_ms:-1,muzzy_decay_ms:-1"
export LD_PRELOAD=/home/mingfeim/packages/jemalloc-5.2.1/lib/libjemalloc.so

### tcmalloc
export LD_PRELOAD=/home/mingfeim/packages/gperftools-2.8/install/lib/libtcmalloc.so

If you see clear_page from vmlinux.so is consuming a lot of time from vtune, it is time to apply jemalloc.

5. Use single process DataLoader

torch.utils.data.DataLoader may be slower in case num_workers > 0, try to compare with num_workers = 0.

@owoshch
Copy link

owoshch commented Mar 2, 2021

@nexgus have you succeeded on building pytorch with those specs from source?

@mingfeima
Copy link
Author

@nexgus have you succeeded on building pytorch with those specs from source?

Sorry that some recipe from this doc is out of date, i will renew it with latest optimization techniques, channels last, bf16, etc.

icc should not work for now, also performance benefit from icc is actually very limited as pytorch prefers explicit vectorization (not via compiler).

I have been working on channels last (nhwc) path optimization lately. I will include this part as well. Good thing about nhwc is that the overall performance would be much higher than default nchw and also it is a plain memory format so it supports all generic torch tensor shape manipulating operators (e.g. cat, chunk, etc) thus you don't have to inject to_mkldnn()/to_dense() in to your model. So the usage would be much simpler...

@owoshch
Copy link

owoshch commented Mar 3, 2021

@mingfeima
Thank you for a very detailed reply! When do you plan to release a new guide?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment