Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save zhangqiaorjc/693962566574b151d9f53b72f590c431 to your computer and use it in GitHub Desktop.
Save zhangqiaorjc/693962566574b151d9f53b72f590c431 to your computer and use it in GitHub Desktop.
Copy of mnist fp8 for sharing with NVIDIA.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/zhangqiaorjc/693962566574b151d9f53b72f590c431/copy-of-mnist-fp8-for-sharing-with-nvidia.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"source": [
"JAX FP8 MNIST training on GPU (zhangqiaorjc@google.com)\n",
"- Comparing FP8 and FP32 training.\n",
"- cf. Keras example from reedwm@google.com:\n",
"https://gist.github.com/zhangqiaorjc/25c9d753864cd95942ffd9ba48b63f42\n",
"\n"
],
"metadata": {
"id": "3fkR91xFLiw2"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "6N5gTB1cgksd"
},
"outputs": [],
"source": [
"from functools import partial\n",
"from typing import (Any, Callable, Iterable, List, Optional, Mapping, Sequence, Tuple, Union)\n",
"import jax\n",
"import jax.numpy as jnp\n",
"import flax\n",
"from flax import linen as nn\n",
"import re\n",
"from flax.traverse_util import flatten_dict, unflatten_dict\n",
"\n",
"PRNGKey = Any\n",
"Shape = Tuple[int, ...]\n",
"Dtype = Any\n",
"Array = Any\n",
"ActivationFn = Any\n",
"\n",
"FAKE_E4M3 = jnp.float16\n",
"FAKE_E5M2 = jnp.bfloat16\n",
"E4M3_MAX = 448\n",
"E5M2_MAX = 57344\n",
"\n",
"def tree_shape(x): return jax.tree_map(lambda v: v.shape, x)\n",
"\n",
"def get_fp8_max(fake_dtype):\n",
" if fake_dtype == FAKE_E4M3:\n",
" return E4M3_MAX\n",
" elif fake_dtype == FAKE_E5M2:\n",
" return E5M2_MAX\n",
" else:\n",
" raise ValueError('Only FAKE_E4M3 and FAKE_E5M2 supported')\n",
"\n",
"def quantize(x, quantized_dtype, scale):\n",
" dtype_max = get_fp8_max(quantized_dtype)\n",
" scaled_x = jnp.clip(x / scale, -dtype_max, dtype_max)\n",
" return scaled_x.astype(quantized_dtype)\n",
"\n",
"def dequantize(x, wide_dtype, scale):\n",
" return x.astype(wide_dtype) * scale\n",
"\n",
"def quantize_dequantize(x, quantized_dtype, scale):\n",
" orig_dtype = x.dtype\n",
" qx = quantize(x, quantized_dtype, scale)\n",
" return dequantize(qx, orig_dtype, scale)\n",
"\n",
"def compute_new_scale(x, quantized_dtype, scale):\n",
" dtype_max = get_fp8_max(quantized_dtype)\n",
" amax = jnp.max(jnp.abs(x)).astype(scale.dtype)\n",
" # Ensure scale != 0 and avoid divide-by-zero.\n",
" amax = jnp.maximum(amax, 2**-10)\n",
" return 1.1 * amax / dtype_max\n",
"\n",
"def qdq_and_new_scale(x, dtype, scale):\n",
" qx = quantize_dequantize(x, dtype, scale)\n",
" new_scale = compute_new_scale(x, dtype, scale)\n",
" return qx, new_scale"
]
},
{
"cell_type": "code",
"source": [
"@jax.custom_vjp\n",
"def kernel_qdq(kernel, kernel_scale):\n",
" qkernel, new_kernel_scale = qdq_and_new_scale(kernel, FAKE_E4M3, kernel_scale)\n",
" return qkernel, new_kernel_scale\n",
"\n",
"def kernel_qdq_fwd(kernel, kernel_scale):\n",
" return kernel_qdq(kernel, kernel_scale), None\n",
"\n",
"def kernel_qdq_bwd(_, g):\n",
" # pass through gradients\n",
" return g\n",
"\n",
"kernel_qdq.defvjp(kernel_qdq_fwd, kernel_qdq_bwd)\n",
"\n",
"\n",
"@jax.custom_vjp\n",
"def out_qdq(out, out_scale, out_grad_scale, dummy):\n",
" qout, new_out_scale = qdq_and_new_scale(out, FAKE_E4M3, out_scale)\n",
" # out_grad_scale is needed in vjp\n",
" return qout, new_out_scale, out_grad_scale\n",
"\n",
"def out_qdq_fwd(out, out_scale, out_grad_scale, dummy):\n",
" # new_out_grad_scale is a dummy value\n",
" qout, new_out_scale, new_out_grad_scale = out_qdq(out, out_scale, out_grad_scale, dummy)\n",
" return (qout, new_out_scale, new_out_grad_scale), (out_grad_scale, )\n",
"\n",
"def out_qdq_bwd(res, g):\n",
" out_grad_scale, = res\n",
" qout_g, new_out_scale_g, out_grad_scale_g = g\n",
" out_grad, new_out_grad_scale = qdq_and_new_scale(qout_g, FAKE_E5M2, out_grad_scale)\n",
" return out_grad, jnp.zeros_like(new_out_scale_g), jnp.zeros_like(out_grad_scale_g), new_out_grad_scale\n",
"\n",
"out_qdq.defvjp(out_qdq_fwd, out_qdq_bwd)\n",
"\n",
"initializer_32 = lambda: jnp.array(32.0, dtype=jnp.float32)\n",
"\n",
"class DenseWithScaling(nn.Module):\n",
" features: int\n",
" param_dtype: Dtype = jnp.float32\n",
" kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = nn.initializers.lecun_normal()\n",
" bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = nn.initializers.zeros\n",
" activation: Optional[ActivationFn] = None\n",
" use_quant: bool = False\n",
"\n",
" @nn.compact\n",
" def __call__(self, inputs):\n",
"\n",
" kernel = self.param('kernel', self.kernel_init, (inputs.shape[1], self.features), self.param_dtype)\n",
" bias = self.param('bias', self.bias_init, (self.features,), self.param_dtype)\n",
"\n",
" if self.use_quant:\n",
" kernel_scale = self.variable('qscale','kernel_scale', initializer_32)\n",
" kernel, new_kernel_scale = kernel_qdq(kernel, kernel_scale.value)\n",
" kernel_scale.value = new_kernel_scale\n",
"\n",
" # Actual dense layer math.\n",
" out = jnp.dot(inputs, kernel) + bias\n",
" if self.activation:\n",
" out = self.activation(out)\n",
"\n",
" if self.use_quant:\n",
" output_scale = self.variable('qscale', 'output_scale', initializer_32)\n",
" output_grad_scale = self.variable('qscale','output_grad_scale', initializer_32)\n",
" # output_grad_scale is updated in training loop\n",
" output_grad_scale_perturb = self.variable('grad_qscale_placeholder','output_grad_scale_placeholder', initializer_32)\n",
" out, new_out_scale, new_out_grad_scale = out_qdq(out, output_scale.value, output_grad_scale.value, output_grad_scale_perturb.value)\n",
" output_scale.value = new_out_scale\n",
" return out\n",
"\n",
"\n",
"class MnistModel(nn.Module):\n",
" use_quant: bool = False\n",
" def setup(self):\n",
" self.dense1 = DenseWithScaling(64, activation=jax.nn.relu, use_quant=self.use_quant)\n",
" self.dense2 = DenseWithScaling(64, activation=jax.nn.relu, use_quant=self.use_quant)\n",
" self.dense3 = DenseWithScaling(10, use_quant=self.use_quant)\n",
"\n",
" def __call__(self, inputs):\n",
" x = self.dense1(inputs)\n",
" x = self.dense2(x)\n",
" output = self.dense3(x)\n",
" return output\n"
],
"metadata": {
"id": "pggTGgeN9GFI"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"import tensorflow as tf\n",
"from flax import struct\n",
"import optax\n",
"from dataclasses import dataclass\n",
"\n",
"class TrainState(struct.PyTreeNode):\n",
" step: int\n",
" params: Any\n",
" grad_qscale_placeholder: Any\n",
" qscale: Any\n",
" opt_state: optax.OptState\n",
" tx: optax.GradientTransformation = struct.field(pytree_node=False)\n",
"\n",
" @staticmethod\n",
" def create(vars, tx):\n",
" params = flax.core.unfreeze(vars['params'])\n",
" opt_state = tx.init(params)\n",
" grad_qscale_placeholder=flax.core.unfreeze(vars['grad_qscale_placeholder']) if 'grad_qscale_placeholder' in vars else None\n",
" qscale=flax.core.unfreeze(vars['qscale']) if 'qscale' in vars else None\n",
" return TrainState(0, params, grad_qscale_placeholder, qscale, opt_state, tx)\n",
"\n",
" def get_diff_vars(self):\n",
" if self.grad_qscale_placeholder:\n",
" return {'params': self.params, \"grad_qscale_placeholder\": self.grad_qscale_placeholder}\n",
" return {'params': self.params}\n",
"\n",
" def get_nondiff_vars(self):\n",
" if self.qscale:\n",
" return {'qscale': self.qscale}\n",
" return {}\n",
"\n",
"def loss_fn(model, diff_vars, nondiff_vars, input_batch):\n",
" logits, updated_nondiff_vars = model.apply({**diff_vars, **nondiff_vars}, input_batch['x'], mutable=['qscale'])\n",
" batched_loss = optax.softmax_cross_entropy_with_integer_labels(logits, input_batch['y'])\n",
" return jnp.mean(batched_loss, axis=0), updated_nondiff_vars\n",
"\n",
"def step_fn(model, train_state, input_batch):\n",
" bound_loss_fn = partial(loss_fn, model)\n",
" grad_fn = jax.value_and_grad(bound_loss_fn, has_aux=True)\n",
" (loss_val, updated_nondiff_vars), diff_vars_grads = grad_fn(train_state.get_diff_vars(), train_state.get_nondiff_vars(), input_batch)\n",
" params_updates, updated_opt_state = train_state.tx.update(diff_vars_grads['params'], train_state.opt_state, params=train_state.params)\n",
" updated_params = optax.apply_updates(train_state.params, params_updates)\n",
" # Update train state\n",
" new_qscale_vars = updated_nondiff_vars['qscale'] if 'qscale' in updated_nondiff_vars else None\n",
"\n",
" # Update qscale with grad_qscale_placeholder for gradient scale entries.\n",
" if 'qscale' in updated_nondiff_vars:\n",
" grad_qscale_vals = {\n",
" tuple(re.sub(r'_placeholder$', '', '/'.join(k)).split('/')): v\n",
" for k, v in flatten_dict(diff_vars_grads['grad_qscale_placeholder']).items()\n",
" }\n",
" flat_new_qscale_vars = flatten_dict(new_qscale_vars)\n",
" flat_new_qscale_vars.update(grad_qscale_vals)\n",
" new_qscale_vars = unflatten_dict(flat_new_qscale_vars)\n",
"\n",
" return train_state.replace(step=train_state.step+1, params=updated_params, qscale=new_qscale_vars, opt_state=updated_opt_state), loss_val\n"
],
"metadata": {
"id": "F-sEukSa_9mC"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()\n",
"\n",
"x_train = x_train.reshape(60000, 784).astype(\"float32\") / 255\n",
"x_test = x_test.reshape(10000, 784).astype(\"float32\") / 255\n",
"\n",
"validation_split = 0.2\n",
"\n",
"validation_size = int(60000*0.2)\n",
"x_eval = x_train[-validation_size:]\n",
"y_eval = y_train[-validation_size:]\n",
"x_train = x_train[0:-validation_size]\n",
"y_train = y_train[0:-validation_size]\n",
"train_size = x_train.shape[0]\n",
"\n",
"batch_size = 64\n",
"epochs = 50"
],
"metadata": {
"id": "jWps9adJCtzE"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Load Tensorboard support\n",
"%load_ext google3.learning.brain.tensorboard.notebook.extension"
],
"metadata": {
"id": "UKDnOttIyN7d",
"outputId": "22923483-9035-4106-948f-4a431929f70d"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"The google3.learning.brain.tensorboard.notebook.extension extension is already loaded. To reload it, use:\n",
" %reload_ext google3.learning.brain.tensorboard.notebook.extension\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"from tensorflow import summary\n",
"\n",
"LOG_DIR='./model_3'\n",
"%tensorboard --logdir=\"{LOG_DIR}\" --port=0"
],
"metadata": {
"colab": {
"resources": {
"https://localhost:41150/?tensorboardColab=true": {
"data": "",
"ok": true,
"headers": [
[
"content-length",
"197287"
],
[
"content-type",
"text/html; charset=utf-8"
]
],
"status": 200,
"status_text": ""
},
"https://localhost:41150/index.js": {
"data": "
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment