Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Star 15 You must be signed in to star a gist
  • Fork 3 You must be signed in to fork a gist
  • Save mingfeima/f040ede25b4797740634ab778b2f5888 to your computer and use it in GitHub Desktop.
Save mingfeima/f040ede25b4797740634ab778b2f5888 to your computer and use it in GitHub Desktop.
PyTorch CPU Performance Optimization Tutorial - Section I

Part I: Memory Formats and Channels Last Optimization

(Training material on pytorch CPU performance optimization)

Chinese version for this chapter, link.

Notes:

  • General idea is to share basic CPU performance optimization techniques.
  • Majority of the material is about pytorch's 'native' kernel, aka 'ATen'.
  • Code snippet may be shared in the material. The material may contain code not upstreamed yet, PR link is shared for this case.
  • Kernel code with pytorch parallelization wrapper (at::parallel_for) and vectorization wrapper can be easily transformed to original version with omp paragma or intrinsics.
  • mkldnn is onednn's old name
  • CL refers to 'channels last' while CF refers to 'channels first'

This section contains the following subjects:

  • memory format: logical and physical order
  • memory format propagation rules in convolution
  • MaxPool2d kernel on Channels First and Channels Last
  • Special Case I: Upsampling kernel on Channels First
  • Special Case II: AvgPool3d from VGGM

1. Memory Format: Physical and Logical Order

So let's start with a very important concept Memory Format which is the fundamental of optimizing CV related operators.

Memory format refers to data representation that describes how a multidimensional (nD) array is stored in linear (1D) memory address space. Sometimes, also called data format, layout. Be aware that in PyTorch, layout has a different semantics, torch.strided (dense Tensors) or torch.sparse_coo (sparse tensor). Another type of layout is torch._mkldnn which refers to mkldnn's blocked memory format. In this article, only dense tensor is involved.

The idea is Memory Format has two meaning of two aspects: i) how data is stored in the memory; ii) how you 'view' that part of memory logically.

Physical Order is about how data is stored in memory. In the CV domain, we talk about NCHW, NHWC, they are the order of physical memory layout, also referred as Channels First and Channels Last. Performance is the primary concern when choosing physical order since it represents how to access individual data.

Logical Order is a convention how you describe the tensor shape and stride. In PyTorch, this convention is NCHW. So no matter what the physical order is, tensor shape and stride will always be in the order of N,C,H,W. ALSO the indexing will always be in the order of NCHW as well. This logical order is not necessarily related to physical order (or to say the actual data storage).

1.1 Channels First and Channels Last

In CV domain, channels first (NCHW) and channels last (NHWC) are the most commonly used memory formats. (Actually we can have others, i used to work on a framework named 'Neon', it has memory format of CHWN which is performance friendly for training e.g. N = 64, 128, 256 etc.)

Fig-1 is a graphical view how you access the index of [1, 1, 2, 3] of a tensor 'A' with a shape of [2, 3, 4, 4] on both channels first (a) and channels last (b) memory format:

fig-1_memory_format_2

Note that sometimes we can collapse dimensions with identical semantic concepts, e.g. on BatchNorm2d we can treat height and width as one single dimension, the collapses index ii = h * width + w as shown above.

1.2 Data Indexing with Strides

As previously mentioned, the logical order on PyTorch is NCHW which is the conventional order for shape, stride and index. The following shows relationship between strides and corresponding memory format:

/*
 * (n, c, h, w) is the index
 * (N, C, H, W) is size of dimension
 *
 * Channels First (NCHW) strides:  (CHW, HW, W, 1)
 * Channels Last  (NHWC) strides:  (HWC, 1, WC, C)
 *                (CHWN) strides:  (1, HWN, WN, N)
 */

/* value for index n,c,h,w under memory format of NCHW */
scalar_t v = X[n * C * H * W + c * H * W + h * W + w];

/* value for index n,c,h,w under memory format of NHWC */
scalar_t v = X[n * H * W * C + h * W * C + w * C + c];

/* value for index n,c,h,w under memory format of CHWN */
scalar_t v = X[c * H * W * N + h * W * N + w * N + n];

As a matter of fact, the PyTorch tensor implementation doesn't keep an attribute recording the memory format, it records strides on each dimension, and map to corresponding memory format with the rules above; Keeping strides can perform more functionalities such as memory view, non-contiguous tensor storage, etc.

2. Memory Format Propagation Rules in Convolution

On channels first, it is impossible to achieve highly optimized performance directly due to the fact that the dimension required to be vectorized (e.g. channels) is NOT the most inner dimension (except for the Conv that can be treated as gemm). So on channels first, both input and weight are reordered into blocked format (e.g. nChw16c, OIhw16i16o). Once oneDNN finished computing the convolution primitive, output needs to be reordered back to NCHW as well. These reorders are a downgrade for the overall performance.

One channels last, input and output tensor doesn't need to be reordered to blocked format to achieve high performance, so it is a 'memory view' passed to oneDNN primitive. However, weight still need to be reordered to blocked format. For inference scenario, weight can be prepacked (pre-reordered and cached) on jitted model to reduce overhead.

Fig-2 shows how memory format is propagated on Conv2d in PyTorch CPU path. fig-2_conv_prop_3

Generally channels last would have better performance than channels first on convolution since it skips reorder on activation.

On PyTorch, the default memory format is channels first (NCHW). In case a particular operator doesn't have explicit support on channels last (NHWC), the channels last input would be treated as a non-contiguous NCHW tensor and thus generating a NCHW output, therefore the memory format propagation chain will be broken. So it is important to make sure ALL the memory format aware operators to have explicit channels last support, the corresponding work is on PyTorch Channels Last Memory Format Performance Optimization on CPU Path.

3. MaxPool2d Kernel on Channels First and Channels Last

For ALL the memory format aware operators, channels first and channels last have different kernel implementations (if they can share the same implementation, the operator would not be 'memory format aware').

Depending the semantics of operators:

  • CL has better performance than CF: Conv2d, ConvTransposed2d, MaxPool2d, UpsampleNearest2d, etc.
  • CL has equal performance as CF: BatchNorm2d, etc.
  • CL has slightly worse performance than CF GroupNorm2d, ChannelShuffle, PixelShuffle, etc. (they manipulate dimension 'channels' which made it more difficult for vectorization)

3.1 Example: MaxPool2d

MaxPool2d is a kernel-base operator. On channels first, the kernel is unable to be vectorized (the memory access pattern on the inner most dimension, e.g. 'width', is non-contiguous), so we parallel on all the available dimensions of NCHW and the kernel looks like:

  // parallel on dim N, C, H, W
  at::parallel_for(0, numel, 0, [&](int64_t begin, int64_t end) {
    int64_t c = 0;
    int64_t oh = 0;
    int64_t ow = 0;
    data_index_init(begin, c, channels, oh, output_height, ow, output_width);

    for (const auto i : c10::irange(begin, end)) {
      int64_t ih0 = oh * dH - padH;
      int64_t iw0 = ow * dW - padW;
      int64_t ih1 = std::min(ih0 + (kH - 1) * dilationH + 1, input_height);
      int64_t iw1 = std::min(iw0 + (kW - 1) * dilationW + 1, input_width);
      while(ih0 < 0) { ih0 += dilationH; }
      while(iw0 < 0) { iw0 += dilationW; }

      // local pointers
      scalar_t* input_ptr = input_data + c * input_height * input_width;

      // compute local max
      int64_t maxindex = ih0 * input_width + iw0;
      accscalar_t maxval = -std::numeric_limits<accscalar_t>::infinity();
      for (int64_t ih = ih0; ih < ih1; ih += dilationH) {
        for (int64_t iw = iw0; iw < iw1; iw += dilationW) {
          int64_t index = ih * input_width + iw;
          accscalar_t val = accscalar_t(input_ptr[index]);
          if ((val > maxval) || std::isnan(val)) {
            maxval = val;
            maxindex = index;
          }
        }
      }

      // set output to local max and store location of max
      output_data[i] = scalar_t(maxval);
      indices_data[i] = maxindex;

      // move on to next output index
      data_index_step(c, channels, oh, output_height, ow, output_width);
    }
  });

Notes on the kernel above:

  • the kernel above treats nbatch and channels as one dimension (e.g. channels = nbatch * channels), since logically they are generic;
  • at:parallel_for is a wrapper on parallelization runtime (OpenMP or TBB, default is OpenMP), you can take it simply as #pragma omp parallel;
  • at::parallel_for accepts 4 params: param_0 & param_1 defines the problem size; param_2 defines grain size which is the minimal amount of payload each thread takes; param_3 defines 'task' for each thread and '[begin, end)' is the global index for corresponding thread;
  • data_index_init and data_index_step are utils templates for incrementally indexing data, will explain in later section on the concept of dimension collapse, you can just take it as a parallel version of for_each.

On channels last, the memory access pattern on most inner dimension, e.g. 'channels', is contiguous, thus we can vectorize it on channels and parallel on the left dimensions of NHW. The kernel looks like:

  // parallel on dim N, H, W
  at::parallel_for(0, nbatch * output_height * output_width, 0, [&](int64_t begin, int64_t end) {
    int64_t n = 0;
    int64_t oh = 0;
    int64_t ow = 0;
    data_index_init(begin, n, nbatch, oh, output_height, ow, output_width);

    int64_t size = channels;
    int64_t len = size - (size % Vec::size());
    // temp buffer holding index with integer_t
    std::unique_ptr<integer_t []> index_buffer(new integer_t[len]);

    for (const auto i : c10::irange(begin, end)) {
      int64_t ih0 = oh * dH - padH;
      int64_t iw0 = ow * dW - padW;
      int64_t ih1 = std::min(ih0 + (kH - 1) * dilationH + 1, input_height);
      int64_t iw1 = std::min(iw0 + (kW - 1) * dilationW + 1, input_width);
      while(ih0 < 0) { ih0 += dilationH; }
      while(iw0 < 0) { iw0 += dilationW; }

      scalar_t* out = output_data + i * channels;
      int64_t* ind = indices_data + i * channels;

      // Pass I: init out lane
      iVec index0_vec = iVec(ih0 * input_width + iw0);
      Vec out_vec = Vec(-std::numeric_limits<scalar_t>::infinity());
      int64_t d1 = 0;
      for (; d1 < len; d1 += Vec::size()) {
        index0_vec.store(index_buffer.get() + d1);
        out_vec.store(out + d1);
      }
      for (; d1 < size; d1++) {
        ind[d1] = ih0 * input_width + iw0;
        out[d1] = -std::numeric_limits<scalar_t>::infinity();
      }
      // Pass II: compute local max
      for (int64_t ih = ih0; ih < ih1; ih += dilationH) {
        for (int64_t iw = iw0; iw < iw1; iw += dilationW) {
          scalar_t* in = input_data + n * input_height * input_width * channels +
              ih * input_width * channels + iw * channels;

          int64_t d2 = 0;
          for (; d2 < len; d2 += Vec::size()) {
            iVec index_vec = iVec(ih * input_width + iw);
            Vec val_vec = Vec::loadu(in + d2);
            iVec maxindex_vec = iVec::loadu(index_buffer.get() + d2);
            Vec maxval_vec = Vec::loadu(out + d2);

            // true = all ones, false = all zeros
            Vec mask = (val_vec > maxval_vec) | val_vec.isnan();
            iVec imask = vec::cast<integer_t>(mask);
            Vec out_vec = Vec::blendv(maxval_vec, val_vec, mask);
            iVec ind_vec = iVec::blendv(maxindex_vec, index_vec, imask);

            out_vec.store(out + d2);
            ind_vec.store(index_buffer.get() + d2);
          }
          for (; d2 < size; d2++) {
            int64_t index = ih * input_width + iw;
            scalar_t val = in[d2];
            int64_t maxindex = ind[d2];
            scalar_t maxval = out[d2];

            bool mask = (val > maxval) || std::isnan(val);
            out[d2] = mask ? val : maxval;
            ind[d2] = mask ? index : maxindex;
          }
        }
      }
      // convert indice data type
      vec::convert<integer_t, int64_t>(index_buffer.get(), ind, len);

      // move on to next output index
      data_index_step(n, nbatch, oh, output_height, ow, output_width);
    }
  });

Notes on the kernel above:

  • PyTorch index tensor has dtype of int64_t which is not of the same vector length as float, so i used a proxy index dtype of int32_t here to vectorize the entire kernel.
  • Vec = at::vec::Vectorized<scalar_t> is a wrapper on SIMD vectorized logic, it will be compiled according on different arch, e.g. avx2, avx512 or arm instruction set for mobile.
  • Aside from the PyTorch Vec wrapper, we can write simple C++ code and add #pragma omp simd on the loop and it will also be auto-vectorized with ICC (GCC won't vectorize the code because of the blending, need more trick).

Without the PyTorch wrappers of parallel and Vec, the kernels above are actually fairly simple, nothing more than the sketchup below:

fig-3_max_pooling_1

4. Special Case I: Upsampling

Here comes my personal favorate, special cases. Usually as long as channels last can be used, i did not put too much effort in channels first optimization. But there are exceptions, upsampling is one of them.

Two reasons for why upsampling channels first also requires optimization:

  • it can be used as interpolate, no way we can ensure that the input is CL.
  • although the quantized model take channels last as default (Conv outputs CL tensor), GAN models run upsampling prior to Conv which indicate the upsampling is still CF.

The major performance bottleneck is calculating the input window indices given a output index (a lot of dtype conversion and scaling work) and the real payload is merely copy a single floating point for each output index, as shown in Fig-4(a):

fig-4_upsampling_cf

Compare the following implementations:

  • Impl-1: directly parallel on NCHW (same as MaxPool2d CF kernel above), indexing calculation will be overwhelming and we are doing redundant job here since each feature map will have the identical window indices across nbatch and channels.
  • Impl-2: an improved idea would be pre-calculate window indices for the output feature map plane, this requires a thread local buffer holding the indices, size of {output_height, output_width, 2} * sizeof(int64_t). This is good enough for commonly used CNN modules shapes. But if we can interpolating an image of [1, 3, 768, 1024], it is even worse than Impl-1 because: i) the temp buffer size is too big; ii) parallel on NC won't utilize all the CPU cores as the problem size is not big enough.
  • Impl-3: tradeoff between the previous 2 impls: only caching window indice on dimension width and temp buffer size would be {output_width} * sizeof(int64_t) and parallel on NCH as shown in Fig-4(b).

The code is placed at #69600, kernel looks like:

  std::unique_ptr<int64_t []> input_offset_arr(new int64_t[output_width]);
  int64_t* input_offset = input_offset_arr.get();
    
  for (const auto w2 : c10::irange(output_width)) {
    const int64_t w1 = nn_compute_source_index_fn(width_scale, w2, input_width);
    input_offset[w2] = w1;
  }
    
  int64_t grain_size = internal::GRAIN_SIZE / std::max(int64_t{1}, output_width);
  at::parallel_for(0, channels * output_height, grain_size, [&](int64_t begin, int64_t end) {
    int64_t nc{0}, h2{0};
    data_index_init(begin, nc, channels, h2, output_height);
    
    for (const auto i : c10::irange(begin, end)) {
      const int64_t h1 = nn_compute_source_index_fn(height_scale, h2, input_height);
      const auto* pos1 = &i_p[nc * input_height * input_width + h1 * input_width];
      auto* pos2 = &o_p[i * output_width];
      
      for (const auto w2 : c10::irange(output_width)) {
        const int64_t w1 = input_offset[w2];
        pos2[w2] = pos1[w1];
      }
      
      data_index_step(nc, channels, h2, output_height);
    }
  });

We can take one step ahead with Magical Nunbers: on GAN variants the upsampling usually have a scale factor of 2. On this situation, we can skip the input indice computing and vectorized the entire logic, it simply requires using vec::interleave2:

// interleave copy
Vec o1, o2;
std::tie(o1, o2) = interleave2(a, a);
 
// interleave2 
template <>
std::pair<Vectorized<float>, Vectorized<float>>
inline interleave2<float>(const Vectorized<float>& a, const Vectorized<float>& b) {
  // inputs:
  //   a = {a0, a1, a2, a3, a4, a5, a6, a7}
  //   b = {b0, b1, b2, b3, b4, b5, b6, b7}

  // swap lanes:
  //   a_swapped = {a0, a1, a2, a3, b0, b1, b2, b3}
  //   b_swapped = {a4, a5, a6, a7, b4, b5, b6, b7}
  // TODO: can we support caching this?
  auto a_swapped = _mm256_permute2f128_ps(a, b, 0b0100000);  // 0, 2.   4 bits apart
  auto b_swapped = _mm256_permute2f128_ps(a, b, 0b0110001);  // 1, 3.   4 bits apart

  // group cols crossing lanes:
  //   return {a0, b0, a1, b1, a2, b2, a3, b3}
  //          {a4, b4, a5, b5, a6, b6, a7, b7}
  const __m256i group_ctrl = _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7);
  return std::make_pair(_mm256_permutevar8x32_ps(a_swapped, group_ctrl),
                        _mm256_permutevar8x32_ps(b_swapped, group_ctrl));
}

This is how to achieve optimal performance :)

5. Special Case II: AvgPool3d from VGGM

Usually AvgPool3d on channels first memory format will suffer bad performance due to it can't be vectorized. VGGM has a very weird usage of AvgPool3d from the module SpatialCrossMapLRN, it has a kernel of {K, 1, 1}. The maigical number grants the possibility to fully vectorize the kernel, by having height and weight dimension untouched we can vectorize on HW.

For parallelization, we have two choices:

  • parallel on NC if NC is big enough, e.g. NC = 64;
  • parallel n NCD if NC is not big enought, e.g. NC = 3.

The overall scheme is shown on fig-5:

fig-5_vggm_avg_pool3d_1

For this particular model, we can also fuse the other element wise operators (pow, div, mul, and add) into AvgPool3d to achieve optimal performance.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment