Skip to content

Instantly share code, notes, and snippets.

@mingfeima
Last active July 8, 2022 06:04
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save mingfeima/fcd3c89e32983c6d7033693cea046e4a to your computer and use it in GitHub Desktop.
Save mingfeima/fcd3c89e32983c6d7033693cea046e4a to your computer and use it in GitHub Desktop.
PyTorch CPU Performance Optimization Tutorial - Section IV

Part IV: BFloat16 Kernel Optimization

(Training material on pytorch CPU performance optimization)

Chinese version for this chapter, link.

This section contains the following subjects:

  • Basic Knowledge
  • General Principle
  • Use Vectorization
  • Fundamentals
  • Kernel Optimization Techniques

1. Basic Knowledge

BFloat16 (Brain Floating Point)[1][2] floating-point format is a computer number format occupying 16 bits, BFloat16 has the same dynamic range as float32 but less precision. On next generation Xeon, Sapphire Rapids, Convolution and GEMM computation power is greatly improved on BFloat16 with help of new instruction set AMX (Advanced Matrix Extensions).

This chapter focuses on optimization techniques on BFloat16 applied on PyTorch rather than training accuracy tuning tricks.

2. General Priciple

First of all, it's very important to understand that BFloat16 is more like a 'storage type' rather than 'data type'. The hardware supports only dot product on bfloat16 with instruction _mm512_dpbf16_ps (AVX512-BF16) and _tile_dpbf16ps (AMXBF16), so it's primariy used to speedup computation intensive operators such as convolution and matmul. For the rest of arithmetic operations, bfloat16 need to be converted back to float32 for computation, e.g. +,-,*,/.

The general principles when enabling bfloat16 on PyTorch are:

  • nn.ConvNd and nn.Linear will go to oneDNN.
  • for the rest nn OPs and tensor OPs under torch, optimize as ATen native kernel.

Optimizations on native kernels include (not limited):

  • nn.BatchNorm - support mixed dtype
  • nn.LayerNorm - support mixed dtype
  • nn.GroupNorm
  • nn.{Max|Avg}PoolNd
  • nn.Adaptive{Max|avg}PoolNd
  • nn.ChannelShuffle
  • nn.PixelShuffle
  • nn.UpSample - 'nearest', 'bilinear', 'bicubic', 'trilinear'
  • Activations - ReLU, Silu, Prelu, etc.
  • Advanced Indexging - gather, scatter, etc.
  • ROIAlign, ROIPool (TorchVision)

And many others...

Actually the majority of the duty is on optimizing native ATen kernels rather than integrating oneDNN functionality into PyTorch, but it is a must do otherwise bfloat16 is complete unusable on PyTorch imperative path.

Methods for optimizing PyTorch native kernels share a lot of similarity between BFloat16 and int8:

  BFloat16 Int8
dtype conversion bf16_fp32/fp32_bf16 dequantize/quantize
arithmetic convert to fp32 convert to fp32
accumulation acc on fp32 acc on int32
non-arithmetic direct copy direct copy

3. Use Vectorization

3.1 Dtype Conversion is slow

Even with vectorization, the dtype conversion between BFloat16 and Float32 is still slow (especially fp32_bf16 when you have compute the rounding), and the scalar dtype conversion is even more in-efficient than vectorized version. Therefore, the fundamental ideas here are fairly simple:

  • reduce dtype conversion as much as possible (help a lot with reducing rounding error as well)
  • use vectorized logic as much as possible

3.2 Memory Format Considerations

The BFloat16 optimization project actually goes side by side with Channels Last optimization project. This because some OPs are unable to be vectorized on Channels First (e.g. nchw) memory format, such as MaxPool2d. And if we cannot vectorize, BFloat16 is going to be even slower than Float32 which makes no sense for end users.

4. Fundamentals

On PyTorch, BFloat16 is stored as uint16_t and operator overloaded at both scalar level and vector level, which means you can write your kernel freely with the scalar logic (example-1) and Vectorized<BFloat16> (exmple-2):

/* 
 * Example-1: Use scalar overload
 */
for (int64_t i = 0; i < 16; ++i) {
  float input_val = BFloat16(input_data[i]);
  output_data[i] = BFloat16(input_val * 2.0);
}

/*
 * Example-2: Use vector overload
 */
using bVec = vec::Vectorized<BFloat16>;
using fVec = vec::Vectorized<float>;

bVec data_bvec = bVec::loadu(input_data);
fVec data_fvec0, data_fvec1;
std::tie(data_fvec0, data_fvec1) = convert_bfloat16_float(data_bvec);
fVec out_fvec0 = data_fvec0 * fVec(2.0);
fVec out_fvec1 = data_fvec1 * fVec(2.0);
bVec out_bvec = convert_float_bfloat16(out_fvec0, out_fvec1);
out_bvec.store(output_data + d);

Example-2 is much faster than example-1!

fig-1

But only having the fundamentals are not enough, it's only a guarantee that BFloat16 is supported rather than optimized. If we stop here, it is almost for sure that BFloat16 is going to be extremely slow.

5. BFloat16 Kernel Optimization Techniques

(I will list the pull request link if the feature is not merged, otherwise the source code file)

5.1 Reduce redundant dtype conversion

If the operator has a couple of consecutive arithmetic operations, only do bf16->fp32 conversion for the input once and do fp32->bf16 conversion for output once, no dtype conversion in between.

One example is nn.Sigmoid from aten/src/ATen/native/cpu/UnaryOpsKernel.cpp

/*
 * Example-3: sigmoid
 *
 * sigmoid will compute -, exp, +, /.
 * the code will also compile without the 1st branch,
 * but it will do bf16/fp32 dtype conversion for 4 times instead of 1.
 */
 
 // BFloat16 vectorized path
 Vectorized<float> a0, a1;
 std::tie(a0, a1) = convert_bfloat16_float(a);
 a0 = (Vectorized<float>(static_cast<float>(1)) + a0.neg().exp()).reciprocal();
 a1 = (Vectorized<float>(static_cast<float>(1)) + a1.neg().exp()).reciprocal();
 return convert_float_bfloat16(a0, a1);
 
 // float32 vectorized path
 a = Vectorized<scalar_t>(static_cast<scalar_t>(0)) - a;
 a = a.exp();
 a = Vectorized<scalar_t>(static_cast<scalar_t>(1)) + a;
 a = a.reciprocal();
 return a;

5.2 Use Float32 as accumulation type

When doing accumulation on BFloat16, use Float32 as the acc dtype. This is not only a request for performance optimization which will also prevent from overflow.

One example is nn.Softmax from aten/src/ATen/native/cpu/SoftMaxKernel.cpp, the wrapper vec::reduce_all is used to reduce max value per given dimension:

/*
 * Example-4: reduction
 *
 * when scalar_t is BFloat16, the acc type is Float32.
 */
scalar_t max_input = vec::reduce_all<scalar_t>(
              [](Vec& x, Vec& y) { return vec::maximum(x, y); },
              input_data,
              dim_size);

In the util above, bf16 input data will be converted to fp32 only once and sum is accummulated on fp32, this is not only requirement for improving performance but also a MUST to ensure numerical stability.

5.3 Use Float32 to store immediate value

When we need to store immediate values, make sure they are stored in Float32. Use as less dtype conversion as possible. One example is nn.MaxPool2d channels last kernel from aten/src/ATen/native/cpu/MaxPoolKernel.cpp

When iterating the input feature map plane to retrieve max, we keep a temp buffer size of {channels} on each thread to store the immediate value in fp32.

5.4 Cache input data in Float32 if necessary

If we need to use the input data or parameter multiple times, we can try to cache the input data in Float32 (the buffer size should be L1 hit so as not to increase too much memory payload).

Example on nn.LayerNorm: nn.LayerNorm - #71376 we cache input lane of C in Float32 since it will be used twice (first time computing mean and rstd and second time applying to the output buffer), also parameter of gamma and beta are cached in Float32 as well since they are constant.

/* Example-5: cache input and parameter in float32
 *
 * temp buffer holding input, gamma/beta (if defined) in float
 *
 * pre convert input slice to float has 2 benefits:
 *   a. Welford algorithm involves more arithmetic operations,
 *      this will reduce rounding error and improve performance.
 *   b. The input slice (float) can be reused when updating
 *      corresponding output slice.
 */
int64_t buffer_size = pre_convert_gamma_beta ? 3 * N : N;
std::unique_ptr<float []> buffer(new float[buffer_size]);
float* input_buffer_ptr = buffer.get();
float* gamma_buffer_ptr = nullptr;
float* beta_buffer_ptr = nullptr;
if (pre_convert_gamma_beta) {
  gamma_buffer_ptr = buffer.get() + N;
  beta_buffer_ptr = buffer.get() + 2 * N;
  vec::convert(gamma_data, gamma_buffer_ptr, N);
  vec::convert(beta_data, beta_buffer_ptr, N);
}

Note that the same method doesn't apply to BatchNorm for training when N is large since BatchNorm computes mean and rstd across {N, H, W} and this may have large memory footprint.

[TBD] add Softmax optimization kernel.

References

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment