Last active
March 12, 2024 15:32
-
-
Save nibrunie/ea700581f2ed8945a8dbafaa75bcd294 to your computer and use it in GitHub Desktop.
python generator to list values of various FP8 formats
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
# this script generates values for different 8-bit floating point formats | |
def bitMask(width): | |
""" generate a <width> wide bitmask """ | |
return 2**width - 1 | |
class FP8Fomat: | |
expBits = None # exponent field width (in bits) | |
mantBits = None # mantissa field width (in bits) | |
bias = None # exponent bias | |
@classmethod | |
def exp(cls, index): | |
return (index >> cls.mantBits) & bitMask(cls.expBits) | |
@classmethod | |
def mant(cls, index): | |
return index & bitMask(cls.mantBits) | |
@classmethod | |
def sign(cls, index): | |
return (index >> (cls.expBits + cls.mantBits)) & bitMask(1) | |
@classmethod | |
def value(cls, index): | |
raise NotImplementedError | |
@classmethod | |
def label(cls): | |
raise NotImplementedError | |
class SpecialValue: | |
def __init__(self, label): | |
self.label = label | |
sNaN = SpecialValue("sNaN") | |
qNaN = SpecialValue("qNaN") | |
Inf = SpecialValue("Inf") | |
PlusInf = SpecialValue("+Inf") | |
MinusInf = SpecialValue("-Inf") | |
NaN = SpecialValue("NaN") | |
class OCP_e5m2(FP8Fomat): | |
expBits = 5 | |
mantBits = 2 | |
bias = 15 | |
label = "OFP8_e5m2" | |
@classmethod | |
def value(cls, index): | |
sign = -1.0 if cls.sign(index) else 1.0 | |
biasedExp = cls.exp(index) | |
mant = cls.mant(index) | |
if biasedExp == bitMask(cls.expBits): | |
if mant == 0: | |
return PlusInf if sign > 0 else MinusInf | |
return NaN | |
exp = biasedExp - cls.bias + (1 if biasedExp == 0 else 0) | |
mant = (1 if biasedExp != 0 else 0) + (mant * 2**-cls.mantBits) | |
return sign * 2**exp * mant | |
class OCP_e4m3(FP8Fomat): | |
expBits = 4 | |
mantBits = 3 | |
bias = 7 | |
label = "OFP8_e4m3" | |
@classmethod | |
def value(cls, index): | |
sign = -1.0 if cls.sign(index) else 1.0 | |
biasedExp = cls.exp(index) | |
mant = cls.mant(index) | |
if biasedExp == 0xf and mant == 0x7: | |
return NaN | |
exp = biasedExp - cls.bias + (1 if biasedExp == 0 else 0) | |
mant = (1 if biasedExp != 0 else 0) + (mant * 2**-cls.mantBits) | |
return sign * 2**exp * mant | |
class DummyIEEEBinary8(FP8Fomat): | |
@classmethod | |
def value(cls, index): | |
sign = -1.0 if cls.sign(index) else 1.0 | |
biasedExp = cls.exp(index) | |
mant = cls.mant(index) | |
if biasedExp == bitMask(cls.expBits): | |
if mant == 0: | |
return PlusInf if sign > 0 else MinusInf | |
return NaN | |
exp = biasedExp - cls.bias + (1 if biasedExp == 0 else 0) | |
mant = (1 if biasedExp != 0 else 0) + (mant * 2**-cls.mantBits) | |
return sign * 2**exp * mant | |
class DummyBinary8p2(DummyIEEEBinary8): | |
mantBits = 1 | |
expBits = 8 - mantBits - 1 | |
bias = 2**(expBits-1) -1 | |
label = "Dummy_binary8p2" | |
class DummyBinary8p3(DummyIEEEBinary8): | |
mantBits = 2 | |
expBits = 8 - mantBits - 1 | |
bias = 2**(expBits-1) -1 | |
label = "Dummy_binary8p3" | |
class DummyBinary8p4(DummyIEEEBinary8): | |
mantBits = 3 | |
expBits = 8 - mantBits - 1 | |
bias = 2**(expBits-1) -1 | |
label = "Dummy_binary8p4" | |
class DummyBinary8p5(DummyIEEEBinary8): | |
mantBits = 4 | |
expBits = 8 - mantBits - 1 | |
bias = 2**(expBits-1) -1 | |
label = "Dummy_binary8p4" | |
class DummyBinary8p6(DummyIEEEBinary8): | |
mantBits = 5 | |
expBits = 8 - mantBits - 1 | |
bias = 2**(expBits-1) -1 | |
label = "Dummy_binary8p6" | |
class P3109Binary8(FP8Fomat): | |
@classmethod | |
def value(cls, index): | |
sign = -1.0 if cls.sign(index) else 1.0 | |
if index & bitMask(7) == 0x7f: | |
return PlusInf if sign > 0 else MinusInf | |
if index == 0x80: | |
return NaN | |
if index == 0x00: | |
return 0 # required by binary8p1 | |
return sign * 2**cls.expValue(index)* cls.sig(index) | |
@classmethod | |
def expValue(cls, index): | |
biasedExp = cls.exp(index) | |
exp = biasedExp - cls.bias + (1 if biasedExp == 0 else 0) | |
return exp | |
@classmethod | |
def sig(cls, index): | |
biasedExp = cls.exp(index) | |
mant = cls.mant(index) | |
sig = (1 if biasedExp != 0 else 0) + (mant * 2**-cls.mantBits) | |
return sig | |
def P3109Binary8pConstructor(p: int): | |
allSpecialExponent = False | |
mantBits = p - 1 | |
expBits = 8 - p | |
emax = 2**(expBits-1) - 1 | |
emin = (1 if allSpecialExponent else 0) - emax | |
bias = 1 - emin | |
return type(f"P3109Binary8p{p}", | |
(P3109Binary8,), | |
{ | |
"mantBits": mantBits, | |
"expBits": expBits, | |
"bias": bias, | |
"label": f"P3109_binary8p{p}"}) | |
class P3109Binary8p1(P3109Binary8): | |
allSpecialExponent = True | |
mantBits = 0 | |
expBits = 8 - 1 - mantBits | |
emax = 2**(expBits-1) - 1 | |
emin = (1 if allSpecialExponent else 0) - emax | |
bias = 1 - emin | |
label = "P3109_binary8p1" | |
@classmethod | |
def mant(cls, index): | |
return 0 | |
@classmethod | |
def sig(cls, index): | |
return 1 | |
@classmethod | |
def expValue(cls, index): | |
biasedExp = cls.exp(index) | |
exp = biasedExp - cls.bias | |
return exp | |
formatlist = [ | |
# Open Compute Project OFP8 formats | |
OCP_e4m3, | |
OCP_e5m2, | |
# Dummy theoretical IEEE-FP8 binary8 formats | |
DummyBinary8p2, | |
DummyBinary8p3, | |
DummyBinary8p4, | |
DummyBinary8p5, | |
DummyBinary8p6, | |
# P3109 binary8p<n> formats | |
P3109Binary8p1, | |
P3109Binary8pConstructor(2), | |
P3109Binary8pConstructor(3), | |
P3109Binary8pConstructor(4), | |
P3109Binary8pConstructor(5), | |
P3109Binary8pConstructor(6), | |
P3109Binary8pConstructor(7), | |
] | |
def value2str(value): | |
if isinstance(value, SpecialValue): | |
return "" # value.label | |
else: | |
return str(value) | |
print(" ", ", ".join(fp8Fmt.label for fp8Fmt in formatlist)) | |
for i in range(128): | |
print(i, ", ".join(value2str(fp8Fmt.value(i)) for fp8Fmt in formatlist)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment