Skip to content

Instantly share code, notes, and snippets.

@fayalalebrun
Created March 2, 2024 07:13
Show Gist options
  • Save fayalalebrun/564b5d9a38f57ee5d18d3c7f2e0fc53c to your computer and use it in GitHub Desktop.
Save fayalalebrun/564b5d9a38f57ee5d18d3c7f2e0fc53c to your computer and use it in GitHub Desktop.
Simulate posit quantization using JAX
import unittest
import jax
import jax.numpy as jnp
from functools import partial
def decompose(x: jnp.float32) -> tuple[jnp.int32, jnp.int32, jnp.int32]:
"""decomposes a float32 into negative, exponent, and significand"""
negative = x < 0
n = jnp.abs(x).view(jnp.int32)
exponent = (n >> 23) - 127
significand = n.view(jnp.uint32) & jnp.uint32(2**23 - 1)
return (negative, exponent, significand)
def compose(negative: jnp.bool_, exponent: jnp.int32, significand: jnp.uint32) -> jnp.float32:
"""composes a negative, exponent, and significand into a float32"""
#assert (significand < 1 << 23).all()
negative = jnp.where(negative, 1, 0).view(jnp.uint32) << 31
exponent = ((exponent + 127) << 23).view(jnp.uint32)
#assert (negative & exponent & significand == 0).all()
return (negative + exponent + significand).view(jnp.float32)
def format_decomposed(x: jnp.float32):
negative, exponent, significand = decompose(x)
return f"sign: {negative:01b}, exp: {exponent+127:08b} ({exponent:03}), m: {significand:023b}"
@partial(jax.jit, static_argnums=(0, 1))
def quantize_float_to_posit(size: jnp.int32, exponent_size: jnp.int32, arr: jnp.float32) -> jnp.float32:
"""maps a float32 into the range representable by a posit with the given size and exponent size"""
negative, exponent, significand = decompose(arr)
ufactor = 1 << exponent_size
max_regime = size - 2
regime = jnp.clip(exponent // ufactor, -max_regime, max_regime)
k = jnp.abs(regime) + jnp.where(regime >= 0, 1, 0)
# Combines exponent and mantissa but exponent is signed
# This is closer to a posit representation and allows us
# to carry correctly when adding the rounding factor
combined = (~(jnp.uint32(1) << 31) & (exponent.view(jnp.uint32) << 23)) + significand
# Bits not consumed by the regime
remaining_bits = max_regime - k
# Where the least signifcant bit would be located on the posit
lsb_pos = 23 + exponent_size - remaining_bits
# Get the value of the bit at the LSB
lsb_val_corr = jnp.where(remaining_bits <= 0, \
# If we are out of bits, then simulate what would be the LSB
# if this was an actual posit
(regime < 0) | (regime > 0) & (remaining_bits==-1), \
# Otherwise just get the actual exponent/significand value
((combined >> lsb_pos) & 1))
# If LSB is 0, then round down.
# See https://posithub.org/docs/posit_standard-2.pdf
combined = combined + ((1 << lsb_pos-1) - jnp.where(lsb_val_corr == 1, 0, 1))
# Mask off region beyond LSB
combined = combined & ~((1 << lsb_pos)-1)
significand_mask = (jnp.uint32(1) << 23) - 1
exponent_mask = ((jnp.uint32(1) << 8) - 1) << 23
# We need to shift the exponent to the start so we can sign extend it and place it back at the end
new_exponent = (((combined & exponent_mask) << 1).view(jnp.int32)) >> 24
# Finally we can build our value
val = compose(negative, new_exponent, combined & significand_mask)
# Clip values out of range
val = jnp.where(jnp.abs(exponent) > max_regime*ufactor, compose(negative, jnp.sign(exponent) * max_regime * ufactor, 0), val)
# Conserve NaNs
val = jnp.where(jnp.isnan(arr), arr, val)
return val
class TestQuantMethods(unittest.TestCase):
def test_compose_decompose(self):
to_check = jnp.iinfo(jnp.uint32).max - jnp.iinfo(jnp.uint32).min
iters = 1 << 12
slice_length = to_check // iters
recreate = lambda x: compose(*decompose(x))
recreate = jax.jit(recreate)
for i in range(iters):
cases = jnp.arange(start = slice_length*i, stop = slice_length*(i+1), dtype=jnp.uint32).view(jnp.float32)
res = recreate(cases)
matches = (cases == res) | (jnp.isnan(cases) & jnp.isnan(res))
no_match = jnp.argwhere(~matches)
cases_bad = jnp.take(cases, no_match)
res_bad = jnp.take(res, no_match)
assert matches.all(), f"{cases_bad} {res_bad}"
def test_quantize(self):
import random
for i in range(jnp.float32(2**-57).view(
jnp.uint32), jnp.float32(2**57).view(jnp.uint32), 2**14):
val = jnp.uint32(i + random.randint(0, 2**14-1)).view(jnp.float32)
test_equiv(val)
test_equiv(-val)
def test_equiv(val: jnp.float32):
import softposit as sp
truth = jnp.float32(sp.convertP16ToDouble(sp.posit16(float(val)).v))
our = quantize_float_to_posit(16, 1, val)
assert truth == our, f"For {val}, {our} != {truth}. ({val.view(jnp.uint32):032b}, {our.view(jnp.uint32):032b} != {truth.view(jnp.uint32):032b}) \nval:\t {format_decomposed(val)}\nour:\t {format_decomposed(our)}\ntruth:\t {format_decomposed(truth)}\n"
if __name__ == '__main__':
unittest.main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment