Skip to content

Instantly share code, notes, and snippets.

@aphearin
Created February 15, 2023 19:42
Show Gist options
  • Save aphearin/84f0d8890e0178e925bb0803de2eb5dc to your computer and use it in GitHub Desktop.
Save aphearin/84f0d8890e0178e925bb0803de2eb5dc to your computer and use it in GitHub Desktop.
Quick overview of JAX
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "afa7c7be",
"metadata": {},
"source": [
"# Quick Intro to JAX\n",
"\n",
"JAX is a python library for automatic differentiation. There is by now quite a lot of well-presented material about JAX available online; [the JAX docs](https://jax.readthedocs.io/en/latest/) are generally very accessible and helpful, and there are also some excellent tutorials floating around (I particularly like [this one](https://ericmjl.github.io/dl-workshop/index.html) on Differentiable Programming with JAX). This notebook just gives a quick overview of some of the core features in JAX: `jax.numpy`, `jax.jit`, `jax.grad`, and `jax.vmap`."
]
},
{
"cell_type": "markdown",
"id": "20aa5319",
"metadata": {},
"source": [
"## JAX as Numpy that runs on GPUs\n",
"\n",
"The JAX API will be familiar to people used to performing bulk-array calculations with Numpy. Even when not using autodiff, writing Numpy-like programs in JAX is a great way to write fast python code that targets GPUs, since JAX's backend is the [XLA](https://www.tensorflow.org/xla) library for accelerated linear algebra. As they say in the JAX docs, you can think of JAX as a differentiable Numpy that runs on accelerators. JAX provides its own implementation of most of the functionality in Numpy, using the same syntax:"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "607e555d",
"metadata": {},
"outputs": [],
"source": [
"from jax import numpy as jnp\n",
"\n",
"nx = 30\n",
"xarr = jnp.linspace(0, 2*jnp.pi, nx)\n",
"yarr = jnp.cos(xarr)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "d1ed9555",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from matplotlib import pyplot as plt\n",
"mred = u'#d62728' \n",
"mgreen = u'#2ca02c'\n",
"mblue = u'#1f77b4' \n",
"\n",
"fig, ax = plt.subplots(1, 1)\n",
"__=ax.plot(xarr, yarr)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "95c06934",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"DeviceArray([0. , 0.21666157, 0.43332314, 0.6499847 , 0.8666463 ,\n",
" 1.0833079 , 1.2999694 , 1.516631 , 1.7332926 , 1.9499542 ,\n",
" 2.1666157 , 2.3832774 , 2.5999389 , 2.8166003 , 3.033262 ,\n",
" 3.2499237 , 3.4665852 , 3.6832466 , 3.8999083 , 4.11657 ,\n",
" 4.3332314 , 4.549893 , 4.766555 , 4.9832163 , 5.1998777 ,\n",
" 5.416539 , 5.6332006 , 5.8498626 , 6.066524 , 6.2831855 ], dtype=float32)"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"xarr"
]
},
{
"cell_type": "markdown",
"id": "bcbb60d8",
"metadata": {},
"source": [
"Notice how `xarr` shows up as a \"DeviceArray\". This is our first indicator that JAX targets GPUs. In GPU programming, one of the primary performance concerns is minimizing transfers of data from CPU to GPU memory; the default behavior of JAX is \"lazy\", so that whenever possible, JAX never returns data from the GPU back to the CPU unless requested, such as the above print statement. This is especially convenient because it means you don't need to rewrite your code to gain major performance benefits from running your programs on a GPU-accelerated machine."
]
},
{
"cell_type": "markdown",
"id": "6cf9b209",
"metadata": {},
"source": [
"## Fast code using JAX's just-in-time compiler\n",
"\n",
"Under the hood, JAX builds a computational trace of all the JAX operations that will be performed during the execution of a function. This trace is used by JAX in gradient calculations with autodiff, and but also more generally to map python code onto a set of array transformations that are efficient to evaluate using the core linear algebra transformations in XLA. This is true for the `jax.numpy` functions above, and also for custom functions you write yourself that use the `jax.jit` decorator to compile the function.\n",
"\n",
"First let's define a simple python function and time it:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "b1159605",
"metadata": {},
"outputs": [],
"source": [
"def func(x):\n",
" for i in range(10):\n",
" x = x - i*0.1*x + i*i\n",
" return np.mean(x[:100])"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "9219a8dc",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"\n",
"xarr_timeit = np.random.uniform(0, 1, int(1e6))"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "b4ca18fe",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"31.9 ms ± 249 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
]
}
],
"source": [
"%timeit func(xarr_timeit)"
]
},
{
"cell_type": "markdown",
"id": "ca967022",
"metadata": {},
"source": [
"### The `@jax.jit` decorator\n",
"\n",
"Now let's use `jax.jit` to build a fast-evaluating compiled version of our function. `jax.jit` is a _decorator_, so it accepts a function as input and returns another function. In our case, `jax.jit` accepts a slow pure-python version of our function, and the function we get back is a fast-evaluating version that gets compiled down to XLA. The phrase \"just-in-time\" just means that this compilation will not happen until the first time we actually call the function. \n",
"\n",
"You can use `jax.jit` with functional syntax like this: `new_func = jax.jit(orig_func)`.\n",
"\n",
"Or you can adorn your function with the decorator syntax `@jax.jit` like in the cell below."
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "7670626a",
"metadata": {},
"outputs": [],
"source": [
"from jax import jit as jjit\n",
"\n",
"@jjit\n",
"def jax_func(x):\n",
" for i in range(10):\n",
" x = x - i*0.1*x + i*i\n",
" return jnp.mean(x[:100])"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "4d1cff43",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"DeviceArray(88.68427, dtype=float32)"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"jax_func(xarr_timeit)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "2165d7b7",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"655 µs ± 5.26 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n"
]
}
],
"source": [
"%timeit jax_func(xarr_timeit)"
]
},
{
"cell_type": "markdown",
"id": "c0dfa2b4",
"metadata": {},
"source": [
"Our implementation of `jax_func` is almost 50x faster than its pure python equivalent."
]
},
{
"cell_type": "markdown",
"id": "f08e5397",
"metadata": {},
"source": [
"## But wait, there are some gotchas\n",
"\n",
"In many situations, speeding up your code and making it GPU portable is as simple as the above examples. But there are some gotchas to be aware of, in which you need to implement your code in a particular way in order for JAX to be able to compile it. This is written about extensively in [this section of the JAX docs](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html); here I'll only cover a couple of commonly encountered examples.\n",
"\n",
"First of all, it's necessary to write \"pure functions\" whose behavior is free of side-effects: the returned values of your function cannot depend on a mutable state defined elsewhere in the program. This is because at compile-time, JAX builds a computational trace of the flow of data through your function, which requires that your function must have deterministic behavior w/r/t each of its arguments. This is why libraries and programs based on JAX tend to be written in a functional style, since this make it simpler to guarantee that we only ever ask JAX to apply its transformations on side-effect-free functions. This can take some getting used to if you're more familiar with object-oriented machine learning libraries like `scikit-learn`, `pytorch` and `tensorflow`. The pure-function requirement of JAX requires you to implement _everything_ in JAX in order to use `jax.jit`. The behavior of your function cannot arbitrarily depend on some other external library such as `numpy` or `scipy`. That brings us to Gotcha #1."
]
},
{
"cell_type": "markdown",
"id": "87ca5efd",
"metadata": {},
"source": [
"### Gotcha 1: Everything inside a jitted function must be implemented in JAX"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "6a7109d1",
"metadata": {},
"outputs": [],
"source": [
"from scipy.special import erf as erf_scipy\n",
"\n",
"\n",
"@jjit\n",
"def jax_func_whoops(x):\n",
" return erf_scipy(x) + x**2"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "382d32d1",
"metadata": {},
"outputs": [],
"source": [
"xarr = np.linspace(0, 2*np.pi, nx)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "305111c8",
"metadata": {},
"outputs": [
{
"ename": "TracerArrayConversionError",
"evalue": "The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(float32[30])>with<DynamicJaxprTrace(level=0/1)>\nThe error occurred while tracing the function jax_func_whoops at /var/folders/vp/g9zy6vt17h18h2dnpqxz6j1w0000gn/T/ipykernel_1727/1702647558.py:4 for jit. This concrete value was not available in Python because it depends on the value of the argument 'x'.\nSee https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mTracerArrayConversionError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn [12], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m jax_func_whoops(xarr)\n",
" \u001b[0;31m[... skipping hidden 14 frame]\u001b[0m\n",
"Cell \u001b[0;32mIn [10], line 6\u001b[0m, in \u001b[0;36mjax_func_whoops\u001b[0;34m(x)\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[38;5;129m@jjit\u001b[39m\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mjax_func_whoops\u001b[39m(x):\n\u001b[0;32m----> 6\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43merf_scipy\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;241m+\u001b[39m x\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m2\u001b[39m\n",
"File \u001b[0;32m~/opt/miniconda3/envs/ht081/lib/python3.9/site-packages/jax/core.py:540\u001b[0m, in \u001b[0;36mTracer.__array__\u001b[0;34m(self, *args, **kw)\u001b[0m\n\u001b[1;32m 539\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__array__\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkw):\n\u001b[0;32m--> 540\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m TracerArrayConversionError(\u001b[38;5;28mself\u001b[39m)\n",
"\u001b[0;31mTracerArrayConversionError\u001b[0m: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(float32[30])>with<DynamicJaxprTrace(level=0/1)>\nThe error occurred while tracing the function jax_func_whoops at /var/folders/vp/g9zy6vt17h18h2dnpqxz6j1w0000gn/T/ipykernel_1727/1702647558.py:4 for jit. This concrete value was not available in Python because it depends on the value of the argument 'x'.\nSee https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError"
]
}
],
"source": [
"jax_func_whoops(xarr)"
]
},
{
"cell_type": "markdown",
"id": "33577f6f",
"metadata": {},
"source": [
"The error message is telling us that the input argument `x` depends upon `scipy.special.erf` and so the function cannot be traced. But JAX has its own implementations of a portion of the functions in the `scipy` library as well. The cell below instead uses the JAX version of this same function."
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "d97df4e3",
"metadata": {},
"outputs": [],
"source": [
"from jax.scipy.special import erf as erf_jax\n",
"\n",
"\n",
"@jjit\n",
"def jax_func_roger(x):\n",
" return erf_jax(x) + x**2"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "4474ed54",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"DeviceArray([ 0. , 0.28764644, 0.6477679 , 1.0644981 ,\n",
" 1.530735 , 2.0480406 , 2.6239219 , 3.2682035 ,\n",
" 3.9900665 , 4.796499 , 5.69204 , 6.67926 ,\n",
" 7.7594447 , 8.933169 , 10.200659 , 11.561998 ,\n",
" 13.01721 , 14.566305 , 16.209284 , 17.946144 ,\n",
" 19.776896 , 21.701525 , 23.720041 , 25.83244 ,\n",
" 28.038723 , 30.338898 , 32.73295 , 35.220886 ,\n",
" 37.802708 , 40.47842 ], dtype=float32)"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"jax_func_roger(xarr)"
]
},
{
"cell_type": "markdown",
"id": "722c3f19",
"metadata": {},
"source": [
"### Gotcha 2: JAX arrays are immutable\n",
"\n",
"Another feature to be aware of is that you cannot do in-place operations on a JAX DeviceArray. That means that the same syntax you may be used to with numpy indexing arrays is not supported by JAX, and a different syntax is required. The few cells below give a common example."
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "7e588211",
"metadata": {},
"outputs": [],
"source": [
"def numpy_clip_allgood(x):\n",
" msk = x < 0.5\n",
" x[msk] = 0.5\n",
" return x"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "c4ed3dd5",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([0.5 , 0.5 , 0.5 , 0.64998469, 0.86664625,\n",
" 1.08330781, 1.29996937, 1.51663094, 1.7332925 , 1.94995406,\n",
" 2.16661562, 2.38327719, 2.59993875, 2.81660031, 3.03326187,\n",
" 3.24992343, 3.466585 , 3.68324656, 3.89990812, 4.11656968,\n",
" 4.33323125, 4.54989281, 4.76655437, 4.98321593, 5.1998775 ,\n",
" 5.41653906, 5.63320062, 5.84986218, 6.06652374, 6.28318531])"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"numpy_clip_allgood(xarr)"
]
},
{
"cell_type": "markdown",
"id": "ce10b494",
"metadata": {},
"source": [
"The Numpy function above just clips its input array from below at 0.5. The line `x[msk] = 0.5` modifies the array `x` in-place, but JAX DeviceArrays are immutable, and so this operation is not supported."
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "1d96d4bf",
"metadata": {},
"outputs": [],
"source": [
"@jjit\n",
"def jax_clip_whoops(x):\n",
" msk = x < 0.5\n",
" x[msk] = 0.5\n",
" return x"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "001b5bc5",
"metadata": {},
"outputs": [
{
"ename": "TypeError",
"evalue": "'<class 'jax.interpreters.partial_eval.DynamicJaxprTracer'>' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn [18], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m jax_clip_whoops(xarr)\n",
" \u001b[0;31m[... skipping hidden 14 frame]\u001b[0m\n",
"Cell \u001b[0;32mIn [17], line 4\u001b[0m, in \u001b[0;36mjax_clip_whoops\u001b[0;34m(x)\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;129m@jjit\u001b[39m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mjax_clip_whoops\u001b[39m(x):\n\u001b[1;32m 3\u001b[0m msk \u001b[38;5;241m=\u001b[39m x \u001b[38;5;241m<\u001b[39m \u001b[38;5;241m0.5\u001b[39m\n\u001b[0;32m----> 4\u001b[0m \u001b[43mx\u001b[49m\u001b[43m[\u001b[49m\u001b[43mmsk\u001b[49m\u001b[43m]\u001b[49m \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0.5\u001b[39m\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m x\n",
" \u001b[0;31m[... skipping hidden 1 frame]\u001b[0m\n",
"File \u001b[0;32m~/opt/miniconda3/envs/ht081/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py:4696\u001b[0m, in \u001b[0;36m_unimplemented_setitem\u001b[0;34m(self, i, x)\u001b[0m\n\u001b[1;32m 4691\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_unimplemented_setitem\u001b[39m(\u001b[38;5;28mself\u001b[39m, i, x):\n\u001b[1;32m 4692\u001b[0m msg \u001b[38;5;241m=\u001b[39m (\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m object does not support item assignment. JAX arrays are \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 4693\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mimmutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 4694\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mor another .at[] method: \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 4695\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mhttps://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m-> 4696\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(msg\u001b[38;5;241m.\u001b[39mformat(\u001b[38;5;28mtype\u001b[39m(\u001b[38;5;28mself\u001b[39m)))\n",
"\u001b[0;31mTypeError\u001b[0m: '<class 'jax.interpreters.partial_eval.DynamicJaxprTracer'>' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html"
]
}
],
"source": [
"jax_clip_whoops(xarr)"
]
},
{
"cell_type": "markdown",
"id": "5649754d",
"metadata": {},
"source": [
"A common code pattern to address this kind of calculation is to use the `jnp.where` function as below."
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "ebe1c5cb",
"metadata": {},
"outputs": [],
"source": [
"@jjit\n",
"def jax_clip_roger(x):\n",
" msk = x < 0.5\n",
" x = jnp.where(msk, x - 1.0, x)\n",
" return x"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "f7687e94",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"DeviceArray([0.5 , 0.5 , 0.5 , 0.64998466, 0.86664623,\n",
" 1.0833079 , 1.2999693 , 1.5166309 , 1.7332925 , 1.949954 ,\n",
" 2.1666157 , 2.3832772 , 2.5999386 , 2.8166003 , 3.0332618 ,\n",
" 3.2499235 , 3.466585 , 3.6832466 , 3.899908 , 4.1165695 ,\n",
" 4.3332314 , 4.549893 , 4.7665544 , 4.983216 , 5.1998773 ,\n",
" 5.416539 , 5.6332006 , 5.849862 , 6.0665236 , 6.2831855 ], dtype=float32)"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"jax_clip_roger(xarr)"
]
},
{
"cell_type": "markdown",
"id": "c8bc89a6",
"metadata": {},
"source": [
"## Computing gradients with `jax.grad`\n",
"\n",
"The cell below shows how to use the autodiff functionality of JAX to calculate the derivative of an input function. The `jax.grad` transformation accepts a input function $f$ and returns a function $f'$ that computes the derivative of $f$."
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "ef3ccd89",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(0.9238795325112867, DeviceArray(0.9238795, dtype=float32, weak_type=True))"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from jax import grad\n",
"\n",
"sin_deriv = jjit(grad(jnp.sin))\n",
"\n",
"np.cos(np.pi/8), sin_deriv(np.pi/8)"
]
},
{
"cell_type": "markdown",
"id": "fa17c566",
"metadata": {},
"source": [
"Whether the cosine function is implemented via `np.cos` or via the autodiff-calculated gradient of the sin function, the results agree to `float32` precision.\n",
"\n",
"The code above evaluated `sin_deriv` with a scalar value $x=\\pi/8$. Now let's try calculating the gradients for an array of inputs:"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "77a001f7",
"metadata": {},
"outputs": [],
"source": [
"cos_result_numpy = np.cos(xarr)"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "f15d9168",
"metadata": {},
"outputs": [
{
"ename": "TypeError",
"evalue": "Gradient only defined for scalar-output functions. Output had shape: (30,).",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn [23], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m cos_result_jax \u001b[38;5;241m=\u001b[39m sin_deriv(xarr)\n",
" \u001b[0;31m[... skipping hidden 18 frame]\u001b[0m\n",
"File \u001b[0;32m~/opt/miniconda3/envs/ht081/lib/python3.9/site-packages/jax/_src/api.py:1166\u001b[0m, in \u001b[0;36m_check_scalar\u001b[0;34m(x)\u001b[0m\n\u001b[1;32m 1164\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(aval, ShapedArray):\n\u001b[1;32m 1165\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m aval\u001b[38;5;241m.\u001b[39mshape \u001b[38;5;241m!=\u001b[39m ():\n\u001b[0;32m-> 1166\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(msg(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mhad shape: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00maval\u001b[38;5;241m.\u001b[39mshape\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m))\n\u001b[1;32m 1167\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 1168\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(msg(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mhad abstract value \u001b[39m\u001b[38;5;132;01m{\u001b[39;00maval\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m))\n",
"\u001b[0;31mTypeError\u001b[0m: Gradient only defined for scalar-output functions. Output had shape: (30,)."
]
}
],
"source": [
"cos_result_jax = sin_deriv(xarr)"
]
},
{
"cell_type": "markdown",
"id": "18593460",
"metadata": {},
"source": [
"The function input to `jax.grad` must return a scalar, not an array. To efficiently evaluate the gradient of a function for arrays, we need to use the `jax.vmap` function."
]
},
{
"cell_type": "markdown",
"id": "b4e8eabc",
"metadata": {},
"source": [
"## Vectorizing functions with `jax.vmap`\n",
"\n",
"\n",
"The `jax.vmap` function accepts transforms the behavior of an input function to operate on arrays of its inputs and to return result arrays of higher dimension. The gradient of the sin function we wrote above accepts and returns scalars, so we can use `jax.vmap` to transform this function to accept and return one-dimensional arrays."
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "ebc65002",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from jax import vmap \n",
"\n",
"sin_deriv_vmap = jjit(vmap(grad(jnp.sin)))\n",
"\n",
"cos_result_jax = sin_deriv_vmap(xarr)\n",
"\n",
"fig, ax = plt.subplots(1, 1)\n",
"__=ax.plot(xarr, cos_result_numpy)\n",
"__=ax.plot(xarr, cos_result_jax, '--')"
]
},
{
"cell_type": "markdown",
"id": "dca4a794",
"metadata": {},
"source": [
"### `jax.vmap` in multiple dimensions\n",
"\n",
"The example of `jax.vmap` above showed how to transform a scalar-valued function into a function that accepts and returns 1d vectors. This same vectorization transformation can be used to construct higher-dimensional operations by composing successive vmap calls.\n",
"\n",
"First we define a $f(x, A, B)=A{\\rm sin}(Bx)$, and then use `jax.grad` to define a new function that calculates $\\partial f/\\partial x$.\n",
"\n",
"Note that we now pass the `argnums` keyword argument to `jax.grad` to specify that we gradient we are interested in is with respect to the first positional argument, $x.$"
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "e03c7a3f",
"metadata": {},
"outputs": [],
"source": [
"@jjit\n",
"def some_func(x, a, b):\n",
" return a*jnp.sin(b*x)\n",
"\n",
"\n",
"some_func_deriv = jjit(grad(some_func, argnums=0))"
]
},
{
"cell_type": "markdown",
"id": "88d0569c",
"metadata": {},
"source": [
"The function `some_func_deriv` is defined as $\\frac{\\partial}{\\partial x} A\\cdot{\\rm sin}(B\\cdot x)=A\\cdot B\\cdot{\\rm cos}(B\\cdot x).$ We have defined `some_func_deriv` based on `jax.grad`, and so it accepts scalars for all its arguments and returns a scalar. We can use the `in_axes` keyword argument of `jax.vmap` to get the corresponding function that accepts an array input for its first argument, `x`, and scalars for the remaining two input arguments, `a` and `b`. For the `in_axes` argument you pass in a tuple with one entry for each positional argument, using a `0` for the entry of an argument that you want vectorized, and `None` for arguments whose form should remain unchanged."
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "a327ed05",
"metadata": {},
"outputs": [],
"source": [
"some_func_deriv_vmap0 = jjit(vmap(some_func_deriv, in_axes=(0, None, None)))"
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "dc29bef8",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"fig, ax = plt.subplots(1, 1)\n",
"\n",
"a, b = 1/3, 3\n",
"__=ax.plot(xarr, some_func_deriv_vmap0(xarr, a, b))\n",
"__=ax.plot(xarr, a*b*np.cos(b*xarr), '--')"
]
},
{
"cell_type": "markdown",
"id": "857111c4",
"metadata": {},
"source": [
"We can use `vmap` again if we want to construct a function that accepts an array for the input argument `a`"
]
},
{
"cell_type": "code",
"execution_count": 28,
"id": "38168dab",
"metadata": {},
"outputs": [],
"source": [
"some_func_deriv_vmap1 = jjit(vmap(some_func_deriv_vmap0, in_axes=(None, 0, None)))"
]
},
{
"cell_type": "code",
"execution_count": 29,
"id": "16b64474",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(5, 30)"
]
},
"execution_count": 29,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"n_a = 5\n",
"a_arr = np.linspace(1, 10, n_a)\n",
"result = some_func_deriv_vmap1(xarr, a_arr, b)\n",
"result.shape"
]
},
{
"cell_type": "markdown",
"id": "f8631bfb",
"metadata": {},
"source": [
"Notice that the returned dimension is `(n_a, n_x)`. This would be flipped if we had applied our vmapping in the reverse order. Each new vmapped argument(s) appears as the first dimension of the returned shape. "
]
},
{
"cell_type": "code",
"execution_count": 30,
"id": "ef3bb85c",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"fig, ax = plt.subplots(1, 1)\n",
"for ia in range(result.shape[0]):\n",
" __=ax.plot(xarr, result[ia, :])"
]
},
{
"cell_type": "markdown",
"id": "99860319",
"metadata": {},
"source": [
"Using `vmap` once more we can vectorize the `b` argument"
]
},
{
"cell_type": "code",
"execution_count": 31,
"id": "b64218e3",
"metadata": {},
"outputs": [],
"source": [
"some_func_deriv_vmap2 = jjit(vmap(some_func_deriv_vmap1, in_axes=(None, None, 0)))"
]
},
{
"cell_type": "code",
"execution_count": 32,
"id": "89aa9f15",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(7, 5, 30)"
]
},
"execution_count": 32,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"n_b = 7\n",
"b_arr = np.linspace(0.5, 2, n_b)\n",
"result = some_func_deriv_vmap2(xarr, a_arr, b_arr)\n",
"result.shape"
]
},
{
"cell_type": "code",
"execution_count": 33,
"id": "5a994849",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"fig, ax = plt.subplots(1, 1)\n",
"\n",
"ia_fixed = 0\n",
"for ib in range(result.shape[0]):\n",
" __=ax.plot(xarr, result[ib, ia_fixed, :])"
]
},
{
"cell_type": "markdown",
"id": "4372dc3d",
"metadata": {},
"source": [
"The arguments `a` and `b` can also be vectorized together in situations where these arguments vary within the same dimension"
]
},
{
"cell_type": "code",
"execution_count": 34,
"id": "d7497b32",
"metadata": {},
"outputs": [],
"source": [
"some_func_deriv_vmap3 = jjit(vmap(some_func_deriv_vmap0, in_axes=(None, 0, 0)))"
]
},
{
"cell_type": "code",
"execution_count": 35,
"id": "d29acf87",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(5, 30)"
]
},
"execution_count": 35,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"n_b = n_a\n",
"b_arr = np.linspace(0.5, 2, n_b)\n",
"result = some_func_deriv_vmap3(xarr, a_arr, b_arr)\n",
"result.shape"
]
},
{
"cell_type": "code",
"execution_count": 36,
"id": "3e8bb18a",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"fig, ax = plt.subplots(1, 1)\n",
"\n",
"for iab in range(result.shape[0]):\n",
" __=ax.plot(xarr, result[iab, :])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "01a57392",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment