Skip to content

Instantly share code, notes, and snippets.

@mingfeima
Last active December 26, 2023 07:16
Show Gist options
  • Star 6 You must be signed in to star a gist
  • Fork 2 You must be signed in to fork a gist
  • Save mingfeima/6205bc3f2676ce23c1e5cb9d2672a9ce to your computer and use it in GitHub Desktop.
Save mingfeima/6205bc3f2676ce23c1e5cb9d2672a9ce to your computer and use it in GitHub Desktop.
PyTorch CPU Performance Optimization Tutorial - Section III

Part III: Vectorization Techniques

(Training material on pytorch CPU performance optimization)

Chinese version for this chapter, link.

This section contains the following subjects:

  • Basic Knowledge
  • Example I: Prefix Sum
  • Example II: Horizontal Reduce
  • Special Case I: ChannelShuffle CF & CL Kernel
  • Special Case II: ShuffleNet Fusion

1. Basic Knowledge

Vectorization is technology to operate on a set of values at one time based on SIMD instruction set. We may use multiple ways for vectorization, such as compiler auto vectorization or manual vectorization via intrisincs, this section focus on the latter approach.

1.1 Data Types and Naming Conventions

All intrisincs on Intel Platform is listed on Intel Intrinsics Guide, the following shows data types on AVX2/AVX512.

__m256    // 256-bit vector containing 8 floats
__m256d   // 256-bit vector containing 4 doubles
__m256i	  // 256-bit vector containing integers
__m512    // 512-bit vector containing 16 floats
__m512d   // 512-bit vector containing 8 doubles
__m512i	  // 512-bit vector containing integers

Generally intrisincs functions are given in the naming convention of:

_mm<bit_width>_<operator_name>_<dtype>

Notes: may have the following values:

  • ps - packed single precision
  • pd - packed double precision
  • epi8/epi16/epi32/epi64 - extend packed signed integer 8-bit/16-bit/32-bit/64-bit
  • epu8/epu16/epu32/epu64 - extend packed unsigned interger 8-bit/16-bit/32-bit/64-bit
  • si128/si256/si512 - unspecified 128-bit/256-bit/512-bit vector (casting)

3.2 Intrinsics and PyTorch Vectorized Wrapper

PyTorch CPU native kernels under ATen mostly used manual vectorization with help of at::vec::Vectorized<T>, abbreviated as Vec in te rest of this tutorial.

  • Vec is a struct that wrappes up the SIMD intrinsics on different archs, e.g. AVX2, AVX512 or mobile arch.
  • By default, kernels with Vec will be compiled multiple times, once for each available arch. With GCC9, on intel platform, it will generate 3 sets of kernels: <kernel_name>.DEFAULT which is the scalar version, <kernel_name>.AVX2 and <kernel_name>.AVX512. With GCC8 or lower, AVX512 kernels will not be generated.
  • At runtime, OP will pick up kernel binary of most high level arch (AVX512>AVX2>Scalar) on the platform.

Fig-1 are some commonly used instrinsics: (a) initialization; (b) load/store; (c) gather/scatter; (d) arithmetic.

fig-1_intrinsics1

The above usage could also be mapped to Vec, such as

  Vec(1.0f); // initialization
  Vec x = Vec::loadu(addr); // load
  x.store(addr); // store
  Vec y = x * Vec(2.0f); // multiply

Note that gather/scatter are used for non-contiguous memory access which will access to mutiple cache lines, meaning that they are pretty slow. On occasions with constant stride access, we can usually replace them with load/store and permute/shuffle, that would be much faster, e.g. 'matrix transpose'.

Fig-2 shows more instrinsics: rearrange vector elements with permute and shuffle:

fig-2_intrinsics_more

Note that the ctrl is a 8-bit integer which controls the data movement pattern per lane (128-bit).

Example I: Prefix Sum

Continue with the example on Prefix Sum, Tutorial-2 uses this example to show how to parallel a sequentially dependent operation with blocking. This chapter shows how to vectorize it with intrinsics, as Fig-3 below. Together they are a complete job on optimizing Prefix Sum on a multi-core CPU.

fig-3_prefix_sum2

It takes 3 rounds of 'shift' and 'add' will do the job on AVX2 (one more round to go on AVX512):

template <>
inline void cumsum<float>(float base, const float* src, float* dst, int64_t n) {
  __m256 offset = _mm256_set1_ps(base);
  int64_t i;
#pragma unroll
  for (i = 0; i <= (n - Vectorized<float>::size()); i += Vectorized<float>::size()) {
    __m256 x = _mm256_loadu_ps(src + i);

    // shift 32 bit
    // x = {a0, a1, a2, a3, a4, a5, a6, a7}
    // y = { 0, a0, a1, a2, a3, a4, a5, a6}
    __m256 t0 = _mm256_permute_ps(x, 0x93);
    __m256 t1 = _mm256_permute2f128_ps(t0, t0, 0x29);
    __m256 y = _mm256_blend_ps(t0, t1, 0x11);
    x = _mm256_add_ps(x, y);

    // shift 64 bit
    // x = {a0, a01, a12, a23, a34, a45, a56, a67}
    // y = { 0,  0,   a0, a01, a12, a23, a34, a45}
    t0 = _mm256_permute_ps(x, 0x4E);
    t1 = _mm256_permute2f128_ps(t0, t0, 0x29);
    y = _mm256_blend_ps(t0, t1, 0x33);
    x = _mm256_add_ps(x, y);

    // shift 128 bit
    // x = {a0, a01, a012, a0123, a1234, a2345, a3456, a4567}
    // y = { 0,   0,    0,     0,    a0,   a01,  a012, a0123}
    y = _mm256_permute2f128_ps(x, x, 0x29);
    x = _mm256_add_ps(x, y);
    x = x + offset;

    _mm256_storeu_ps(dst + i, x);

    // broadcast the offset
    t0 = _mm256_permute2f128_ps(x, x, 0x11);
    offset = _mm256_permute_ps(t0, 0xFF);
  }
  float offset_val = _mm256_cvtss_f32(offset);
#pragma unroll
  for (; i < n; ++i) {
    offset_val += src[i];
    dst[i] = offset_val;
  }
}

Example II: Horizontal Reduce

Horizontal Reduce refers to reduce a vector to scalar. Usually when reducing a row of data, we take 2 steps: a) first accumulate vector by vector; b) reduce the final vector to a scalar. #73953 is used for optimizing nn.Softmax and nn.LogSoftmax when dim = -1. Usually in the MultiheadAttention of Transformer model, the dimension size for reduction in Softmax is not very big, therefore the final vector to scalar reduction becomes a hotspot.

fig-4_hsum

Similar to the previous example, it takes 3 rounds of 'shuffle' and 'add' to do the job on AVX2 (one more round to go on AVX512):

template <typename scalar_t=float, typename Op>
inline float vec_reduce_all(
    const Op& vec_fun,
    vec::Vectorized<float> acc_vec) {
  using Vec = vec::Vectorized<float>;
  Vec v = acc_vec;

  // 128-bit shuffle
  Vec v1 = _mm256_permute2f128_ps(v, v, 0x1);
  v = vec_fun(v, v1);
  // 64-bit shuffle
  v1 = _mm256_shuffle_ps(v, v, 0x4E);
  v = vec_fun(v, v1);
  // 32-bit shuffle
  v1 = _mm256_shuffle_ps(v, v, 0xB1);
  v = vec_fun(v, v1);

  return _mm256_cvtss_f32(v);
}

Note on the previous kernel: vec_fun is the lambda function used for reduction, for 'horizontal sum' it is going to be 'add'.

Special Case I: ChannelShuffle CF & CL Kernel

ChannelShuffle is used to divide and rearrange the channel dimension of input tensor. The parallel scheme is mapped to the output shape of this operator, as shown in Fig-5 (suppose G=2 and C=4):

fig-5_channels_shuffle1

On channels first memory format, we parallel on {N * C * G}, get the corresponding input offset and do a memory copy on {H * W};

  using Vec = vec::Vectorized<scalar_t>;
  int64_t inner_size = image_size - (image_size % Vec::size());
  at::parallel_for (0, nbatch * /* oc*g */channels, 0, [&](int64_t begin, int64_t end) {
    int64_t n = 0;
    int64_t oc = 0;
    int64_t g = 0;
    data_index_init(begin, n, nbatch, oc, channels_per_group, g, groups);

    for (const auto i : c10::irange(begin, end)) {
      scalar_t* output_ptr = output_data + i * image_size;
      scalar_t* input_ptr = input_data + n * channels * image_size +
          g * channels_per_group * image_size + oc * image_size;

      int64_t d = 0;
      for (; d < inner_size; d += Vec::size()) {
        Vec data_vec = Vec::loadu(input_ptr + d);
        data_vec.store(output_ptr + d);
      }
      for (; d < image_size; d++) {
        output_ptr[d] = input_ptr[d];
      }

      // move on to next output index
      data_index_step(n, nbatch, oc, channels_per_group, g, groups);
    }
  });

On channels last memory format, we parallel on {N *H * W} and do a transposition from {G, C} to {G, C}.

  at::parallel_for(0, nbatch * image_size, 0, [&](int64_t begin, int64_t end) {
    for (const auto i : c10::irange(begin, end)) {
      scalar_t* output_ptr = output_data + i * channels;
      scalar_t* input_ptr = input_data + i * channels;

      // transpose each channel lane:
      // from [groups, channels_per_group] to [channels_per_group, groups]
      utils::transpose(groups, channels_per_group, input_ptr, channels_per_group, output_ptr, groups);
    }
  });

It's actually very simple to construct parallelized kernel with dimension collapse. The utils::transpose will finally go to fbgemm with a fully vectorized matrix transposition functionality.

Special Case II: ShuffleNet Fusion

Now that we are here, let's take one more step to see how to improve performance of ShuffleNet a little bit. On ShuffleNet, the module depthwise_conv has the pattern of 'cat' + 'channel_shuffle' which can be fused together, as shown in Fig-6:

fig-6_shuffle_net

Suppose the memory format is channels last, it's actually fairly simple to construct the fused kernel with all the helpers from PyTorch, (pseudo as below). We parallel on {N, H, W} and do an interleave copy on {C}:

  // x1_stride/x2_stride may be C or 2C
  // out stride is 2C
  at::parallel_for(0, nbatch * height * width, 0, [&](int64_t begin, int64_t end) {
    for (int64_t i = begin; i < end; ++i) {
      scalar_t* x1_ptr = x1_data + i * x1_stride;
      scalar_t* x2_ptr = x2_data + i * x2_stride;
      scalar_t* out_ptr = out_data + i * 2 * channels;
      int64_t d = 0;
      for (; d < channels - (channels % Vec::size()); d += Vec::size()) {
        Vec x1 = Vec::loadu(x1_ptr + d);
        Vec x2 = Vec::loadu(x2_ptr + d);
        Vec out1, out2;
        std::tie(out1, out2) = vec::interleave2(x1, x2);
        out1.store(out_ptr + d);
        out2.store(out_ptr + d + Vec::size();
      }
      for (; d < channels; ++d) {
        out_ptr[d] = x1_ptr[d];
        out_ptr[d + 1] = x2_ptr[d];
      }
    }
  });

So in this way we can almost save time spent on 'channel shuffle', since payload of the 'fused kernel' is only slightly bigger than a usual 'cat'.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment