Skip to content

Instantly share code, notes, and snippets.

@zhangqiaorjc
Created November 30, 2022 19:38
Show Gist options
  • Save zhangqiaorjc/25c9d753864cd95942ffd9ba48b63f42 to your computer and use it in GitHub Desktop.
Save zhangqiaorjc/25c9d753864cd95942ffd9ba48b63f42 to your computer and use it in GitHub Desktop.
"""Runs a simple mnist model with fake FP8. FP8 scaling is used.
The HLO can be dumped by setting the environment variable:
XLA_FLAGS='--xla_dump_disable_metadata=true --xla_dump_to=/tmp/hlo'
"""
import tensorflow as tf
USE_QUANT = True
tf.keras.utils.set_random_seed(1)
# Fake FP8 dtypes since we don't yet have real FP8
FAKE_E4M3 = tf.float16
FAKE_E5M2 = tf.bfloat16
E4M3_MAX = 448.
E5M2_MAX = 57344.
def get_fp8_max(fake_dtype):
if fake_dtype == FAKE_E4M3:
return E4M3_MAX
else:
assert fake_dtype == FAKE_E5M2
return E5M2_MAX
def quantize(x, quantized_dtype, scale):
dtype_max = get_fp8_max(quantized_dtype)
scaled_x = tf.clip_by_value(x / scale, -dtype_max, dtype_max)
return tf.cast(scaled_x, quantized_dtype)
def dequantize(x, wide_dtype, scale):
return tf.cast(x, wide_dtype) * scale
def quantize_dequantize(x, quantized_dtype, scale):
orig_dtype = x.dtype
qx = quantize(x, quantized_dtype, scale)
return dequantize(qx, orig_dtype, scale)
def update_scale(x, quantized_dtype, scale_var):
dtype_max = get_fp8_max(quantized_dtype)
amax = tf.cast(tf.math.reduce_max(tf.math.abs(x)), scale_var.dtype)
amax = tf.maximum(amax, 2 ** -10)
scale_var.assign(1.1 * amax / dtype_max)
def qdq_and_update(x, dtype, scale_var):
qx = quantize_dequantize(x, dtype, scale_var)
update_scale(x, dtype, scale_var)
return qx
class DenseWithScaling(tf.keras.layers.Layer):
def __init__(self, units, activation=None, use_quant=False):
super().__init__()
self.units = int(units)
self.actvation = tf.keras.activations.get(activation)
self.use_quant = use_quant
def build(self, input_shape):
last_dim = input_shape[-1]
self.kernel = self.add_weight("kernel", shape=[last_dim, self.units],
initializer="glorot_uniform")
self.bias = self.add_weight("kernel", shape=[self.units],
initializer="zeros")
if self.use_quant:
init32 = tf.keras.initializers.Constant(32.)
self.output_scale = self.add_weight("output_scale", shape=(),
initializer=init32, trainable=False)
self.kernel_scale = self.add_weight("kernel_scale", shape=(),
initializer=init32, trainable=False)
self.output_grad_scale = self.add_weight("kernel_scale", shape=(),
initializer=init32,
trainable=False)
@tf.custom_gradient
def out_qdq(self, out):
"""Quantize-dequantize both the output and the output's gradient."""
qout = qdq_and_update(out, FAKE_E4M3, self.output_scale)
def grad(out_grad):
return qdq_and_update(out_grad, FAKE_E5M2, self.output_grad_scale)
return qout, grad
@tf.custom_gradient
def kernel_qdq(self, kernel):
"""Quantize-dequantize the kernel but not its gradient."""
qkernel = qdq_and_update(kernel, FAKE_E4M3, self.kernel_scale)
def grad(kernel_grad):
return kernel_grad
return qkernel, grad
def call(self, inputs):
kernel = self.kernel.read_value()
if self.use_quant:
kernel = self.kernel_qdq(kernel)
out = inputs @ kernel + self.bias
out = self.actvation(out)
if self.use_quant:
out = self.out_qdq(out)
return out
class MnistModel(tf.keras.Model):
def build(self, input_shape):
self.dense1 = DenseWithScaling(64, activation="relu", use_quant=USE_QUANT)
self.dense2 = DenseWithScaling(64, activation="relu", use_quant=USE_QUANT)
self.dense3 = DenseWithScaling(10, use_quant=USE_QUANT)
def call(self, inputs):
x = self.dense1(inputs)
x = self.dense2(x)
output = self.dense3(x)
return output
model = MnistModel()
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = x_train.reshape(60000, 784).astype("float32") / 255
x_test = x_test.reshape(10000, 784).astype("float32") / 255
model.compile(
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=tf.keras.optimizers.RMSprop(),
metrics=["accuracy"],
# run_eagerly=True,
jit_compile=True
)
history = model.fit(x_train, y_train, batch_size=64, epochs=2,
validation_split=0.2, verbose=1)
test_scores = model.evaluate(x_test, y_test, verbose=2)
print("Test loss:", test_scores[0])
print("Test accuracy:", test_scores[1])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment