Skip to content

Instantly share code, notes, and snippets.

@mingfeima
Last active September 1, 2023 03:02
Show Gist options
  • Star 17 You must be signed in to star a gist
  • Fork 3 You must be signed in to fork a gist
  • Save mingfeima/595f63e5dd2ac6f87fdb47df4ffe4772 to your computer and use it in GitHub Desktop.
Save mingfeima/595f63e5dd2ac6f87fdb47df4ffe4772 to your computer and use it in GitHub Desktop.
PyTorch Channels Last memory format perf optimization and oneDNN integration plan.

PyTorch Channels Last Memory Format Performance Optimization on CPU Path

("mkldnn" has been renamed to "oneDNN", but exsiting PyTorch APIs still use "mkldnn", future work will align PyTorch user level APIs to "oneDNN")

Table of Contents

  • PyTorch Channels Last memory format introduction
  • oneDNN API for NHWC layout
  • Generic Channels Last memory format optimization with ATen native
  • oneDNN NHWC integration

NB: Memory format refers to data representation that describes how multidimensional arrays (nD) are stored in linear (1D) memory address space. Memory format has the same semantic with layout in oneDNN. Layout in PyTorch has other semantic ofdescribing dense or sparse with the attributes: 'torch.strided', 'torch.sparse_coo'.

What is Channels Last

On CNN models, the canonical order of tensor dimensions are assigned with semantic meaning. For example the input tensor of 2D convolution is of NCHW by default on PyTorch - <batch_size, channels, height, width>. NHWC is an alternative way of describing the tensor dimensions - <batch_size, height, width, channels>.

Take a look at the following image of illustrating NCHW and NHWC when N=1. Actually when N=1, NHWC has the same format with BMP file image. fig-1-memory-layout

PyTorch refers NCHW as torch.contiguous_format which is the default memory format and NHWC as torch.channels_last which is an new feature from 1.5 release.

TF takes NHWC as the default memory format and from the performance point of view NHWC has advantage over NCHW. On CPU platform, we propose to optimize Channels Last memory path out of the following reasones:

  • Performance - NHWC performance is not as good as blocked memory format (nChw16c) but it is close, and much better than NCHW.
  • User Experience - Operator coverage of NHWC would be higher than blocked memory format (to_mkldnn() method) so user experience is better. To be specific it would be very difficult to enable operator that manipulates dim on blocked format such as sum(dim=?) so you need to convert tensor from blocked memory format back to NHWC by to_dense() before feeding into sum(). But it is naturally supported on Channels Last memory format already.
  • Upstream - Will be easier since CPU doesn't hold secret ingredient and both inference and training will be covered.

Memory Format Is All That Matters

On CNN models, memory format is all most the foundation of any upper level design. One imporant fact is converting memory format could be very expensive, so in case that multiple CNN operators are performed in a row e.g. Conv2d -> ReLU -> Conv2d, it's beneficial to transform to the different memory format once, do computation and reorder them back.

On PyTorch, you can use 3 types of memory formats on CNN models:

a. NCHW (default)

### NB: internally sitll blocked format will be used.
###   aka. we do 'reorder' for 'input', 'weight' and 'output',
###   and believe me this is expensive, roughly 50% perf loss...
input = torch.randn(1, 10, 32, 32)
model = torch.nn.Conv2d(10, 20, 1, 1)
output = model(input)

b. NHWC (WIP for CPU)

input = torch.randn(1, 10, 32, 32)
model = torch.nn.Conv2d(10, 20, 1, 1)
### NB: convert to Channels Last memory format.
###   oneDNN support NHWC for feature maps (input, output),
###   but weight still need to be of blocked format.
###   Still we can save reorders for feature maps.
input = input.to(memory_format=torch.channels_last)
model = model.to(memory_format=torch.channels_last)
output = model(input)

c. Blocked (nChw16c)

from torch.utils import mkldnn as mkldnn_utils
input = torch.randn(1, 10, 32, 32)
model = torch.nn.Conv2d(10, 20, 1, 1)
### NB: convert to blocked memory format.
###   Note that 'output' is in blocked memory format,
###   in case the subsequent operator doesn't support blocked memory format
###   you need to manually reorder it back to NCHW by output.to_dense()
### mkldnn_utils.to_mkldnn(model) is used to prepack the weight, this will save weight reorder time
###   for inference. For training, it is not needed.
input = input.to_mkldnn()
model = mkldnn_utils.to_mkldnn(model)
output = model(input)

Better to explain the concepts here with a diagram, the dotted line indicate a simple memory view, no hard copy. fig-2(1)-pt-conv-layout-path-dispatch

Conclusion is that NHWC path saves the reorders from feature maps compared with NCHW path, but still weight reorder is necessary since oneDNN requires weight to be in blocked memory format. From performance perspective, when batch_size=N, weight reorder is minimum compared with feature map reorder. But when batch_size=1, weight reoder is usually not negligible. SO whether to enable weight prepacking on channels last memory format needs further discussion.

PyTorch Strided Layout

Before moving on, I feel it necessary to explain how PyTorch organize tensor in memory - the layout. Here we only focus on dense tensors, skip 'coo' layout of sparse tensor.

The question itself can be reinterpreted as for a tensor of size <N, C, H, W>, how does PyTorch accesses the element with index <n, w, h, w> from memory, the answer is stride:

tensor: <N, C, H, W>
index: <n, c, h, w>
strides: <CHW, HW, W, 1>
offset(n,c,h,w) = stride_n * n + stride_c * c + stride_h * h + stride_w * w
                = CHW * n + HW * c + W * h + 1 * w

One merit of introducing stride is it will be able to express noncontiguous tensor, e.g. a slice of big tensor. For example, the 'Xs' in the following image will have a stride of <n1+n2, 1>.

fig-3-pytorch-strided-layout

Keep in mind that PyTorch Tensor does not have an attribute so called 'memory_format' or something. The memory format expression completely relies on size and stride, design principle can be found at reference: RFC: Memory format (aka layout aka NHWC) support. So no matter what the tensor's memory format is, we need a logical canonical order for the dimensions - that is NCHW on PyTorch. Thus size and stride are ALWAYs describes in the order of NCHW. OK let's take a look at the Channels Last case of the previous question:

tensor: <N, C, H, W>
index: <n, c, h, w>
strides: <HWC, 1, WC, C>
offset(n,c,h,w) = stride_n * n + stride_c * c + stride_h * h + stride_w * w
                = HWC * n + 1 * c + WC * h + C * w

Actually, this pattern applies to ALL other memory formats as long as it is 4-dim, e.g. strides for CHWN would be <1, HWN, WN, N>.

PyTorch Channels Last Memory Format APIs

a. tensor creation

x = torch.empty(N, C, H, W, memory_format=torch.channels_last)

b. tensor conversion

### .contiguous() transforms NHWC noncontiguous to NHWC contiguous.
### .to() converts NCHW tensor to NHWC one, it is outplace.
x = x.contiguous(memory_format=torch.channels_last)
x = x.to(memory_format=torch.channels_last)

### contiguous check
x.is_contiguous(memory_format=torch.channels_last)

c. model conversion

### NB: tensor.to() is an outplace operation
###   model.to() is inplace. It calls _apply() which is inplace.
model = model.to(memory_format=torch.channels_last)
input = input.to(memory_format=torch.channels_last)

d. operator coverage

Detailed operator coverage information has been listed at reference Operators-with-Channels-Last-support. In brief, ImageNet training topologies on GPU already have full support on Channels Last memory format, while CPU doesn't.

Some spontaneous questions:

  • How to tell whether this model or operator support Channels Last? - This requires mannual memory format check, aka. 'torch.channels_last' input and weight shall NOT generate 'torch.contiguous_format' output.
  • What if the model comprises of operator not supported Channels Last? - No errors messages will be shown, the NHWC tensor will be handled by the operator as a non-contiguous NCHW tensor, so result might not be correct depending on the algorithm of this operator.

Writing Channels Last Kernels

a. Status on CPU

  • No support - Requires to register Channels Last kernel for CPU path, e.g. Conv2d;
  • Explicit support - Already have Channels Last kernel for CPU path (in ATen native manner), need to compare oneDNN counterpart performance, e.g. BatchNorm;
  • Implicit support - Supported via meta structures like 'TensorIterator', need to compare oneDNN counterpart performance, e.g. ReLU.

b. Register Channels Last Kernel in ATen Native Manner

The general guideline has been listed under reference Writing-memory-format-aware-operators, not to repeat here. You may take one of my recent PR optimize upsample performance linear mode on CPU as an example, which also demonstrates NHWC performance advantage over NCHW because of the ease of vectorization.

c. Register oneDNN Kernel on Channels Last

Essence of registering an oneDNN kernel under Channels Last memory format on CPU is no differenct from cuDNN: Only very few upper level change is needed such as accommodate 'contiguous()' to 'contiguous(suggested_memory_format)'. The automatic reorder of oneDNN weight shall been hided in ideep.

oneDNN NHWC APIs

Compared to NCHW interfaces, 2 parts need to be addressed on NHWC inferfaces:

a. Create NHWC Memory

The logical size and stride description of oneDNN is always in NCHW, this is identical to PyTorch. Example code such as

/* create md from memory::format_tag */
auto src_md = memory::desc(
        {N, C, H, W}, // logical dims, the order is defined by a primitive
        memory::data_type::f32, // tensor's data type
        memory::format_tag::nhwc // memory format, NHWC in this case
);

/* alternative: create md from strides */
auto src_md = memory::desc(
        {N, C, H, W}, // logical dims, the order is defined by a primitive
        memory::data_type::f32, // tensor's data type
        {stride_N, stride_C, stride_H, stride_W} // the strides
);

/* create memory */
auto src_mem = memory(src_md, src_data_ptr, engine);

b. Create Convolution Primitive

  • NCHW - create memory::desc with any card for 'input', 'output' and 'weight'; query proposed memory::desc from convolution primitive;
  • NHWC - create memory::desc with format_tag::nhwc for 'input' and 'output', use any for 'weight'; if we use hwio for 'weight' convolution primitive will be created with gemm rather jit avx512.

CPU Channels Last Targets

  • User Experience - No special user level code change, only 'input' and 'model' conversion is required;
  • Scenarios - cover both training and inference;
  • Models - ResNet50 and ResNext101, extended targets: torchvision models, detectron2;
  • Performance Targets - training >0.8x blocked; inference throughput > 0.8x blocked; inference latency? (need further discussion)
  • Operator Converage - No less than GPU path;
  • BFloat16 - This part shall align with big picture of BFloat16 integration (need further discussion);
  • int8 - Need further discussion.

TODO List

  • oneDNN - upgrade to 1.5 or higher;
  • ideep - interface change: ideep::tensor, ideep::computation;
  • ATen integration - ConvNd shall PR directly; BatchNorm, Pooling, etc. need performance compare with native ATen Channels Last kernels; PR inference and training at the same time, one operator at a time; Traninig first; Inference weight prepacking under discussion;
  • validation - oneDNN kernel level performance compare with NCHW and NHWC kernel; oneDNN NHWC kernel performance compare with native ATen Channels Last kernels; TTT measurement?
  • distributed - gloo backend or ccl backend? or we compare with only 1S on CPU?

Upstreaming

References

@pinzhenx
Copy link

pinzhenx commented Jun 23, 2020

Play around with this script that mimics framework integraion

#include <functional>
#include <chrono>
#include <iostream>
#include "dnnl.hpp"

using namespace dnnl;
using tag = memory::format_tag;
using dt = memory::data_type;
using dims = memory::dims;
engine eng(engine::kind::cpu, 0);
stream strm(eng);

memory reorder_if_necessary(memory& s, const memory::desc& desc) {
  if (s.get_desc() == desc)
    return s;
  memory d {desc, eng};
  reorder(s, d).execute(strm, s, d);
  return d;
}

void benchmark(std::function<void()> cb, size_t iters=1e4, size_t warmup=10) {
  using namespace std::chrono;
  for (size_t i = 0; i < warmup; i++) cb();

  auto start = high_resolution_clock::now();
  for (size_t i = 0; i < iters; i++) cb();
  auto elapsed = high_resolution_clock::now() - start;

  auto avg = 1.0 * duration_cast<milliseconds>(elapsed).count() / iters;
  std::cout << "avg: " << avg << " ms\n";
}

void test_conv(tag src_usr = tag::nhwc, tag src_query = tag::nhwc,
               tag wei_usr = tag::ohwi, tag wei_query = tag::any,
               tag dst_usr = tag::nhwc, tag dst_query = tag::nhwc) {
  dims x_sizes = {1, 64, 28, 28};
  dims w_sizes = {128, 64, 3, 3};
  dims y_sizes = {1, 128, 28, 28};

  memory x_usr({x_sizes, dt::f32, src_usr}, eng);
  memory w_usr({w_sizes, dt::f32, wei_usr}, eng);

  benchmark([&]() {
    auto pd = convolution_forward::primitive_desc({
        prop_kind::forward_training,
        algorithm::convolution_direct,
        {x_sizes, dt::f32, src_query},
        {w_sizes, dt::f32, wei_query},
        {y_sizes, dt::f32, dst_query},
        {1, 1},
        {1, 1},
        {1, 1}
    }, eng);

    auto x_opt = reorder_if_necessary(x_usr, pd.src_desc());
    auto w_opt = reorder_if_necessary(w_usr, pd.weights_desc());
    auto y_opt = memory(pd.dst_desc(), eng);

    convolution_forward(pd).execute(strm, {
      {DNNL_ARG_SRC, x_opt}, {DNNL_ARG_WEIGHTS, w_opt}, {DNNL_ARG_DST, y_opt},
    });

    auto y_usr = reorder_if_necessary(y_opt, {y_sizes, dt::f32, dst_usr});
  });
}

int main() {
  std::cout << "src: nchw (fixed) | wei: oihw (fixed) | dst: nchw\n";
  test_conv(tag::nchw, tag::nchw, tag::oihw, tag::oihw, tag::nchw, tag::nchw);

  std::cout << "src: any (nchw -> blocked) | wei: any (oihw -> blocked) | dst: any\n";
  test_conv(tag::nchw, tag::any, tag::oihw, tag::any, tag::nchw, tag::any);

  std::cout << "src: nhwc | wei: hwio (fixed) | dst: nhwc\n";
  test_conv(tag::nhwc, tag::nhwc, tag::hwio, tag::hwio, tag::nhwc, tag::nhwc);

  std::cout << "src: nhwc | wei: any (hwio -> blocked) | dst: nhwc\n";
  test_conv(tag::nhwc, tag::nhwc, tag::hwio, tag::any, tag::nhwc, tag::nhwc);

  std::cout << "src: nhwc | wei: any (ohwi -> blocked) | dst: nhwc\n";
  test_conv(tag::nhwc, tag::nhwc, tag::ohwi, tag::any, tag::nhwc, tag::nhwc);

  return 0;
}

Obtained on 8180 with single thread

src: nchw (fixed) | wei: oihw (fixed) | dst: nchw
avg: 1.0008 ms
src: any (nchw -> blocked) | wei: any (oihw -> blocked) | dst: any
avg: 0.7779 ms
src: nhwc | wei: hwio (fixed) | dst: nhwc
avg: 1.0359 ms
src: nhwc | wei: any (hwio -> blocked) | dst: nhwc
avg: 0.6743 ms
src: nhwc | wei: any (ohwi -> blocked) | dst: nhwc
avg: 0.6399 ms

Takeaway: for this specific shape, NHWC performs better than NCHW in the naive path, i.e. just plain in plain out.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment