Skip to content

Instantly share code, notes, and snippets.

@zhangqiaorjc
Created November 29, 2022 21:33
Show Gist options
  • Save zhangqiaorjc/22cb60d3e12edd0b81143bd42442221e to your computer and use it in GitHub Desktop.
Save zhangqiaorjc/22cb60d3e12edd0b81143bd42442221e to your computer and use it in GitHub Desktop.
from absl.testing import absltest
from absl import logging
import jax
import jax.numpy as jnp
def amax(x):
return jnp.max(jnp.abs(x))
def get_scale(amax_v):
return 1.1 * (amax_v / 448)
fake_fp8 = jnp.float16
fake_bf16 = jnp.float32
def quantize(x, amax_v):
return (x / get_scale(amax_v)).astype(fake_fp8)
def dequantize(x, amax_v):
return x.astype(fake_bf16) * get_scale(amax_v)
# RFC Approach (https://github.com/openxla/xla/discussions/22)
def matmul_f8_rfc_fp8_inp(x_fp8, y_fp8, x_amax, y_amax, z_amax):
x_rounded = dequantize(x_fp8, x_amax)
y_rounded = dequantize(y_fp8, y_amax)
z_bf16 = jnp.dot(x_rounded, y_rounded)
new_z_amax = amax(z_bf16)
z_fp8 = quantize(z_bf16, z_amax)
return z_fp8, new_z_amax
class Fp8Test(absltest.TestCase):
def test_quantize(self):
logging.info('amax = %s', amax(jnp.array([1e-15, 4, 27])))
logging.info('scale = %s', get_scale(449))
fp8_v = quantize(449, amax(449))
logging.info('scaled to fp8 = %s', fp8_v)
logging.info('unscaled to bf16 = %s', dequantize(fp8_v, amax(449)))
def test_matmul(self):
A = jnp.ones((2, 4), dtype=fake_bf16)
B = jnp.ones((4, 2), dtype=fake_bf16)
A_amax = amax(A)
B_amax = amax(B)
A_fp8 = quantize(A, A_amax)
B_fp8 = quantize(B, B_amax)
C_fp8, C_amax = matmul_f8_rfc_fp8_inp(A_fp8, B_fp8, A_amax, B_amax, amax(jnp.dot(A, B)))
C = dequantize(C_fp8, C_amax)
logging.info('matmul_f8_rfc_fp8_inp: %s', C)
logging.info('matmul %s: ', jnp.dot(A, B))
def print_ir(f, *args):
lowered = jax.jit(f).lower(*args)
logging.info(lowered.compiler_ir())
print_ir(matmul_f8_rfc_fp8_inp, A_fp8, B_fp8, A_amax, B_amax,
amax(jnp.dot(A, B)))
if __name__ == '__main__':
absltest.main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment