Skip to content

Instantly share code, notes, and snippets.

@avivajpeyi
Forked from dfm/intro-to-jax-part2.ipynb
Last active June 29, 2022 19:19
Show Gist options
  • Save avivajpeyi/1ca004c1129ad9cf816e6fe57a4178d6 to your computer and use it in GitHub Desktop.
Save avivajpeyi/1ca004c1129ad9cf816e6fe57a4178d6 to your computer and use it in GitHub Desktop.
intro-to-jax-part2.ipynb
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "ada14b6a-d989-4aa4-8b71-0d870933eb13",
"metadata": {
"id": "ada14b6a-d989-4aa4-8b71-0d870933eb13"
},
"source": [
"# Introduction to JAX (Part 2)\n",
"\n",
"We'll start with a re-cap of the previous \"intro to jax\" session with (hopefully!) enough info to get people who weren't there caught up.\n",
"\n",
"This tutorial includes a whirlwind introduction to JAX. It's going to be pretty incomplete so, if you want more info, check out the [JAX docs](https://jax.readthedocs.io).\n",
"\n",
"We'll pretty much always want to include this line since JAX normally operates with single point precision:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9dc1746e-58c1-4ae6-824c-f0e0865a784a",
"metadata": {
"id": "9dc1746e-58c1-4ae6-824c-f0e0865a784a"
},
"outputs": [],
"source": [
"import jax\n",
"\n",
"# In many cases you may want to enable support for double precision\n",
"# jax.config.update(\"jax_enable_x64\", True)"
]
},
{
"cell_type": "markdown",
"id": "b2f1843d-fe6d-428d-9206-08599ce547de",
"metadata": {
"id": "b2f1843d-fe6d-428d-9206-08599ce547de"
},
"source": [
"## `jax.numpy`\n",
"\n",
"`jax.numpy` works just like `numpy` (almost always):"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "467fe662-819a-4c5a-b179-8e5d1cb21076",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "467fe662-819a-4c5a-b179-8e5d1cb21076",
"outputId": "e2cb6afd-6d9c-46a9-f5b2-ae08a669b443"
},
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n"
]
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(DeviceArray([0.1 , 1.325, 2.55 , 3.775, 5. ], dtype=float32),\n",
" DeviceArray([ 0.09983342, 0.9699439 , 0.55768377, -0.5918946 ,\n",
" -0.9589243 ], dtype=float32))"
]
},
"metadata": {},
"execution_count": 3
}
],
"source": [
"import jax.numpy as jnp\n",
"\n",
"x = jnp.linspace(0.1, 5.0, 5)\n",
"y = jnp.sin(x)\n",
"x, y"
]
},
{
"cell_type": "markdown",
"id": "e0b320f6-d4fd-4f59-bf0a-8cbd748215be",
"metadata": {
"id": "e0b320f6-d4fd-4f59-bf0a-8cbd748215be"
},
"source": [
"We can combine regular `numpy` and `jax.numpy`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8e108d6f-5bf7-483f-8485-7bab4533f92f",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "8e108d6f-5bf7-483f-8485-7bab4533f92f",
"outputId": "cdfd2bd5-227a-4dcc-be5d-19139a0aa104"
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(array([0.1 , 1.325, 2.55 , 3.775, 5. ]),\n",
" DeviceArray([ 0.09983342, 0.9699439 , 0.55768377, -0.5918946 ,\n",
" -0.9589243 ], dtype=float32))"
]
},
"metadata": {},
"execution_count": 3
}
],
"source": [
"import numpy as np\n",
"\n",
"x = np.linspace(0.1, 5.0, 5)\n",
"y = jnp.sin(x)\n",
"x, y"
]
},
{
"cell_type": "markdown",
"id": "9282cf37-e89e-4f3a-abfc-b247538f6e4b",
"metadata": {
"id": "9282cf37-e89e-4f3a-abfc-b247538f6e4b"
},
"source": [
"## `jax.jit`\n",
"\n",
"We use `jax.jit` to fuse operations, and run them on the GPU, for example.\n",
"One of the key points to remember when using JAX is that it works best in a \"functional\" style.\n",
"A lot of the key JAX functions take a function as input and return a new function.\n",
"For example:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cc4dc0ac-b240-47d1-91d3-f1edbb0a0526",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "cc4dc0ac-b240-47d1-91d3-f1edbb0a0526",
"outputId": "f078464d-67f4-43d8-8983-8d2c2acf5e91"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"hi from this function\n"
]
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"DeviceArray([2.6049867, 4.1377964, 3.246622 , 2.053278 , 1.883305 ], dtype=float32)"
]
},
"metadata": {},
"execution_count": 4
}
],
"source": [
"def jnp_function(x):\n",
" print(\"hi from this function\")\n",
" arg = jnp.sin(x)\n",
" return 1.5 + jnp.exp(arg)\n",
"\n",
"jitted_function = jax.jit(jnp_function)\n",
"\n",
"jitted_function(x)"
]
},
{
"cell_type": "markdown",
"id": "8c7d5f64-7797-46d8-ac54-4c46cb6d9525",
"metadata": {
"id": "8c7d5f64-7797-46d8-ac54-4c46cb6d9525"
},
"source": [
"What happens if we call that function again?"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "79bac88b-202b-42b2-889c-16afe755f667",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "79bac88b-202b-42b2-889c-16afe755f667",
"outputId": "4c21b7ef-c766-4415-ce9f-78757d165879"
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"DeviceArray([2.6049867, 4.1377964, 3.246622 , 2.053278 , 1.883305 ], dtype=float32)"
]
},
"metadata": {},
"execution_count": 5
}
],
"source": [
"jitted_function(x)"
]
},
{
"cell_type": "markdown",
"id": "5e7c75a4-23da-4c28-9ee0-1e5baa960327",
"metadata": {
"id": "5e7c75a4-23da-4c28-9ee0-1e5baa960327"
},
"source": [
"What about if we call it with a different input?"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "35573796-3a57-44ab-94f6-dcd34c481f97",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "35573796-3a57-44ab-94f6-dcd34c481f97",
"outputId": "6b5c0dea-ea02-493e-ac8c-1d7592af2fb9"
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"DeviceArray([2.6048036, 3.7815475, 3.1976116, 2.0723903, 1.9410601], dtype=float32)"
]
},
"metadata": {},
"execution_count": 6
}
],
"source": [
"jitted_function(np.sin(x))"
]
},
{
"cell_type": "markdown",
"id": "020331f8-2033-4937-8273-7565ab26fd6d",
"metadata": {
"id": "020331f8-2033-4937-8273-7565ab26fd6d"
},
"source": [
"What about an input with a different shape?"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1353bc5c-6a1d-4d12-a9be-28e9204e2c95",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "1353bc5c-6a1d-4d12-a9be-28e9204e2c95",
"outputId": "7be7f76c-4f73-405c-8a0b-945d23ba105f"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"hi from this function\n"
]
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"DeviceArray([2.6049867, 4.1377964, 3.246622 , 2.053278 ], dtype=float32)"
]
},
"metadata": {},
"execution_count": 7
}
],
"source": [
"jitted_function(x[:-1])"
]
},
{
"cell_type": "markdown",
"id": "b2990af6-a1b8-426a-aa5e-7560c97b64d6",
"metadata": {
"id": "b2990af6-a1b8-426a-aa5e-7560c97b64d6"
},
"source": [
"*Note:* It is common to use `jax.jit` as a \"decorator\":"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d997f1d0-8117-425b-bd8a-889be7cb821f",
"metadata": {
"id": "d997f1d0-8117-425b-bd8a-889be7cb821f"
},
"outputs": [],
"source": [
"@jax.jit\n",
"def jitted_function(x):\n",
" arg = jnp.sin(x)\n",
" return 1.5 + jnp.exp(arg)"
]
},
{
"cell_type": "markdown",
"source": [
"What about control flow?"
],
"metadata": {
"id": "Se90GwOHOhP3"
},
"id": "Se90GwOHOhP3"
},
{
"cell_type": "code",
"source": [
"@jax.jit\n",
"def incorrect_conditional_func(x):\n",
" if jnp.all(x > 0):\n",
" return x\n",
" arg = jnp.sin(x)\n",
" return 1.5 + jnp.exp(arg)"
],
"metadata": {
"id": "XG8W5MNUOjkk"
},
"id": "XG8W5MNUOjkk",
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# What happens if we run this?\n",
"# incorrect_conditional_func(x)"
],
"metadata": {
"id": "p5ZfoQGvOopU"
},
"id": "p5ZfoQGvOopU",
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"@jax.jit\n",
"def correct_conditional_func(x):\n",
" arg = jnp.sin(x)\n",
" return jnp.where(jnp.all(x > 0), x, 1.5 + jnp.exp(arg))"
],
"metadata": {
"id": "iIIaFa_1O88z"
},
"id": "iIIaFa_1O88z",
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"correct_conditional_func(x)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "5HNdGm6gPMZ_",
"outputId": "12092eb1-8b97-47a4-c827-0972934d2422"
},
"id": "5HNdGm6gPMZ_",
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"DeviceArray([0.1 , 1.325, 2.55 , 3.775, 5. ], dtype=float32)"
]
},
"metadata": {},
"execution_count": 12
}
]
},
{
"cell_type": "markdown",
"id": "cdc4a86a-cd09-4d4d-a5d0-55438cc9c02b",
"metadata": {
"id": "cdc4a86a-cd09-4d4d-a5d0-55438cc9c02b"
},
"source": [
"## `jax.vmap`\n",
"\n",
"`jax.vmap` gives a mechanism for applying a \"scalar\" function on a vector of inputs.\n",
"The same effects can often be achieved by manually broadcasting, but `vmap` comes in handy more often than you might think."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "136772d8-a384-4dc3-87ba-e967e5bfbf7c",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "136772d8-a384-4dc3-87ba-e967e5bfbf7c",
"outputId": "e5eb8a3d-e696-4880-94e9-5074c315e837"
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"DeviceArray([[[ 1.19428426e-01, 2.83938229e-01, 1.14193819e-01],\n",
" [ 2.83938229e-01, 6.75056338e-01, 2.71493077e-01],\n",
" [ 1.14193819e-01, 2.71493077e-01, 1.09188654e-01]],\n",
"\n",
" [[ 1.69821870e+00, -1.17982101e+00, -5.81696212e-01],\n",
" [-1.17982101e+00, 8.19669247e-01, 4.04127836e-01],\n",
" [-5.81696212e-01, 4.04127836e-01, 1.99250251e-01]],\n",
"\n",
" [[ 2.88318753e-01, -3.12033236e-01, -1.95758328e-01],\n",
" [-3.12033236e-01, 3.37698251e-01, 2.11859629e-01],\n",
" [-1.95758328e-01, 2.11859629e-01, 1.32913038e-01]],\n",
"\n",
" [[ 8.65139291e-02, 8.35990533e-03, 1.60806060e-01],\n",
" [ 8.35990533e-03, 8.07823846e-04, 1.55388089e-02],\n",
" [ 1.60806060e-01, 1.55388089e-02, 2.98895091e-01]],\n",
"\n",
" [[ 5.42364597e-01, 1.19975701e-01, 3.55058730e-01],\n",
" [ 1.19975701e-01, 2.65396535e-02, 7.85420388e-02],\n",
" [ 3.55058730e-01, 7.85420388e-02, 2.32439041e-01]]], dtype=float32)"
]
},
"metadata": {},
"execution_count": 13
}
],
"source": [
"A = np.random.default_rng(1).normal(size=(5, 3))\n",
"\n",
"def scalar_function(x):\n",
" return jnp.outer(x, x)\n",
"\n",
"vector_function = jax.vmap(scalar_function)\n",
"vector_function(A)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4a861a6d-f6b2-415b-a3d7-a9eba66b6df8",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "4a861a6d-f6b2-415b-a3d7-a9eba66b6df8",
"outputId": "cb2d525b-6174-4e9d-818c-61077554e40b"
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"array([[[ 1.19428434e-01, 2.83938242e-01, 1.14193830e-01],\n",
" [ 2.83938242e-01, 6.75056374e-01, 2.71493097e-01],\n",
" [ 1.14193830e-01, 2.71493097e-01, 1.09188661e-01]],\n",
"\n",
" [[ 1.69821877e+00, -1.17982104e+00, -5.81696252e-01],\n",
" [-1.17982104e+00, 8.19669245e-01, 4.04127838e-01],\n",
" [-5.81696252e-01, 4.04127838e-01, 1.99250259e-01]],\n",
"\n",
" [[ 2.88318777e-01, -3.12033246e-01, -1.95758328e-01],\n",
" [-3.12033246e-01, 3.37698251e-01, 2.11859620e-01],\n",
" [-1.95758328e-01, 2.11859620e-01, 1.32913032e-01]],\n",
"\n",
" [[ 8.65139256e-02, 8.35990480e-03, 1.60806056e-01],\n",
" [ 8.35990480e-03, 8.07823801e-04, 1.55388084e-02],\n",
" [ 1.60806056e-01, 1.55388084e-02, 2.98895090e-01]],\n",
"\n",
" [[ 5.42364622e-01, 1.19975697e-01, 3.55058738e-01],\n",
" [ 1.19975697e-01, 2.65396512e-02, 7.85420322e-02],\n",
" [ 3.55058738e-01, 7.85420322e-02, 2.32439032e-01]]])"
]
},
"metadata": {},
"execution_count": 14
}
],
"source": [
"A[:, None, :] * A[:, :, None]"
]
},
{
"cell_type": "markdown",
"id": "a37c40b8-8a99-46db-860b-a04b2918b976",
"metadata": {
"id": "a37c40b8-8a99-46db-860b-a04b2918b976"
},
"source": [
"## `jax.grad`\n",
"\n",
"Any JAX function can also be differentiated."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fc0e710d-962b-44e7-8649-51b88e47c806",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "fc0e710d-962b-44e7-8649-51b88e47c806",
"outputId": "501f0dd5-9aee-4df1-c3d1-ff14e7e8bf15"
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"DeviceArray(1.4174242, dtype=float32, weak_type=True)"
]
},
"metadata": {},
"execution_count": 15
}
],
"source": [
"grad_function = jax.grad(jitted_function)\n",
"grad_function(0.5)"
]
},
{
"cell_type": "markdown",
"id": "fc51106f-1379-4b53-95f1-031a5d67264d",
"metadata": {
"id": "fc51106f-1379-4b53-95f1-031a5d67264d"
},
"source": [
"By default, differentiation is only supported for scalar outputs:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c343dc80-4d7e-461b-b520-34dfadff76f2",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "c343dc80-4d7e-461b-b520-34dfadff76f2",
"outputId": "eb595fe6-4fda-4bc1-a331-b81f18737388"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
" with pytest.raises(TypeError) as info:\n",
"> jax.grad(jitted_function)(x)\n",
"E TypeError: Gradient only defined for scalar-output functions. Output had shape: (5,).\n",
"\n",
"<ipython-input-16-b967da8725c7>:4: TypeError\n"
]
}
],
"source": [
"import pytest\n",
"\n",
"with pytest.raises(TypeError) as info:\n",
" jax.grad(jitted_function)(x)\n",
"print(\"\\n\\n\".join(str(info.getrepr()).split(\"\\n\\n\")[-2:]))"
]
},
{
"cell_type": "markdown",
"id": "6d2a14ce-d5bd-436f-801a-71242d1167f9",
"metadata": {
"id": "6d2a14ce-d5bd-436f-801a-71242d1167f9"
},
"source": [
"But we can combine `grad` with `vmap` to get the derivative at each input point:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "59edef16-1a86-4ce4-8ac6-5e470090c242",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "59edef16-1a86-4ce4-8ac6-5e470090c242",
"outputId": "fc867951-c5f6-4e37-88d2-3b98353b8767"
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"DeviceArray([ 1.0994664 , 0.6418517 , -1.4497899 , -0.4459506 ,\n",
" 0.10872913], dtype=float32)"
]
},
"metadata": {},
"execution_count": 17
}
],
"source": [
"jax.vmap(jax.grad(jitted_function))(x)"
]
},
{
"cell_type": "markdown",
"id": "fb4d5bbf-4e23-4cd9-be05-d21bf89094fe",
"metadata": {
"id": "fb4d5bbf-4e23-4cd9-be05-d21bf89094fe"
},
"source": [
"Another useful function is `jax.value_and_grad`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ce76be99-0778-4859-90cf-6fce4c1cf3cd",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "ce76be99-0778-4859-90cf-6fce4c1cf3cd",
"outputId": "dc0ae4a5-810d-42f3-da39-e5a4a0623bec"
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(DeviceArray([2.6049867, 4.1377964, 3.246622 , 2.053278 , 1.883305 ], dtype=float32),\n",
" DeviceArray([ 1.0994664 , 0.6418517 , -1.4497899 , -0.4459506 ,\n",
" 0.10872913], dtype=float32))"
]
},
"metadata": {},
"execution_count": 18
}
],
"source": [
"jax.vmap(jax.value_and_grad(jitted_function))(x)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ba508356-266d-49cb-b6a9-9fbc46c0b7bf",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 265
},
"id": "ba508356-266d-49cb-b6a9-9fbc46c0b7bf",
"outputId": "f43e9d52-701c-4309-fb89-352670f40497"
},
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
],
"image/png": "\n"
},
"metadata": {
"needs_background": "light"
}
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"x_grid = jnp.linspace(-5, 5, 100)\n",
"value, grad = jax.vmap(jax.value_and_grad(jitted_function))(x_grid)\n",
"plt.plot(x_grid, value, label=\"value\")\n",
"plt.plot(x_grid, grad, label=\"grad\")\n",
"plt.legend();"
]
},
{
"cell_type": "code",
"source": [
"@jax.jit\n",
"def f(x):\n",
" y = jnp.exp(-2.0 * x)\n",
" return (1.0 - y) / (1.0 + y)\n",
"\n",
"dfdx = jax.grad(f)\n",
"d2fdx = jax.grad(dfdx)\n",
"d3fdx = jax.grad(d2fdx)\n",
"d4fdx = jax.grad(d3fdx)\n",
"\n",
"x = jnp.linspace(-4,4, 200)\n",
"plt.plot(x, f(x), label=\"f\")\n",
"plt.plot(x, jax.vmap(dfdx)(x), label=\"f'\")\n",
"plt.plot(x, jax.vmap(d2fdx)(x), label=\"f''\")\n",
"plt.plot(x, jax.vmap(d3fdx)(x), label=\"f'''\")\n",
"plt.plot(x, jax.vmap(d4fdx)(x), label=\"f''''\")\n",
"plt.legend(frameon=False, loc='upper right')\n",
"plt.gca().axis('off')\n",
"plt.show()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 248
},
"id": "EqDMScSD4mhD",
"outputId": "e60899b5-6523-43b9-cbc8-e059fa391567"
},
"id": "EqDMScSD4mhD",
"execution_count": null,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
],
"image/png": "\n"
},
"metadata": {
"needs_background": "light"
}
}
]
},
{
"cell_type": "markdown",
"id": "f0704757-2bd8-4439-a0f7-8d17dd1a138a",
"metadata": {
"id": "f0704757-2bd8-4439-a0f7-8d17dd1a138a"
},
"source": [
"## PyTrees\n",
"\n",
"Another useful JAX concept is \"PyTrees\".\n",
"This allows us to use structured inputs and still use `jit`, `vmap`, and `grad`.\n",
"For example:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "db182d2d-47c3-4f63-88cc-f0f316fb0ad8",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "db182d2d-47c3-4f63-88cc-f0f316fb0ad8",
"outputId": "32cf977c-757a-42fa-d6c7-af71c435c4fa"
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"DeviceArray(0.02227585, dtype=float32, weak_type=True)"
]
},
"metadata": {},
"execution_count": 20
}
],
"source": [
"def pytree_func(params):\n",
" return jnp.exp(params[\"log_amp\"]) * jnp.sin(params[\"log_scale\"])\n",
"\n",
"params = {\n",
" \"log_amp\": -1.5,\n",
" \"log_scale\": 0.1,\n",
"}\n",
"pytree_func(params)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "59f7514e-4c74-4a64-9f56-bfec4469c41c",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "59f7514e-4c74-4a64-9f56-bfec4469c41c",
"outputId": "8b8e8341-ed82-487f-9387-da348e724b9a"
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"{'log_amp': DeviceArray(0.02227585, dtype=float32, weak_type=True),\n",
" 'log_scale': DeviceArray(0.22201544, dtype=float32, weak_type=True)}"
]
},
"metadata": {},
"execution_count": 21
}
],
"source": [
"jax.grad(pytree_func)(params)"
]
},
{
"cell_type": "markdown",
"source": [
"## Random numbers\n",
"\n",
"Random number generation in JAX is a little different from in numpy.\n",
"For example, every random function takes a \"key\" as input:"
],
"metadata": {
"id": "RFR4UqrmMoFQ"
},
"id": "RFR4UqrmMoFQ"
},
{
"cell_type": "code",
"source": [
"from jax import random\n",
"\n",
"key = random.PRNGKey(42)\n",
"random.normal(key)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "cTAI87MPMnvI",
"outputId": "8c93bd9c-31ab-4cfe-abb1-f5dd7cd04135"
},
"id": "cTAI87MPMnvI",
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"DeviceArray(-0.18471177, dtype=float32)"
]
},
"metadata": {},
"execution_count": 22
}
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0b6d64a9-6361-4f05-b399-29312c15ef31",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "0b6d64a9-6361-4f05-b399-29312c15ef31",
"outputId": "8644972a-d897-4877-c0be-e1dd64da92ef"
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"DeviceArray(-0.18471177, dtype=float32)"
]
},
"metadata": {},
"execution_count": 23
}
],
"source": [
"random.normal(key)"
]
},
{
"cell_type": "markdown",
"source": [
"If you want to generate multiple different random numbers, a good approach is to \"split\" the key."
],
"metadata": {
"id": "sOY38c39NQNN"
},
"id": "sOY38c39NQNN"
},
{
"cell_type": "code",
"source": [
"key1, key2 = random.split(key)\n",
"random.normal(key1), random.uniform(key2)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "9v5Hob9_NEKM",
"outputId": "06b28e51-f6ed-40a5-8845-f95ad1a81292"
},
"id": "9v5Hob9_NEKM",
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(DeviceArray(0.13790321, dtype=float32),\n",
" DeviceArray(0.91457367, dtype=float32))"
]
},
"metadata": {},
"execution_count": 24
}
]
},
{
"cell_type": "markdown",
"source": [
"## Optimizers\n",
"\n",
"The JAX ecosystem is pretty modular and there are various packages available for non-linear function optimization.\n",
"Some popular ones include [jaxopt](https://github.com/google/jaxopt) (\"scipy.optimize with support for PyTrees\") and [optax](https://github.com/deepmind/optax) (\"feature-rich framework with a lot more boilerplate\")."
],
"metadata": {
"id": "u01iPar9Wc8B"
},
"id": "u01iPar9Wc8B"
},
{
"cell_type": "code",
"source": [
"%pip install -q jaxopt optax"
],
"metadata": {
"id": "LPVdUpTUNLLU"
},
"id": "LPVdUpTUNLLU",
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"import jaxopt\n",
"import optax\n",
"\n",
"def loss(params):\n",
" return jnp.sum(jnp.square(params[\"x\"]))\n",
"\n",
"params = {\"x\": 12.5}\n",
"opt = jaxopt.ScipyMinimize(fun=loss)\n",
"soln = opt.run(params)\n",
"print(soln)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "1oYP60kxXf1v",
"outputId": "e67eddb6-d155-4d25-f599-4607aef8d41e"
},
"id": "1oYP60kxXf1v",
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"OptStep(params={'x': DeviceArray(4.7211597e-07, dtype=float32)}, state=ScipyMinimizeInfo(fun_val=DeviceArray(2.228935e-13, dtype=float32, weak_type=True), success=True, status=0, iter_num=2))\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"params = {\"x\": 12.5}\n",
"opt = optax.sgd(0.1)\n",
"opt_state = opt.init(params)\n",
"\n",
"@jax.jit\n",
"def train(params, opt_state):\n",
" value, grads = jax.value_and_grad(loss)(params)\n",
" updates, opt_state = opt.update(grads, opt_state)\n",
" params = optax.apply_updates(params, updates)\n",
" return value, params, opt_state\n",
"\n",
"losses = []\n",
"for _ in range(100):\n",
" value, params, opt_state = train(params, opt_state)\n",
" losses.append(value)\n",
"params"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "4FUgjJwXXuq8",
"outputId": "bf12e9a8-d442-414d-f23a-cb0693cca0b7"
},
"id": "4FUgjJwXXuq8",
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"{'x': DeviceArray(2.5462958e-09, dtype=float32)}"
]
},
"metadata": {},
"execution_count": 27
}
]
},
{
"cell_type": "code",
"source": [
"import matplotlib.pyplot as plt\n",
"plt.plot(losses)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 282
},
"id": "rF30CR93YdEP",
"outputId": "9c201c4c-163a-4b95-b542-e57bdf4e9aa4"
},
"id": "rF30CR93YdEP",
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"[<matplotlib.lines.Line2D at 0x7fbb7f52a710>]"
]
},
"metadata": {},
"execution_count": 28
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
],
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD4CAYAAAAXUaZHAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAY1UlEQVR4nO3da5Bc5X3n8e9vZnpukmZGl9EIJBEJGKCw115TAwjsuDDYMfgmXrhiSDaWs0qpnGUTJ/auA+uqZfeFq+xdV4i92aVWCxjhpcAEE0O5iB0sY1PZGOERdgCJiwZhYFSSZkDoLjS3/77o06g1mmFGfZlWn/P7VKmm+5zT3f/jg3/99NNPP48iAjMzS5eGWhdgZmaV53A3M0shh7uZWQo53M3MUsjhbmaWQk21LgBgyZIlsWrVqlqXYWZWV7Zu3fpGRHRPte+MCPdVq1bR399f6zLMzOqKpFen2zdjt4ykuyQNSXpu0vY/k/SCpG2S/lvR9lskDUh6UdLHyyvdzMxKMZuW+93A3wL3FDZI+giwFnh/RByXtDTZfjFwA/Ae4Gzgp5IuiIjxShduZmbTm7HlHhFPAPsmbf5T4BsRcTw5ZijZvha4PyKOR8QrwABwWQXrNTOzWSh1tMwFwO9K2iLpF5IuTbYvB14vOm4w2XYKSRsk9UvqHx4eLrEMMzObSqnh3gQsAtYA/xF4QJJO5wkiYmNE9EVEX3f3lF/2mplZiUoN90Hgoch7CpgAlgC7gJVFx61ItpmZ2RwqNdx/CHwEQNIFQDPwBvAIcIOkFkmrgV7gqUoUamZmszeboZD3Ab8ELpQ0KGk9cBdwbjI88n5gXdKK3wY8AGwHfgzcVM2RMi/uOcS3fvIibx4+Xq2XMDOrSzMOhYyIG6fZ9W+mOf7rwNfLKWq2dg4f5m8fH+CT7zuLxfNb5uIlzczqQl3PLdPW3AjA0ZGxGldiZnZmqetwb2/Of/A4OuLfSJmZFavzcC+03B3uZmbF6jrcC90yxxzuZmYnqetwd8vdzGxq9R3uuUKfu79QNTMrVtfh7m4ZM7Op1XW4Nzc10NQgjo463M3MitV1uEO+9e6Wu5nZyeo+3NubG93nbmY2SQrCvcmjZczMJqn7cG/LuVvGzGyyug/3fLeMw93MrFj9h3tLk0fLmJlNUv/hnmvkmL9QNTM7Sf2Hu7tlzMxOUffh7nHuZmanms0ye3dJGkqW1Ju87yuSQtKS5L4kfUfSgKRnJF1SjaKLueVuZnaq2bTc7waunbxR0krg94DXijZfR35R7F5gA3B7+SW+u7bmJo6NjjMxEdV+KTOzujFjuEfEE8C+KXbdBnwVKE7VtcA9yWLZTwJdks6qSKXTKEz7e8wjZszM3lFSn7uktcCuiPiXSbuWA68X3R9Mtk31HBsk9UvqHx4eLqUMwHO6m5lN5bTDXVI78J+A/1zOC0fExojoi4i+7u7ukp+nLedpf83MJmsq4THnAauBf5EEsAJ4WtJlwC5gZdGxK5JtVfPOItmjHutuZlZw2i33iHg2IpZGxKqIWEW+6+WSiNgDPAJ8Phk1swY4EBG7K1vyydwtY2Z2qtkMhbwP+CVwoaRBSevf5fBHgZ3AAPB/gH9XkSrfhVdjMjM71YzdMhFx4wz7VxXdDuCm8suaPbfczcxOVfe/UD0R7u5zNzMrqPtwb0u+UHW3jJnZCXUf7u05d8uYmU1W9+He5l+ompmdou7DvaWpgQa5z93MrFjdh7skL5JtZjZJ3Yc75EfM+AtVM7MTUhPubrmbmZ2QinBvc7eMmdlJUhHu+Za7v1A1MytIUbi75W5mVpCKcG/L+QtVM7NiqQj39uZGz+duZlYkFeHe1tzklruZWZFUhLv73M3MTpaacD82Ok5+OnkzM5vNSkx3SRqS9FzRtv8u6QVJz0j6e0ldRftukTQg6UVJH69W4cXamhuJgLdHJ+bi5czMznizabnfDVw7adtjwHsj4n3AS8AtAJIuBm4A3pM85n9JaqxYtdM4Me2vv1Q1M4NZhHtEPAHsm7TtHyOikKRPAiuS22uB+yPieES8Qn4t1csqWO+U2pMFO9zvbmaWV4k+938L/ENyeznwetG+wWTbKSRtkNQvqX94eLisAjynu5nZycoKd0lfA8aAe0/3sRGxMSL6IqKvu7u7nDK8SLaZ2SRNpT5Q0heATwHXxIlhKruAlUWHrUi2VVWbF8k2MztJSS13SdcCXwU+ExFHi3Y9AtwgqUXSaqAXeKr8Mt9duxfJNjM7yYwtd0n3AVcBSyQNAreSHx3TAjwmCeDJiPhiRGyT9ACwnXx3zU0RUfXEdbeMmdnJZgz3iLhxis13vsvxXwe+Xk5Rp6sQ7m65m5nlpeQXqoWhkO5zNzOD1IR70i3joZBmZkBKwr2lqQEJjh53uJuZQUrCXRLtOc8MaWZWkIpwh2ROdy/YYWYGpCjcPae7mdkJDnczsxRKTbi3NXuRbDOzgtSEe77l7j53MzNIUbi35ZrcLWNmlkhNuBfWUTUzs5SFu1vuZmZ5qQl3f6FqZnZCasK98IXqiXVDzMyyK0Xh3sREwPGxiVqXYmZWc6kJ97ac53Q3MyuYMdwl3SVpSNJzRdsWSXpM0o7k78JkuyR9R9KApGckXVLN4ot52l8zsxNm03K/G7h20rabgc0R0QtsTu4DXEd+3dReYANwe2XKnFnbO6sx+YdMZmYzhntEPAHsm7R5LbApub0JuL5o+z2R9yTQJemsShX7bua9sxqTW+5mZqX2ufdExO7k9h6gJ7m9HHi96LjBZNspJG2Q1C+pf3h4uMQyTmhvybfcDx93y93MrOwvVCM/9vC0xx9GxMaI6IuIvu7u7nLLoKM1B8DBYw53M7NSw31vobsl+TuUbN8FrCw6bkWyreq62vPhfuDYyFy8nJnZGa3UcH8EWJfcXgc8XLT988momTXAgaLum6rqbCuE++hcvJyZ2RmtaaYDJN0HXAUskTQI3Ap8A3hA0nrgVeD3k8MfBT4BDABHgT+uQs1Tmt/SRGODHO5mZswi3CPixml2XTPFsQHcVG5RpZBER2sT+4863M3MUvMLVYCu9ma33M3MSFm4d7TlHO5mZqQs3Dvbchx0uJuZpSvcu9py7He4m5mlK9w73S1jZgakMNwPHhtlYsILdphZtqUq3Lvac0wEHPbMkGaWcakK947Cr1Q91t3MMi5V4e4pCMzM8hzuZmYplKpwPzEzpMPdzLItVeFeaLl7fhkzy7pUhrtb7maWdakK97ZcI82NDQ53M8u8VIW7pGTyMK/GZGbZlqpwB+hsa3LL3cwyL4Xh7vllzMzKCndJfylpm6TnJN0nqVXSaklbJA1I+r6k5koVOxtesMPMrIxwl7Qc+HOgLyLeCzQCNwDfBG6LiPOBt4D1lSh0tjrbch4KaWaZV263TBPQJqkJaAd2A1cDDyb7NwHXl/kap8XdMmZmZYR7ROwCvgW8Rj7UDwBbgf0RUZiWcRBYPtXjJW2Q1C+pf3h4uNQyTtHZluPQ22OMe9pfM8uwcrplFgJrgdXA2cA84NrZPj4iNkZEX0T0dXd3l1rGKQo/ZDr0tlvvZpZd5XTLfBR4JSKGI2IUeAj4INCVdNMArAB2lVnjafEUBGZm5YX7a8AaSe2SBFwDbAceBz6bHLMOeLi8Ek+PJw8zMyuvz30L+S9OnwaeTZ5rI/BXwJclDQCLgTsrUOeseX4ZM7P8aJeSRcStwK2TNu8ELivnecvhcDczS+kvVAH2O9zNLMNSF+6FdVQPOtzNLMNSF+6tuUZac57218yyLXXhDoUpCDztr5llVyrDvavNk4eZWbalMtw9v4yZZV0qwz2/GtPYzAeamaVUKsO9sy3HAfe5m1mGpTLcu9rdLWNm2ZbKcO9sy3FkZJzR8Ylal2JmVhOpDXfwD5nMLLtSGe6FmSE9BYGZZVUqw73Dk4eZWcalMtzfmRnSC3aYWUalMtyXzGsBYPjw8RpXYmZWG6kM96Ud+XDfe+DtGldiZlYbZYW7pC5JD0p6QdLzkq6QtEjSY5J2JH8XVqrY2WrNNbKwPceegw53M8umclvu3wZ+HBEXAe8HngduBjZHRC+wObk/53o6WtnrcDezjCo53CV1Ah8mWSM1IkYiYj+wFtiUHLYJuL7cIkuxrLPVLXczy6xyWu6rgWHgu5J+LekOSfOAnojYnRyzB+iZ6sGSNkjql9Q/PDxcRhlTW9bRyp4D/kLVzLKpnHBvAi4Bbo+IDwBHmNQFExEBxFQPjoiNEdEXEX3d3d1llDG1no5W3jxy3FMQmFkmlRPug8BgRGxJ7j9IPuz3SjoLIPk7VF6JpVnW2UoEDB1y693MsqfkcI+IPcDrki5MNl0DbAceAdYl29YBD5dVYYmWdbQCsMfDIc0sg5rKfPyfAfdKagZ2An9M/g3jAUnrgVeB3y/zNUrSk4S7R8yYWRaVFe4R8Rugb4pd15TzvJWwrNMtdzPLrlT+QhVgYXuO5qYGt9zNLJNSG+6S6Olo8Vh3M8uk1IY7FMa6O9zNLHtSHe6egsDMsirV4b6sIz8FQf63VGZm2ZHucO9s5e3RCQ4eG6t1KWZmcyrV4V4Y6+4vVc0sa1Id7u+MdXe4m1nGpDvcC79S9YgZM8uYVId7Ybk9t9zNLGtSHe4tTY0smtfscDezzEl1uEMy1t3dMmaWMakP92WegsDMMij94d7pX6maWfakPtx7Olp54/AII2Nebs/MsiP14V4YDjl0yK13M8uOssNdUqOkX0v6UXJ/taQtkgYkfT9Zpalmerxoh5llUCVa7l8Cni+6/03gtog4H3gLWF+B1yjZMk9BYGYZVFa4S1oBfBK4I7kv4GrgweSQTcD15bxGuVYsbAPgtX1Ha1mGmdmcKrfl/jfAV4HCt5WLgf0RUZiGcRBYPtUDJW2Q1C+pf3h4uMwypregNceyjlYG9h6u2muYmZ1pSg53SZ8ChiJiaymPj4iNEdEXEX3d3d2lljErvT3z2THkcDez7Cin5f5B4DOSfgvcT7475ttAl6Sm5JgVwK6yKqyA3qULGBg6zMSEF+0ws2woOdwj4paIWBERq4AbgJ9FxB8CjwOfTQ5bBzxcdpVl6u2Zz7HRcXbtP1brUszM5kQ1xrn/FfBlSQPk++DvrMJrnJbepfMB2DF0qMaVmJnNjaaZD5lZRPwc+HlyeydwWSWet1LOL4T73sNcfVFPjasxM6u+1P9CFaCrvZnuBS3+UtXMMiMT4Q5wgUfMmFmGZCbce5cuYGDvISI8YsbM0i8z4X7+0vkcGRlnt+eYMbMMyEy4nxgx464ZM0u/7IR7zwIAduz1cEgzS7/MhPuiec0smd/MDs8xY2YZkJlwh3y/u3/IZGZZkKlw7126gB1Dhz1ixsxSL1vh3jOfQ2+PMXToeK1LMTOrqkyFe/E0BGZmaZapcL8gGTHzwp6DNa7EzKy6MhXuS+a3sHJRG7/67b5al2JmVlWZCneAK85dzJM79zHuhTvMLMUyF+5XnreEA8dGeX63u2bMLL0yF+5XnLcYgH9++Y0aV2JmVj3lLJC9UtLjkrZL2ibpS8n2RZIek7Qj+buwcuWWr6ejlfO65/HPL79Z61LMzKqmnJb7GPCViLgYWAPcJOli4GZgc0T0ApuT+2eUK89bwlOv7GN0fKLWpZiZVUU5C2Tvjoink9uHgOeB5cBaYFNy2Cbg+nKLrLQrz1vM0ZFxnhncX+tSzMyqoiJ97pJWAR8AtgA9EbE72bUHOOMWLV1zbtLvPuCuGTNLp7LDXdJ84AfAX0TESUNQIj+Jy5RjDiVtkNQvqX94eLjcMk7LwnnNXHxWh/vdzSy1ygp3STnywX5vRDyUbN4r6axk/1nA0FSPjYiNEdEXEX3d3d3llFGSK89bzNbX3uLt0fE5f20zs2orZ7SMgDuB5yPir4t2PQKsS26vAx4uvbzqufL8xYyMTfD0a2/VuhQzs4orp+X+QeCPgKsl/Sb59wngG8DHJO0APprcP+NcumoRjQ3iFy/NbZeQmdlcaCr1gRHxT4Cm2X1Nqc87Vxa05rjqgm4eenoX/+H3LiTXmLnfc5lZimU60f7g8nMYPnScn27fW+tSzMwqKtPhftWFSzm7s5V7t7xW61LMzCoq0+He2CA+d+k5/NPAG/z2jSO1LsfMrGIyHe4An7t0JY0N4r5fufVuZumR+XBf1tnK1Rct5cH+QUbGPNeMmaVD5sMd8l+svnlkhJ9s21PrUszMKsLhDny4t5vfWdzOdzbvcOvdzFLB4U7+i9X/8un3sGPoMP/7Fy/Xuhwzs7I53BMfuWgpn3zfWfyPxwfYOXy41uWYmZXF4V7k1k9fTEtTA1/7++fIT2hpZlafHO5Fli5o5ebrLuKXO9/k7/oHa12OmVnJHO6T3HjpOVy+ehFf++Gz/OwFT0tgZvXJ4T5JQ4PY+Ed9XLhsAV/83tOeNdLM6pLDfQqd7Tn+7/rLOW/pfDbc0++AN7O643CfRld7M/f+yeWsXjKPL3z3KW59+DkOHx+rdVlmZrPicH8Xi+Y184M/vZIvXLmKe558lY/f9gT/uG0PExMeSWNmZzaH+wzmtTRx66ffw4NfvILWXAMbvreVq771czY+8TJvHRmpdXlmZlNStcZzS7oW+DbQCNwREdMut9fX1xf9/f1VqaOSRsYm+Mm2PXzvyVd56pV9NAj+1YouPnT+Ytacu5iLlnWwZH4z+eVlzcyqS9LWiOibcl81wl1SI/AS8DFgEPgVcGNEbJ/q+HoJ92Iv7DnIo8/u4f8NvMFvXt/PeNJVs7A9x/lL53N2VxvLOlvpWdDKwnk5utqa6WjLMb+lifbmRtqaG2lpaqClqZFco/yGYGan7d3CveQ1VGdwGTAQETuTAu4H1gJThns9umhZBxct6+DLH7uAg2+P8szrB3hp7yFe2nuIncNHePq1t9h74Dgj47ObiCzXKJoaGmhqEA0NorFBNEg0NkCD8rcBpPz9wnuB4J03hnfeHoreJ6Z7yyjnzcRvQ2aV87lLV/Inv3tuxZ+3WuG+HHi96P4gcHnxAZI2ABsAzjnnnCqVMTc6WnN8qHcJH+pdctL2iYlg/7FR9h8d4cCxUfYfG+XYyDhHjo9xdGSckbEJRsYnOD42wdj4BGMTwej4BBMTwXgE4xMQEUwUbhMQMJF82gqg8MGr8Pmr+JPYtJ/JyviwFuU82MxOsWR+S1Wet1rhPqOI2AhshHy3TK3qqKaGBrFoXjOL5jXXuhQzy5hqjZbZBawsur8i2WZmZnOgWuH+K6BX0mpJzcANwCNVei0zM5ukKt0yETEm6d8DPyE/FPKuiNhWjdcyM7NTVa3PPSIeBR6t1vObmdn0/AtVM7MUcribmaWQw93MLIUc7mZmKVS1icNOqwhpGHi1xIcvAd6oYDn1IovnncVzhmyedxbPGU7/vH8nIrqn2nFGhHs5JPVPN3FOmmXxvLN4zpDN887iOUNlz9vdMmZmKeRwNzNLoTSE+8ZaF1AjWTzvLJ4zZPO8s3jOUMHzrvs+dzMzO1UaWu5mZjaJw93MLIXqOtwlXSvpRUkDkm6udT3VIGmlpMclbZe0TdKXku2LJD0maUfyd2Gta60GSY2Sfi3pR8n91ZK2JNf8+8mU0qkhqUvSg5JekPS8pCuycK0l/WXy3/dzku6T1JrGay3pLklDkp4r2jbl9VXed5Lzf0bSJafzWnUb7ski3P8TuA64GLhR0sW1raoqxoCvRMTFwBrgpuQ8bwY2R0QvsDm5n0ZfAp4vuv9N4LaIOB94C1hfk6qq59vAjyPiIuD95M891dda0nLgz4G+iHgv+WnCbyCd1/pu4NpJ26a7vtcBvcm/DcDtp/NCdRvuFC3CHREjQGER7lSJiN0R8XRy+xD5/7MvJ3+um5LDNgHX16bC6pG0AvgkcEdyX8DVwIPJIak6b0mdwIeBOwEiYiQi9pOBa01++vE2SU1AO7CbFF7riHgC2Ddp83TXdy1wT+Q9CXRJOmu2r1XP4T7VItzLa1TLnJC0CvgAsAXoiYjdya49QE+NyqqmvwG+Ckwk9xcD+yNiLLmftmu+GhgGvpt0Rd0haR4pv9YRsQv4FvAa+VA/AGwl3de62HTXt6yMq+dwzxRJ84EfAH8REQeL90V+PGuqxrRK+hQwFBFba13LHGoCLgFuj4gPAEeY1AWT0mu9kHwrdTVwNjCPU7suMqGS17eewz0zi3BLypEP9nsj4qFk897CR7Tk71Ct6quSDwKfkfRb8l1uV5Pvj+5KPrpD+q75IDAYEVuS+w+SD/u0X+uPAq9ExHBEjAIPkb/+ab7Wxaa7vmVlXD2HeyYW4U76me8Eno+Ivy7a9QiwLrm9Dnh4rmurpoi4JSJWRMQq8tf2ZxHxh8DjwGeTw1J13hGxB3hd0oXJpmuA7aT8WpPvjlkjqT35771w3qm91pNMd30fAT6fjJpZAxwo6r6ZWUTU7T/gE8BLwMvA12pdT5XO8UPkP6Y9A/wm+fcJ8v3Pm4EdwE+BRbWutYr/G1wF/Ci5fS7wFDAA/B3QUuv6Knyu/xroT673D4GFWbjWwH8FXgCeA74HtKTxWgP3kf9eYZT8J7X1011fQORHBL4MPEt+NNGsX8vTD5iZpVA9d8uYmdk0HO5mZinkcDczSyGHu5lZCjnczcxSyOFuZpZCDnczsxT6/6F/oXCwfCMWAAAAAElFTkSuQmCC\n"
},
"metadata": {
"needs_background": "light"
}
}
]
},
{
"cell_type": "code",
"source": [
""
],
"metadata": {
"id": "Afl22IuwYie5"
},
"id": "Afl22IuwYie5",
"execution_count": null,
"outputs": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.9"
},
"colab": {
"name": "intro-to-jax-part2.ipynb",
"provenance": [],
"collapsed_sections": []
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment