Skip to content

Instantly share code, notes, and snippets.

@thecharlieblake
Created October 13, 2022 15:43
Show Gist options
  • Save thecharlieblake/62fe60553050d078caf04764a2228681 to your computer and use it in GitHub Desktop.
Save thecharlieblake/62fe60553050d078caf04764a2228681 to your computer and use it in GitHub Desktop.
A representation of the IEEE754 binary format, as well as variants for FP8.
"""A representation of the IEEE754 binary format, as well as variants for FP8."""
from dataclasses import dataclass
from functools import total_ordering
from typing import Any, ClassVar
@dataclass
class IEEE754BinaryFormat:
"""Can be used to express any format defined in the IEEE754 standard.
To represent a value using this format, see `FloatInstance`.
"""
e_width: int
m_width: int
inf_encoding: ClassVar[str] = "E=1s M=0s"
nan_encoding: ClassVar[str] = "E=1s M≠0s"
zero_encoding: ClassVar[str] = "S=0/1 E=0s M=0s"
subnormal_encoding: ClassVar[str] = "E=0s"
@property
def bias(self) -> int:
"""Exponent bias."""
return int(2 ** (self.e_width - 1)) - 1
@property
def min_e(self) -> int:
"""Minimum exponent.
Note: this is not `0 - self.bias` as the 0 exponent is reserved for subnormals.
"""
return 1 - self.bias
@property
def max_e(self) -> int:
"""Maximum exponent.
Note: The `-2` accounts for the fact that the all 1s exponent denotes NaN/Inf.
"""
return int(2**self.e_width) - 2 - self.bias
@property
def abs_min_normal(self) -> float:
"""Absolute minimum normal (i.e. not subnormal) representable value."""
return float(2**self.min_e)
@property
def abs_min(self) -> float:
"""Absolute minimum representable value (this is in the subnormal range)."""
return float(2 ** (self.min_e - self.m_width))
@property
def abs_max(self) -> float:
"""Absolute maximum representable value."""
return float((2**self.max_e) * (2 - 2**-self.m_width))
@dataclass
class GAQProposedFormat(IEEE754BinaryFormat):
"""Used for both of Graphcore, AMD and Qualcomm's proposed FP8 formats."""
custom_bias: int
inf_encoding: ClassVar[str] = "N/A"
nan_encoding: ClassVar[str] = "S=1 E=0s M=0s"
zero_encoding: ClassVar[str] = "S=0 E=0s M=0s"
@property
def bias(self) -> int:
"""Exponent bias. GAQ use a bias that doesn't match the standard IEE754 one."""
return self.custom_bias
@property
def max_e(self) -> int:
"""Maximum exponent. The all 1s exponent no longer denotes NaN/Inf."""
return super().max_e + 1
@dataclass
class NAIProposedFormat(IEEE754BinaryFormat):
"""Used for Nvidia, ARM and Intel's proposed E4M3 format."""
inf_encoding: ClassVar[str] = "N/A"
nan_encoding: ClassVar[str] = "E=1s M=1s"
@property
def max_e(self) -> int:
"""Maximum exponent. The all 1s exponent no longer denotes NaN/Inf."""
return super().max_e + 1
@property
def abs_max(self) -> float:
"""Absolute maximum representable value.
Accounts for the fact that the value with all 1s exponent+mantissa denotes NaN.
This was previously handled by taking one value away from max_e, but now the all
1s exponent is generally valid, this special-case must be added.
"""
return float(super().abs_max - 2 ** (self.max_e - self.m_width))
@dataclass
@total_ordering
class FloatInstance:
"""An instance of a floating point number, defined with reference to an instance
or subclass of `IEEE754BinaryFormat`.
"""
format: IEEE754BinaryFormat
s: int
e: int
m: int
def __post_init__(self) -> None:
assert self.s in [0, 1], self.s
self.e_limit = int(2**self.format.e_width) - 1
self.m_limit = int(2**self.format.m_width) - 1
assert (
0 <= self.e <= self.e_limit
), f"Exponent {self.e} outside range: [0, {self.e_limit}]"
assert (
0 <= self.m <= self.m_limit
), f"Mantissa {self.m} outside range: [0, {self.m_limit}]"
@property
def value(self) -> float:
"""The numerical value of the bitstring, as defined by the supplied format."""
if self._is_inf():
return float("inf") * int((-1) ** self.s)
if self._is_nan():
return float("nan")
if self._is_subnormal():
return self._subnormal_val()
return self._normal_val()
def _normal_val(self) -> float:
e = self.e - self.format.bias
m = 1 + (self.m / (self.m_limit + 1))
return float(((-1) ** self.s) * (2**e) * m)
def _subnormal_val(self) -> float:
e = 1 - self.format.bias
m = self.m / (self.m_limit + 1)
return float(((-1) ** self.s) * (2**e) * m)
def _is_subnormal(self) -> bool:
return self.e == 0
def _is_nan(self) -> bool:
if self.format.nan_encoding == "E=1s M≠0s":
return self.e == self.e_limit and self.m != 0
if self.format.nan_encoding == "E=1s M=1s":
return self.e == self.e_limit and self.m == self.m_limit
assert (
self.format.nan_encoding == "S=1 E=0s M=0s"
), f"NaN encoding `'{self.format.nan_encoding}' not recognised"
return self.s == 1 and self.e == 0 and self.m == 0
def _is_inf(self) -> bool:
if self.format.inf_encoding == "E=1s M=0s":
return self.e == self.e_limit and self.m == 0
assert (
self.format.inf_encoding == "N/A"
), f"Inf encoding `'{self.format.inf_encoding}' not recognised"
return False
def __repr__(self) -> str:
return str(self.value)
def __eq__(self, other: Any) -> bool:
return self.value.__eq__(other)
def __lt__(self, other: Any) -> bool:
return self.value.__lt__(other)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment