Created
June 17, 2024 16:31
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/float8_experimental/float8_dynamic_linear.py b/float8_experimental/float8_dynamic_linear.py | |
index 5d1aa44..0d4dbc0 100644 | |
--- a/float8_experimental/float8_dynamic_linear.py | |
+++ b/float8_experimental/float8_dynamic_linear.py | |
@@ -22,7 +22,7 @@ from float8_experimental.float8_tensor import ( | |
tensor_already_casted_to_fp8, | |
to_fp8_no_autograd, | |
) | |
-from float8_experimental.float8_utils import tensor_to_scale, e4m3_dtype, e5m2_dtype | |
+from float8_experimental.float8_utils import e4m3_dtype, e5m2_dtype, tensor_to_scale | |
from torch._prims_common import suggest_memory_format | |
@@ -106,9 +106,7 @@ def cast_to_float8_e4m3fn( | |
if tensor_already_casted_to_fp8(inpt_tensor): | |
return inpt_tensor | |
scale = tensor_to_scale(inpt_tensor, e4m3_dtype, reduce_amax) | |
- return Float8Tensor.to_float8( | |
- inpt_tensor, scale, e4m3_dtype, mm_config=mm_config | |
- ) | |
+ return Float8Tensor.to_float8(inpt_tensor, scale, e4m3_dtype, mm_config=mm_config) | |
def cast_to_float8_e5m2_bw( | |
diff --git a/float8_experimental/float8_linear.py b/float8_experimental/float8_linear.py | |
index a413925..35c03c0 100644 | |
--- a/float8_experimental/float8_linear.py | |
+++ b/float8_experimental/float8_linear.py | |
@@ -21,7 +21,12 @@ from float8_experimental.float8_tensor import ( | |
to_fp8_no_autograd, | |
) | |
-from float8_experimental.float8_utils import amax_history_to_scale, tensor_to_amax, e4m3_dtype, e5m2_dtype | |
+from float8_experimental.float8_utils import ( | |
+ amax_history_to_scale, | |
+ e4m3_dtype, | |
+ e5m2_dtype, | |
+ tensor_to_amax, | |
+) | |
def _maybe_initialize_amaxes_scales_for_float8_cast( | |
diff --git a/float8_experimental/float8_linear_utils.py b/float8_experimental/float8_linear_utils.py | |
index ba2b597..9239200 100644 | |
--- a/float8_experimental/float8_linear_utils.py | |
+++ b/float8_experimental/float8_linear_utils.py | |
@@ -14,7 +14,11 @@ import torch.nn as nn | |
from float8_experimental.float8_dynamic_linear import Float8DynamicLinear | |
from float8_experimental.float8_linear import Float8Linear | |
-from float8_experimental.float8_utils import amax_history_to_scale_stack, e4m3_dtype, e5m2_dtype | |
+from float8_experimental.float8_utils import ( | |
+ amax_history_to_scale_stack, | |
+ e4m3_dtype, | |
+ e5m2_dtype, | |
+) | |
from torch.distributed._functional_collectives import all_reduce, AsyncCollectiveTensor | |
log = logging.getLogger(__name__) | |
diff --git a/float8_experimental/float8_python_api.py b/float8_experimental/float8_python_api.py | |
index 6cb406d..64f30bf 100644 | |
--- a/float8_experimental/float8_python_api.py | |
+++ b/float8_experimental/float8_python_api.py | |
@@ -9,7 +9,6 @@ of class `Float8Tensor`. This is a thin wrapper on top of the aten API | |
to simplify the product code. | |
""" | |
- | |
from typing import Optional, Tuple | |
import float8_experimental.float8_aten_api # noqa | |
diff --git a/float8_experimental/float8_tensor.py b/float8_experimental/float8_tensor.py | |
index 43c576a..5c8e9a8 100644 | |
--- a/float8_experimental/float8_tensor.py | |
+++ b/float8_experimental/float8_tensor.py | |
@@ -9,7 +9,11 @@ from typing import Dict, Optional | |
import torch | |
import torch.distributed._functional_collectives as funcol | |
-from float8_experimental.float8_utils import tensor_to_amax, to_fp8_saturated, e4m3_dtype | |
+from float8_experimental.float8_utils import ( | |
+ e4m3_dtype, | |
+ tensor_to_amax, | |
+ to_fp8_saturated, | |
+) | |
from torch.distributed._tensor import DTensor | |
aten = torch.ops.aten | |
diff --git a/test/test_base.py b/test/test_base.py | |
index c015e1d..7e7539e 100644 | |
--- a/test/test_base.py | |
+++ b/test/test_base.py | |
@@ -31,11 +31,11 @@ from float8_experimental.float8_tensor import ( | |
from float8_experimental.float8_utils import ( | |
amax_to_scale, | |
compute_error, | |
+ e4m3_dtype, | |
+ e5m2_dtype, | |
fp8_tensor_statistics, | |
FP8_TYPES, | |
tensor_to_scale, | |
- e4m3_dtype, | |
- e5m2_dtype, | |
) | |
random.seed(0) | |
@@ -397,10 +397,15 @@ class TestScaledMM: | |
class TestNumerics: | |
- @pytest.mark.parametrize("float8_dtype", [torch.float8_e4m3fn, | |
- torch.float8_e5m2, | |
- torch.float8_e4m3fnuz, | |
- torch.float8_e5m2fnuz]) | |
+ @pytest.mark.parametrize( | |
+ "float8_dtype", | |
+ [ | |
+ torch.float8_e4m3fn, | |
+ torch.float8_e5m2, | |
+ torch.float8_e4m3fnuz, | |
+ torch.float8_e5m2fnuz, | |
+ ], | |
+ ) | |
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") | |
def test_small_amax_float16(self, float8_dtype): | |
# If we calculate scale naively with FP8_MAX_POS / amax, |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment