Created
March 2, 2024 07:13
-
-
Save fayalalebrun/564b5d9a38f57ee5d18d3c7f2e0fc53c to your computer and use it in GitHub Desktop.
Simulate posit quantization using JAX
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import unittest | |
import jax | |
import jax.numpy as jnp | |
from functools import partial | |
def decompose(x: jnp.float32) -> tuple[jnp.int32, jnp.int32, jnp.int32]: | |
"""decomposes a float32 into negative, exponent, and significand""" | |
negative = x < 0 | |
n = jnp.abs(x).view(jnp.int32) | |
exponent = (n >> 23) - 127 | |
significand = n.view(jnp.uint32) & jnp.uint32(2**23 - 1) | |
return (negative, exponent, significand) | |
def compose(negative: jnp.bool_, exponent: jnp.int32, significand: jnp.uint32) -> jnp.float32: | |
"""composes a negative, exponent, and significand into a float32""" | |
#assert (significand < 1 << 23).all() | |
negative = jnp.where(negative, 1, 0).view(jnp.uint32) << 31 | |
exponent = ((exponent + 127) << 23).view(jnp.uint32) | |
#assert (negative & exponent & significand == 0).all() | |
return (negative + exponent + significand).view(jnp.float32) | |
def format_decomposed(x: jnp.float32): | |
negative, exponent, significand = decompose(x) | |
return f"sign: {negative:01b}, exp: {exponent+127:08b} ({exponent:03}), m: {significand:023b}" | |
@partial(jax.jit, static_argnums=(0, 1)) | |
def quantize_float_to_posit(size: jnp.int32, exponent_size: jnp.int32, arr: jnp.float32) -> jnp.float32: | |
"""maps a float32 into the range representable by a posit with the given size and exponent size""" | |
negative, exponent, significand = decompose(arr) | |
ufactor = 1 << exponent_size | |
max_regime = size - 2 | |
regime = jnp.clip(exponent // ufactor, -max_regime, max_regime) | |
k = jnp.abs(regime) + jnp.where(regime >= 0, 1, 0) | |
# Combines exponent and mantissa but exponent is signed | |
# This is closer to a posit representation and allows us | |
# to carry correctly when adding the rounding factor | |
combined = (~(jnp.uint32(1) << 31) & (exponent.view(jnp.uint32) << 23)) + significand | |
# Bits not consumed by the regime | |
remaining_bits = max_regime - k | |
# Where the least signifcant bit would be located on the posit | |
lsb_pos = 23 + exponent_size - remaining_bits | |
# Get the value of the bit at the LSB | |
lsb_val_corr = jnp.where(remaining_bits <= 0, \ | |
# If we are out of bits, then simulate what would be the LSB | |
# if this was an actual posit | |
(regime < 0) | (regime > 0) & (remaining_bits==-1), \ | |
# Otherwise just get the actual exponent/significand value | |
((combined >> lsb_pos) & 1)) | |
# If LSB is 0, then round down. | |
# See https://posithub.org/docs/posit_standard-2.pdf | |
combined = combined + ((1 << lsb_pos-1) - jnp.where(lsb_val_corr == 1, 0, 1)) | |
# Mask off region beyond LSB | |
combined = combined & ~((1 << lsb_pos)-1) | |
significand_mask = (jnp.uint32(1) << 23) - 1 | |
exponent_mask = ((jnp.uint32(1) << 8) - 1) << 23 | |
# We need to shift the exponent to the start so we can sign extend it and place it back at the end | |
new_exponent = (((combined & exponent_mask) << 1).view(jnp.int32)) >> 24 | |
# Finally we can build our value | |
val = compose(negative, new_exponent, combined & significand_mask) | |
# Clip values out of range | |
val = jnp.where(jnp.abs(exponent) > max_regime*ufactor, compose(negative, jnp.sign(exponent) * max_regime * ufactor, 0), val) | |
# Conserve NaNs | |
val = jnp.where(jnp.isnan(arr), arr, val) | |
return val | |
class TestQuantMethods(unittest.TestCase): | |
def test_compose_decompose(self): | |
to_check = jnp.iinfo(jnp.uint32).max - jnp.iinfo(jnp.uint32).min | |
iters = 1 << 12 | |
slice_length = to_check // iters | |
recreate = lambda x: compose(*decompose(x)) | |
recreate = jax.jit(recreate) | |
for i in range(iters): | |
cases = jnp.arange(start = slice_length*i, stop = slice_length*(i+1), dtype=jnp.uint32).view(jnp.float32) | |
res = recreate(cases) | |
matches = (cases == res) | (jnp.isnan(cases) & jnp.isnan(res)) | |
no_match = jnp.argwhere(~matches) | |
cases_bad = jnp.take(cases, no_match) | |
res_bad = jnp.take(res, no_match) | |
assert matches.all(), f"{cases_bad} {res_bad}" | |
def test_quantize(self): | |
import random | |
for i in range(jnp.float32(2**-57).view( | |
jnp.uint32), jnp.float32(2**57).view(jnp.uint32), 2**14): | |
val = jnp.uint32(i + random.randint(0, 2**14-1)).view(jnp.float32) | |
test_equiv(val) | |
test_equiv(-val) | |
def test_equiv(val: jnp.float32): | |
import softposit as sp | |
truth = jnp.float32(sp.convertP16ToDouble(sp.posit16(float(val)).v)) | |
our = quantize_float_to_posit(16, 1, val) | |
assert truth == our, f"For {val}, {our} != {truth}. ({val.view(jnp.uint32):032b}, {our.view(jnp.uint32):032b} != {truth.view(jnp.uint32):032b}) \nval:\t {format_decomposed(val)}\nour:\t {format_decomposed(our)}\ntruth:\t {format_decomposed(truth)}\n" | |
if __name__ == '__main__': | |
unittest.main() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment