| """A representation of the IEEE754 binary format, as well as variants for FP8.""" | |
| from dataclasses import dataclass | |
| from functools import total_ordering | |
| from typing import Any, ClassVar | |
| @dataclass | |
| class IEEE754BinaryFormat: | |
| """Can be used to express any format defined in the IEEE754 standard. | |
| To represent a value using this format, see `FloatInstance`. | |
| """ | |
| e_width: int | |
| m_width: int | |
| inf_encoding: ClassVar[str] = "E=1s M=0s" | |
| nan_encoding: ClassVar[str] = "E=1s M≠0s" | |
| zero_encoding: ClassVar[str] = "S=0/1 E=0s M=0s" | |
| subnormal_encoding: ClassVar[str] = "E=0s" | |
| @property | |
| def bias(self) -> int: | |
| """Exponent bias.""" | |
| return int(2 ** (self.e_width - 1)) - 1 | |
| @property | |
| def min_e(self) -> int: | |
| """Minimum exponent. | |
| Note: this is not `0 - self.bias` as the 0 exponent is reserved for subnormals. | |
| """ | |
| return 1 - self.bias | |
| @property | |
| def max_e(self) -> int: | |
| """Maximum exponent. | |
| Note: The `-2` accounts for the fact that the all 1s exponent denotes NaN/Inf. | |
| """ | |
| return int(2**self.e_width) - 2 - self.bias | |
| @property | |
| def abs_min_normal(self) -> float: | |
| """Absolute minimum normal (i.e. not subnormal) representable value.""" | |
| return float(2**self.min_e) | |
| @property | |
| def abs_min(self) -> float: | |
| """Absolute minimum representable value (this is in the subnormal range).""" | |
| return float(2 ** (self.min_e - self.m_width)) | |
| @property | |
| def abs_max(self) -> float: | |
| """Absolute maximum representable value.""" | |
| return float((2**self.max_e) * (2 - 2**-self.m_width)) | |
| @dataclass | |
| class GAQProposedFormat(IEEE754BinaryFormat): | |
| """Used for both of Graphcore, AMD and Qualcomm's proposed FP8 formats.""" | |
| custom_bias: int | |
| inf_encoding: ClassVar[str] = "N/A" | |
| nan_encoding: ClassVar[str] = "S=1 E=0s M=0s" | |
| zero_encoding: ClassVar[str] = "S=0 E=0s M=0s" | |
| @property | |
| def bias(self) -> int: | |
| """Exponent bias. GAQ use a bias that doesn't match the standard IEE754 one.""" | |
| return self.custom_bias | |
| @property | |
| def max_e(self) -> int: | |
| """Maximum exponent. The all 1s exponent no longer denotes NaN/Inf.""" | |
| return super().max_e + 1 | |
| @dataclass | |
| class NAIProposedFormat(IEEE754BinaryFormat): | |
| """Used for Nvidia, ARM and Intel's proposed E4M3 format.""" | |
| inf_encoding: ClassVar[str] = "N/A" | |
| nan_encoding: ClassVar[str] = "E=1s M=1s" | |
| @property | |
| def max_e(self) -> int: | |
| """Maximum exponent. The all 1s exponent no longer denotes NaN/Inf.""" | |
| return super().max_e + 1 | |
| @property | |
| def abs_max(self) -> float: | |
| """Absolute maximum representable value. | |
| Accounts for the fact that the value with all 1s exponent+mantissa denotes NaN. | |
| This was previously handled by taking one value away from max_e, but now the all | |
| 1s exponent is generally valid, this special-case must be added. | |
| """ | |
| return float(super().abs_max - 2 ** (self.max_e - self.m_width)) | |
| @dataclass | |
| @total_ordering | |
| class FloatInstance: | |
| """An instance of a floating point number, defined with reference to an instance | |
| or subclass of `IEEE754BinaryFormat`. | |
| """ | |
| format: IEEE754BinaryFormat | |
| s: int | |
| e: int | |
| m: int | |
| def __post_init__(self) -> None: | |
| assert self.s in [0, 1], self.s | |
| self.e_limit = int(2**self.format.e_width) - 1 | |
| self.m_limit = int(2**self.format.m_width) - 1 | |
| assert ( | |
| 0 <= self.e <= self.e_limit | |
| ), f"Exponent {self.e} outside range: [0, {self.e_limit}]" | |
| assert ( | |
| 0 <= self.m <= self.m_limit | |
| ), f"Mantissa {self.m} outside range: [0, {self.m_limit}]" | |
| @property | |
| def value(self) -> float: | |
| """The numerical value of the bitstring, as defined by the supplied format.""" | |
| if self._is_inf(): | |
| return float("inf") * int((-1) ** self.s) | |
| if self._is_nan(): | |
| return float("nan") | |
| if self._is_subnormal(): | |
| return self._subnormal_val() | |
| return self._normal_val() | |
| def _normal_val(self) -> float: | |
| e = self.e - self.format.bias | |
| m = 1 + (self.m / (self.m_limit + 1)) | |
| return float(((-1) ** self.s) * (2**e) * m) | |
| def _subnormal_val(self) -> float: | |
| e = 1 - self.format.bias | |
| m = self.m / (self.m_limit + 1) | |
| return float(((-1) ** self.s) * (2**e) * m) | |
| def _is_subnormal(self) -> bool: | |
| return self.e == 0 | |
| def _is_nan(self) -> bool: | |
| if self.format.nan_encoding == "E=1s M≠0s": | |
| return self.e == self.e_limit and self.m != 0 | |
| if self.format.nan_encoding == "E=1s M=1s": | |
| return self.e == self.e_limit and self.m == self.m_limit | |
| assert ( | |
| self.format.nan_encoding == "S=1 E=0s M=0s" | |
| ), f"NaN encoding `'{self.format.nan_encoding}' not recognised" | |
| return self.s == 1 and self.e == 0 and self.m == 0 | |
| def _is_inf(self) -> bool: | |
| if self.format.inf_encoding == "E=1s M=0s": | |
| return self.e == self.e_limit and self.m == 0 | |
| assert ( | |
| self.format.inf_encoding == "N/A" | |
| ), f"Inf encoding `'{self.format.inf_encoding}' not recognised" | |
| return False | |
| def __repr__(self) -> str: | |
| return str(self.value) | |
| def __eq__(self, other: Any) -> bool: | |
| return self.value.__eq__(other) | |
| def __lt__(self, other: Any) -> bool: | |
| return self.value.__lt__(other) |