Skip to content

Instantly share code, notes, and snippets.

@drisspg
Created June 22, 2024 23:14
Show Gist options
  • Save drisspg/64600f98c4a0cb41917afe81e757469e to your computer and use it in GitHub Desktop.
Save drisspg/64600f98c4a0cb41917afe81e757469e to your computer and use it in GitHub Desktop.

Scaled Conversion to Fp8 Dtype

When converting from FP32 to the FP8 E4M3 format, understanding the intricacies of the FP8 data type is crucial. For an in-depth look at FP8, refer to our detailed presentation: insert link to slides. The FP8 E4M3 format offers a more compact dynamic range compared to FP32. To minimize quantization errors and fully utilize the available range, we introduce a scaling factor. This factor maps values from a high-precision tensor to its lower-precision counterpart. F

This process is akin to symmetric quantization, a technique well-explained in the context of INT8 in this insightful NVIDIA blog post: Achieving FP32 Accuracy for INT8 Inference. Imagine two distributions: one represents the original high-precision values, and the other, the target low-precision range. We determine the maximum value of the high-precision distribution (abs_max) and align it with the limits of the FP8 range. If abs_max is less than the maximum representable value in FP8 (torch.finfo(float8_dtype).max), the scale factor will be greater than 1.0, effectively "stretching" the tensor to span the entire FP8 range. Conversely, if abs_max exceeds the FP8 maximum, the scale factor will be less than 1.0, "compressing" the tensor to fit within the FP8 constraints.

In this note, 'scale' will refer to the multiplier used when transitioning from high precision to FP8 data type. Conversely, 'inverse_scale' will be used when upcasting from FP8 to a high-precision data type.

To summarize:

  • Scale = Quantization (Mapping to FP8)
  • Inverse Scale = Dequantization (Mapping back to high precision)

By implementing these scaling techniques, we aim to harness the power of FP8 computation without compromising the numerics of the underlying model.

Pseudo Code for Quantization:

def cast_to_fp8(tensor, scale):
	tensor_scaled = tensor * scale
	bits_fp8 = to_fp8_saturated(tensor_scaled, float8_dtype)
	# Store high precision dtype for eventual dequantization
	Float8Tensor(bits_fp8, scale, tensor.dtype)

Pseudo Code for Dequantization:

def to_original_precision(tensor: Float8Tensor):
	inverse_scale = (tensor._scale).reciprocal()
	return tensor._data.to(tensor._orig_dtype) * inverse_scale

In these snippets, cast_to_fp8 handles the conversion of a high-precision tensor to an FP8 representation by scaling it and then converting it to the FP8 format. The Float8Tensor object encapsulates the quantized data along with its scale and the original data type for later restoration.

The to_original_precision function takes a Float8Tensor object, computes the inverse of its scale, and applies it to the FP8 data converted back to its original data type, effectively dequantizing the data.

Enhanced Matrix Multiplication with _scaled_mm

Luca's recent insightful post delves into the _scaled_mm operator, a new addition to the PyTorch Core that we've begun to integrate. You can read Luca's detailed exploration here: LINK.

This operator serves as an efficient wrapper for the CuBLAS's GEMM API, which is designed to work with FP8 tensors. For more technical specifics, NVIDIA's documentation provides a comprehensive overview: FP8 Data Types in CuBLAS.

The PyTorch API for _scaled_mm is currently in a prototype stage, as indicated by the leading underscore and is still subject to change. The API is as follows:

_scaled_mm(Tensor self, Tensor mat2, *, Tensor? bias=None, ScalarType? out_dtype=None, Tensor? scale_a=None, Tensor? scale_b=None, Tensor? scale_result=None, bool use_fast_accum=False) -> (Tensor, Tensor)

Parameters:

  • mat1 + mat1_scale (Float8Tensor): The first matrix and its associated scale factor.
  • mat2 + mat2_scale (Float8Tensor): The second matrix and its associated scale factor.
  • bias: Optional bias tensor.
  • ouput_dtype + scale_result (Optional Float8Tensor): Output dtype and result scaling factor.
  • use_fast_accum: A flag to enable fast accumulation.

With the introduction of _scaled_mm and the capability to downcast to FP8 tensors, we're now equipped to implement one of PyTorch's cornerstone operations: torch.nn.Linear!

Dynamic Linear: Leveraging FP8 Precision

The complete implementation of our Dynamic Linear module, which embraces the FP8 data type, is available at: Float8 Experimental Dynamic Linear. But the changes can be highlighted in the forward of this new module:

def forward(self, x):
	x_fp8: Float8Tensor = self.cast_to_float8(x)
	w_fp8: Float8Tensor = self.cast_to_float8(self.weight)

	y = torch.nn.functional.linear(x_fp8, w_fp8, self.bias)

	# Cast gradY to float8_e5m2 during backward
	y = self.cast_to_float8e5m2_bw(y)

	return y

The process is straightforward: cast the inputs x and weights w to FP8, perform the linear operation, and then apply specific backward pass adjustments. You might wonder about the scaling factors and their integration. In our approach, we utilize tensor subclasses that allow us to redefine the behavior of F.Linear. Internally, this translates the Float8Tensors into two components: _data and _scale, which are then handled by _scaled_mm.

NVIDIA suggests employing the float8_e4m3 format for the forward pass and float8_e5m2 for the backward pass to optimize computational efficiency. An alternative to the method above could involve using tensor hooks for the FP8 conversion, which would bypass the need for a custom Autograd.Function.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment