Skip to content

Instantly share code, notes, and snippets.

@danmunson
Last active April 14, 2023 01:38
Show Gist options
  • Save danmunson/f73b675acc40128c559d416956dbb463 to your computer and use it in GitHub Desktop.
Save danmunson/f73b675acc40128c559d416956dbb463 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "4698a883",
"metadata": {},
"source": [
"# *Solving MNIST with JAX*\n",
"\n",
"JAX is a cool library. Among other things, it:\n",
"- can JIT compile code for a CPU/GPU/TPU/etc...\n",
"- makes parallel execution easy, even on separate devices\n",
"- can transform a scalar-valued function into one that computes its gradient\n",
"\n",
"What better way to explore this library than with a neural net?\n",
"\n",
"Special thanks to [You Don't Know JAX](https://colinraffel.com/blog/you-don-t-know-jax.html) and [Neural Networks and Deep Learning](http://neuralnetworksanddeeplearning.com)."
]
},
{
"cell_type": "markdown",
"id": "f592cabc",
"metadata": {},
"source": [
"### First, let's make sure we're able to use the GPUs"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "5c90a2c1",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0),\n",
" StreamExecutorGpuDevice(id=1, process_index=0, slice_index=0)]"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import jax\n",
"jax.devices()"
]
},
{
"cell_type": "markdown",
"id": "f75f65f4",
"metadata": {},
"source": [
"## Downloading the MNIST dataset\n",
"\n",
"The MNIST dataset can be downloaded into a pandas dataframe via sklearn's OpenML interface. In order to save time in future runs, the pandas dataframe is cached."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "6b50c4a5",
"metadata": {},
"outputs": [],
"source": [
"from sklearn.datasets import fetch_openml\n",
"import pickle\n",
"import os\n",
"\n",
"def load_mnist():\n",
" pickle_path = '../data/mnist/data.pkl'\n",
" if os.path.exists(pickle_path):\n",
" with open(pickle_path, 'rb') as f:\n",
" return pickle.load(f)\n",
" mnist = fetch_openml(name='mnist_784', version=1, parser='auto')\n",
" with open(pickle_path, 'wb') as f:\n",
" pickle.dump(mnist, f)\n",
" return mnist\n",
"\n",
"mnist = load_mnist()\n",
"all_features, all_targets = mnist['data'], mnist['target']"
]
},
{
"cell_type": "markdown",
"id": "c31aff85",
"metadata": {},
"source": [
"## Note: Randomness in JAX\n",
"\n",
"JAX offers a randomness utility very similar to that of numpy. The major difference is that you need to explicitly provide a seed. This is more tedious, but the upside is that you have more control over the process. For example if you use the same seed, you should end up with the same results.\n",
"\n",
"All randomness in this notebook is derived from the same seed - so, if you choose to run it (with the provdied key) you should get the exact same results."
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "a9f113cb",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Using seed: 298504\n"
]
}
],
"source": [
"from random import randint\n",
"from jax import random as jrand\n",
"\n",
"class DistSampler:\n",
" def __init__(self, seed=None):\n",
" seed = seed or randint(0, 10**6)\n",
" self.key = jrand.PRNGKey(seed)\n",
" print('Using seed:', seed)\n",
" \n",
" def normal(self, *shape):\n",
" self.key, _ = jrand.split(self.key)\n",
" return jrand.normal(self.key, shape=tuple(shape))\n",
" \n",
" def choice(self, n, k):\n",
" self.key, _ = jrand.split(self.key)\n",
" return list(map(int, jrand.choice(self.key, n, (k,), replace=False)))\n",
" \n",
" def random_rows(self, n, *arrs):\n",
" idxs = self.choice(arrs[0].shape[0], n) # assume arrays have same outer dim\n",
" return [jnp.take(x, jnp.asarray(idxs), axis=0) for x in arrs]\n",
" \n",
" def shuffle(self, arr):\n",
" self.key, _ = jrand.split(self.key)\n",
" return jrand.permutation(self.key, arr)\n",
" \n",
"seed = 298504\n",
"dist_sampler = DistSampler(seed)"
]
},
{
"cell_type": "markdown",
"id": "23a8dbcd",
"metadata": {},
"source": [
"## Preparing the data\n",
"\n",
"Now that the dataset is available, it needs to be split into train and test sets and the feature vectors need to be normalized. For my own sanity, I also performed a quick spot check on the data."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "861baa30",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Sample image:\n",
"\n",
" \n",
" \n",
" \n",
" \n",
" \n",
" @@@ \n",
" @@@@@@@ \n",
" @@@@@@@@@@@@@ \n",
" @@@@@@ @@@@@@@ \n",
" @@@@ @@@@@ \n",
" @@@ @@@@ \n",
" @@@ @@@@ \n",
" @@@ @@@@@@ \n",
" @@@ @@@@@@ \n",
" @@ @@@@@@ \n",
" @@@@@@@@@ \n",
" @@@@@@@@ \n",
" @@@@@@ \n",
" @@@@@@@ \n",
" @@@@@@@@ \n",
" @@@@@ @@@ \n",
" @@@@@@@@@@ \n",
" @@@@@@@@@@ \n",
" @@@@@@@@ \n",
" @@@@@@ \n",
" \n",
" \n",
" \n",
"\n",
"Sample answer: 8\n"
]
}
],
"source": [
"from jax import numpy as jnp\n",
"import pandas as pd\n",
"\n",
"\"\"\"\n",
"Split the data into train and test segments, then format it as JAX matrices\n",
"\"\"\"\n",
"bool_vec = [i < 60_000 for i in range(len(all_targets))]\n",
"bool_vec = [bool(x) for x in dist_sampler.shuffle(jnp.asarray(bool_vec))]\n",
"split_df = lambda df : (\n",
" df[pd.Series(bool_vec).values],\n",
" df[pd.Series([not b for b in bool_vec]).values]\n",
")\n",
"\n",
"train_features, test_features = split_df(all_features)\n",
"train_targets, test_targets = split_df(all_targets)\n",
"\n",
"format_jnp = lambda *dfs : tuple([jnp.asarray(df.to_numpy(), dtype='float32') for df in dfs])\n",
"\n",
"ftr_train, ftr_test, tgt_train, tgt_test = format_jnp(\n",
" train_features,\n",
" test_features,\n",
" pd.get_dummies(train_targets),\n",
" pd.get_dummies(test_targets)\n",
")\n",
"\n",
"\"\"\"\n",
"Spot check: print a rough sketch of a sample number, along with the expected answer.\n",
"\"\"\"\n",
"print('Sample image:\\n')\n",
"\n",
"idx = dist_sampler.choice(ftr_train.shape[0], 1)[0]\n",
"x, y = ftr_train[idx], tgt_train[idx]\n",
"\n",
"img = jnp.reshape(x, (28,28))\n",
"for row in img:\n",
" print(''.join(['@' if pix > 100 else ' ' for pix in row]))\n",
" \n",
"print(f'\\nSample answer: {jnp.argmax(y)}')\n",
"\n",
"\"\"\"\n",
"Normalize the features\n",
"\"\"\"\n",
"normalize = lambda ftr_df : ftr_df / jnp.linalg.norm(ftr_df, axis=1, keepdims=True)\n",
"ftr_train, ftr_test = normalize(ftr_train), normalize(ftr_test)\n",
"\n",
"Obs, Resp = ftr_train, tgt_train\n",
"TestObs, TestResp = ftr_test, tgt_test"
]
},
{
"cell_type": "markdown",
"id": "1cae116e",
"metadata": {},
"source": [
"## Constructing the Neural Network\n",
"\n",
"Here we'll define a simple neural network that maps a 784 dimensional input to a 10 dimemnsional input (which is then softmaxed). In addition, we'll define the loss function (cross-entropy), as well as a function to initialize the neural nets parameters.\n",
"\n",
"\n",
"Note: in order to leverage the full power of JAX's jit-compiler (more on that later), it's important to make sure our functions are \"pure\", e.g. stateless. Note that the parameters are an explicit parameter of the neural network, rather than being \"baked in\". Personally, this is one of the things I like about working with JAX. The structure of the network (it's dimensions and activation functions) are clearly separated from the parameter values that are being learned, making the process of training more intuitive."
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "ff0da1a1",
"metadata": {},
"outputs": [],
"source": [
"from jax import nn, lax\n",
"\n",
"def neural_net(x, params):\n",
" w0, b0, w1, b1, w2, b2 = params\n",
" x1 = jnp.tanh(jnp.dot(w0, x) + b0)\n",
" x2 = jnp.tanh(jnp.dot(w1, x1) + b1)\n",
" x3 = jnp.tanh(jnp.dot(w2, x2) + b2)\n",
" return nn.softmax(x3)\n",
"\n",
"def initialize_params():\n",
" d0, d1, d2 = 50, 100, 10\n",
" return [\n",
" dist_sampler.normal(d0, Obs.shape[1]),\n",
" dist_sampler.normal(d0),\n",
"\n",
" dist_sampler.normal(d1, d0),\n",
" dist_sampler.normal(d1),\n",
"\n",
" dist_sampler.normal(d2, d1),\n",
" dist_sampler.normal(d2),\n",
" ]\n",
"\n",
"def cross_entropy_loss(params, x, y):\n",
" prediction = neural_net(x, params)\n",
" return jnp.sum(\n",
" -y * jnp.log(prediction) - (1. - y) * jnp.log(1. - prediction)\n",
" )"
]
},
{
"cell_type": "markdown",
"id": "273f99ae",
"metadata": {},
"source": [
"## JAX MAGIC (GRAD, VMAP, JIT)\n",
"\n",
"Seemingly magical function transformations are where JAX really comes into its own, and none exemplify this better than `jax.grad`, `jax.vmap` and `jax.jit`.\n",
"\n",
"To briefly summarize, a neural network is just an \"over-parameterized\" differentiable function that maps inputs to outputs, which we can then interpret however we like. Training a neural net basically comes down to (1) passing data into this function, (2) computing a loss value based on the output, (3) computing the gradient of the loss function with respect to the neural net parameters, and then (4) updating the parameters in the opposite direction of the gradient.\n",
"\n",
"### `jax.grad`\n",
"Step 3 of the training process is easily the trickiest part, which is where `jax.grad` comes in. Once we've defined a loss function around our neural net, we can then transform this loss function into one that takes the same arguments but instead returns the gradient with respect to the parameters of the neural net (note that `jax.grad` expects these parameters to be the first argument of the function being transformed). Once we have those gradients, updating the parameters becomes dead simple.\n",
"\n",
"### `jax.vmap`\n",
"Neural nets are known to be \"embarassingly parallel\", which just means that a lot of the overall work of training can be done in parallel. However, actually taking advantage of this structural parallelism can be a delicate business. `jax.vmap` takes care of that for us by simply taking a function as input and returning a parallelized version.\n",
"\n",
"When using `jax_vmap`, pay close attention to the `in_axes` and `out_axes` arguments. In the below example, note that:\n",
"\n",
"- in_axes=(None, 0, 0)\n",
"- out_axes=0\n",
"\n",
"These axes tell JAX that if our initial function `F` has a signature:\n",
"\n",
"`F : (a, b, c) -> d`\n",
" \n",
"then `vmap(F)` has a signature of:\n",
"\n",
"`vmap(F) : (a, [b0, ..., bn], [c0, ..., cn]) -> [f(a, b0, c0), ..., f(a, bn, cn)]`\n",
"\n",
"### `jax.jit`\n",
"Neural nets are also known to be very computationally intensive, so we'd like any speedup we can get. `jax.jit` enables us to mark a function as \"jit-compilable\", enabling it to significantly improve the efficiency of that functions execution.\n",
"\n",
"For a function to be \"jit-compilable\", it must be pure. You'll likely encounter a number of errors related to this when you give JAX a go; just stay patient and use the docs!"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "2294f192",
"metadata": {},
"outputs": [],
"source": [
"from jax import jit, vmap, grad\n",
"\n",
"# create the gradient function\n",
"loss_gradient = grad(cross_entropy_loss)\n",
"\n",
"# parallelize the gradient function\n",
"vectorized_loss_gradient = vmap(loss_gradient, in_axes=(None, 0, 0), out_axes=0)\n",
"\n",
"# define the \"batch update\" function and mark it for jit\n",
"@jit\n",
"def batch_update(params, X, Y, learning_rate):\n",
" # (1) compute the gradients of loss function relative to the training data\n",
" param_gradients = vectorized_loss_gradient(params, X, Y)\n",
" # (2) average the gradients\n",
" average_grad_per_param = [\n",
" jnp.average(gradients, 0)\n",
" for gradients in param_gradients\n",
" ]\n",
" # (3) subtract the gradients from the original params and return them\n",
" return [\n",
" param - (learning_rate * average_grad)\n",
" for param, average_grad in zip(params, average_grad_per_param)\n",
" ]"
]
},
{
"cell_type": "markdown",
"id": "288f37e9",
"metadata": {},
"source": [
"## Evaluation\n",
"\n",
"In order to evaluate how well the neural network is actually doing during and after training, we need to evaluate its accuracy against the test set. These functions will let us do that. We can also take advantage of `jax.vmap` and `jax.jit` here, but we'll opt for the decorator syntax here to be more concise."
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "00ba6d36",
"metadata": {},
"outputs": [],
"source": [
"from functools import partial\n",
"\n",
"@partial(vmap, in_axes=(None, 0, 0), out_axes=0)\n",
"def score_prediction(params, x, y):\n",
" return jnp.argmax(neural_net(x, params)) == jnp.argmax(y)\n",
"\n",
"@jit\n",
"def avg_prediction_accuracy(params, X, Y):\n",
" scores = score_prediction(params, X, Y)\n",
" return jnp.average(scores, 0)"
]
},
{
"cell_type": "markdown",
"id": "b670d82c",
"metadata": {},
"source": [
"## Training the neural net\n",
"\n",
"The meat of the training process is defined in the `batch_update` function. Here we just run that function in a loop, using the updated parameters each time. Every once in awhile we'll log the current accuracy, and occassionally we'll shrink the learning_rate."
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "e290054c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Initial Accuracy: 8.70%\n",
"\n",
"(500) Accuracy: 79.70%\n",
"\n",
"(1000) Accuracy: 85.50%\n",
"\n",
"(1500) Accuracy: 88.80%\n",
"\n",
"(2000) Accuracy: 89.00%\n",
"\n",
"(2500) Accuracy: 89.50%\n",
"\n",
"(3000) Accuracy: 91.80%\n",
"\n",
"(3500) Accuracy: 91.80%\n",
"\n",
"(4000) Accuracy: 91.20%\n",
"\n",
"(4500) Accuracy: 91.40%\n",
"\n",
"(5000) Accuracy: 91.80%\n",
"\n",
"54.20 seconds elapsed during training.\n",
"\n",
"Final accuracy against test set: 91.74%\n",
"\n"
]
}
],
"source": [
"import time\n",
"\n",
"# initialize the parameters\n",
"neural_net_params = initialize_params()\n",
"\n",
"learning_rate = 0.5\n",
"iterations = 5000\n",
"batch_size = 100\n",
"\n",
"fmt_acc = lambda acc : f'{(acc * 100):.2f}%\\n'\n",
"\n",
"test_xs, test_ys = dist_sampler.random_rows(1000, TestObs, TestResp)\n",
"accuracy = avg_prediction_accuracy(neural_net_params, test_xs, test_ys)\n",
"print(f'{\"Initial Accuracy:\":<20}{fmt_acc(accuracy):>10}')\n",
"\n",
"starting_time = time.time()\n",
"\n",
"for iter_num in range(1, iterations + 1):\n",
" # perform gradient descent to get new params\n",
" batch_x, batch_y = dist_sampler.random_rows(batch_size, Obs, Resp)\n",
" neural_net_params = batch_update(neural_net_params, batch_x, batch_y, learning_rate)\n",
"\n",
" # evaluate accuracy every 500 iterations\n",
" if (iter_num) % 500 == 0:\n",
" test_xs, test_ys = dist_sampler.random_rows(1000, TestObs, TestResp)\n",
" accuracy = avg_prediction_accuracy(neural_net_params, test_xs, test_ys)\n",
" print(f'{f\"({iter_num}) Accuracy:\":<20}{fmt_acc(accuracy):>10}')\n",
"\n",
" # decrease the learning rate over time\n",
" if (iter_num) % 2000 == 0:\n",
" learning_rate *= 0.5\n",
" \n",
"elapsed_seconds = time.time() - starting_time\n",
"\n",
"print(f'{elapsed_seconds:.2f} seconds elapsed during training.\\n')\n",
"\n",
"accuracy = avg_prediction_accuracy(neural_net_params, TestObs, TestResp)\n",
"print(f'Final accuracy against test set: {fmt_acc(accuracy):>10}')"
]
},
{
"cell_type": "markdown",
"id": "50297f49",
"metadata": {},
"source": [
"### So there you have it - a simple neural network in JAX!\n",
"\n",
"Despite this being a toy example, it should be clear that JAX's transformations can be used to make the whole process of training a neural network intuitive, without sacrificing performance.\n",
"\n",
"I should note that JAX has a special library called `stax` which is designed for neural nets in particular. If you're interested in JAX for this use case, you should check that out!"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "nns",
"language": "python",
"name": "nns"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment