Skip to content

Instantly share code, notes, and snippets.

@drisspg
Created June 17, 2024 16:31
diff --git a/float8_experimental/float8_dynamic_linear.py b/float8_experimental/float8_dynamic_linear.py
index 5d1aa44..0d4dbc0 100644
--- a/float8_experimental/float8_dynamic_linear.py
+++ b/float8_experimental/float8_dynamic_linear.py
@@ -22,7 +22,7 @@ from float8_experimental.float8_tensor import (
tensor_already_casted_to_fp8,
to_fp8_no_autograd,
)
-from float8_experimental.float8_utils import tensor_to_scale, e4m3_dtype, e5m2_dtype
+from float8_experimental.float8_utils import e4m3_dtype, e5m2_dtype, tensor_to_scale
from torch._prims_common import suggest_memory_format
@@ -106,9 +106,7 @@ def cast_to_float8_e4m3fn(
if tensor_already_casted_to_fp8(inpt_tensor):
return inpt_tensor
scale = tensor_to_scale(inpt_tensor, e4m3_dtype, reduce_amax)
- return Float8Tensor.to_float8(
- inpt_tensor, scale, e4m3_dtype, mm_config=mm_config
- )
+ return Float8Tensor.to_float8(inpt_tensor, scale, e4m3_dtype, mm_config=mm_config)
def cast_to_float8_e5m2_bw(
diff --git a/float8_experimental/float8_linear.py b/float8_experimental/float8_linear.py
index a413925..35c03c0 100644
--- a/float8_experimental/float8_linear.py
+++ b/float8_experimental/float8_linear.py
@@ -21,7 +21,12 @@ from float8_experimental.float8_tensor import (
to_fp8_no_autograd,
)
-from float8_experimental.float8_utils import amax_history_to_scale, tensor_to_amax, e4m3_dtype, e5m2_dtype
+from float8_experimental.float8_utils import (
+ amax_history_to_scale,
+ e4m3_dtype,
+ e5m2_dtype,
+ tensor_to_amax,
+)
def _maybe_initialize_amaxes_scales_for_float8_cast(
diff --git a/float8_experimental/float8_linear_utils.py b/float8_experimental/float8_linear_utils.py
index ba2b597..9239200 100644
--- a/float8_experimental/float8_linear_utils.py
+++ b/float8_experimental/float8_linear_utils.py
@@ -14,7 +14,11 @@ import torch.nn as nn
from float8_experimental.float8_dynamic_linear import Float8DynamicLinear
from float8_experimental.float8_linear import Float8Linear
-from float8_experimental.float8_utils import amax_history_to_scale_stack, e4m3_dtype, e5m2_dtype
+from float8_experimental.float8_utils import (
+ amax_history_to_scale_stack,
+ e4m3_dtype,
+ e5m2_dtype,
+)
from torch.distributed._functional_collectives import all_reduce, AsyncCollectiveTensor
log = logging.getLogger(__name__)
diff --git a/float8_experimental/float8_python_api.py b/float8_experimental/float8_python_api.py
index 6cb406d..64f30bf 100644
--- a/float8_experimental/float8_python_api.py
+++ b/float8_experimental/float8_python_api.py
@@ -9,7 +9,6 @@ of class `Float8Tensor`. This is a thin wrapper on top of the aten API
to simplify the product code.
"""
-
from typing import Optional, Tuple
import float8_experimental.float8_aten_api # noqa
diff --git a/float8_experimental/float8_tensor.py b/float8_experimental/float8_tensor.py
index 43c576a..5c8e9a8 100644
--- a/float8_experimental/float8_tensor.py
+++ b/float8_experimental/float8_tensor.py
@@ -9,7 +9,11 @@ from typing import Dict, Optional
import torch
import torch.distributed._functional_collectives as funcol
-from float8_experimental.float8_utils import tensor_to_amax, to_fp8_saturated, e4m3_dtype
+from float8_experimental.float8_utils import (
+ e4m3_dtype,
+ tensor_to_amax,
+ to_fp8_saturated,
+)
from torch.distributed._tensor import DTensor
aten = torch.ops.aten
diff --git a/test/test_base.py b/test/test_base.py
index c015e1d..7e7539e 100644
--- a/test/test_base.py
+++ b/test/test_base.py
@@ -31,11 +31,11 @@ from float8_experimental.float8_tensor import (
from float8_experimental.float8_utils import (
amax_to_scale,
compute_error,
+ e4m3_dtype,
+ e5m2_dtype,
fp8_tensor_statistics,
FP8_TYPES,
tensor_to_scale,
- e4m3_dtype,
- e5m2_dtype,
)
random.seed(0)
@@ -397,10 +397,15 @@ class TestScaledMM:
class TestNumerics:
- @pytest.mark.parametrize("float8_dtype", [torch.float8_e4m3fn,
- torch.float8_e5m2,
- torch.float8_e4m3fnuz,
- torch.float8_e5m2fnuz])
+ @pytest.mark.parametrize(
+ "float8_dtype",
+ [
+ torch.float8_e4m3fn,
+ torch.float8_e5m2,
+ torch.float8_e4m3fnuz,
+ torch.float8_e5m2fnuz,
+ ],
+ )
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def test_small_amax_float16(self, float8_dtype):
# If we calculate scale naively with FP8_MAX_POS / amax,
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment