Skip to content

Instantly share code, notes, and snippets.

@shoyer
Last active August 29, 2021 23:28
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save shoyer/4e0328c277e46f58c47d79b85a51aa0a to your computer and use it in GitHub Desktop.
Save shoyer/4e0328c277e46f58c47d79b85a51aa0a to your computer and use it in GitHub Desktop.
JAX new scan benchmarking.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "JAX new scan benchmarking.ipynb",
"provenance": [],
"collapsed_sections": [],
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/shoyer/4e0328c277e46f58c47d79b85a51aa0a/jax-new-scan-benchmarking.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"metadata": {
"id": "5eWe1pu6K9Em"
},
"source": [
"# Copyright 2021 Google LLC.\n",
"# SPDX-License-Identifier: Apache-2.0\n",
"\n",
"from jax import lax, jit\n",
"from functools import partial\n",
"\n",
"import jax.numpy as jnp\n",
"import jax\n",
"import numpy as np\n"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "4cv9jO_rL4Uk",
"outputId": "27cc5f3e-c1f8-411e-8938-c0b70a942d1b"
},
"source": [
"@partial(jit, static_argnames=['unroll'], backend='cpu')\n",
"def polyval(p, x, unroll=64):\n",
" shape = lax.broadcast_shapes(p.shape[1:], x.shape)\n",
" dtype = jnp.result_type(p, x)\n",
" y = lax.full_like(x, 0, shape=shape, dtype=dtype)\n",
" y, _ = lax.scan(lambda y, p: (y * x + p, None), y, p, unroll=unroll)\n",
" return y\n",
"\n",
"\n",
"x = np.random.rand(100).astype(np.float32)\n",
"p = np.random.randn(10000).astype(np.float32)\n",
"\n",
"print(\"CPU\")\n",
"for unroll in [1, 2, 4, 8, 16, 32, 64, 128]:\n",
" print(f\"unroll={unroll}\")\n",
" %time polyval(p, x, unroll).block_until_ready()\n",
" %timeit polyval(p, x, unroll).block_until_ready()"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"CPU\n",
"unroll=1\n",
"CPU times: user 29.3 ms, sys: 0 ns, total: 29.3 ms\n",
"Wall time: 29.6 ms\n",
"10000 loops, best of 5: 78 µs per loop\n",
"unroll=2\n",
"CPU times: user 35.2 ms, sys: 0 ns, total: 35.2 ms\n",
"Wall time: 35 ms\n",
"10000 loops, best of 5: 45.4 µs per loop\n",
"unroll=4\n",
"CPU times: user 38.3 ms, sys: 0 ns, total: 38.3 ms\n",
"Wall time: 38.5 ms\n",
"10000 loops, best of 5: 35.8 µs per loop\n",
"unroll=8\n",
"CPU times: user 47 ms, sys: 0 ns, total: 47 ms\n",
"Wall time: 47.1 ms\n",
"10000 loops, best of 5: 37.2 µs per loop\n",
"unroll=16\n",
"CPU times: user 61.1 ms, sys: 0 ns, total: 61.1 ms\n",
"Wall time: 61.4 ms\n",
"10000 loops, best of 5: 46.9 µs per loop\n",
"unroll=32\n",
"CPU times: user 135 ms, sys: 0 ns, total: 135 ms\n",
"Wall time: 135 ms\n",
"1000 loops, best of 5: 358 µs per loop\n",
"unroll=64\n",
"CPU times: user 178 ms, sys: 0 ns, total: 178 ms\n",
"Wall time: 177 ms\n",
"10000 loops, best of 5: 98.7 µs per loop\n",
"unroll=128\n",
"CPU times: user 307 ms, sys: 0 ns, total: 307 ms\n",
"Wall time: 307 ms\n",
"10000 loops, best of 5: 130 µs per loop\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "4cqrgQ_VMJqM",
"outputId": "5d8704e2-52ce-4517-bc1f-baf2848e9fe3"
},
"source": [
"@partial(jit, static_argnames=['unroll'], backend='gpu')\n",
"def polyval(p, x, unroll=64):\n",
" shape = lax.broadcast_shapes(p.shape[1:], x.shape)\n",
" dtype = jnp.result_type(p, x)\n",
" y = lax.full_like(x, 0, shape=shape, dtype=dtype)\n",
" y, _ = lax.scan(lambda y, p: (y * x + p, None), y, p, unroll=unroll)\n",
" return y\n",
"\n",
"\n",
"x = jax.device_put(np.random.rand(100))\n",
"p = jax.device_put(np.random.randn(10000))\n",
"\n",
"print(\"GPU\")\n",
"for unroll in [1, 2, 4, 8, 16, 32, 64, 128]:\n",
" print(f\"unroll={unroll}\")\n",
" %time polyval(p, x, unroll).block_until_ready()\n",
" %timeit polyval(p, x, unroll).block_until_ready()"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"GPU\n",
"unroll=1\n",
"CPU times: user 112 ms, sys: 34.1 ms, total: 146 ms\n",
"Wall time: 730 ms\n",
"10 loops, best of 5: 70.6 ms per loop\n",
"unroll=2\n",
"CPU times: user 62.9 ms, sys: 11.2 ms, total: 74.1 ms\n",
"Wall time: 150 ms\n",
"10 loops, best of 5: 35.6 ms per loop\n",
"unroll=4\n",
"CPU times: user 47.7 ms, sys: 13.6 ms, total: 61.3 ms\n",
"Wall time: 122 ms\n",
"100 loops, best of 5: 17.5 ms per loop\n",
"unroll=8\n",
"CPU times: user 42.9 ms, sys: 27.7 ms, total: 70.6 ms\n",
"Wall time: 129 ms\n",
"100 loops, best of 5: 8.86 ms per loop\n",
"unroll=16\n",
"CPU times: user 56 ms, sys: 34.4 ms, total: 90.4 ms\n",
"Wall time: 144 ms\n",
"100 loops, best of 5: 6.54 ms per loop\n",
"unroll=32\n",
"CPU times: user 105 ms, sys: 38.1 ms, total: 143 ms\n",
"Wall time: 214 ms\n",
"100 loops, best of 5: 3.16 ms per loop\n",
"unroll=64\n",
"CPU times: user 162 ms, sys: 32.8 ms, total: 195 ms\n",
"Wall time: 258 ms\n",
"1000 loops, best of 5: 1.83 ms per loop\n",
"unroll=128\n",
"CPU times: user 393 ms, sys: 8.66 ms, total: 402 ms\n",
"Wall time: 501 ms\n",
"1000 loops, best of 5: 861 µs per loop\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "Ttmy9oH7LJ8D",
"outputId": "7c5fd8a3-516c-413d-f906-8c92713e6d55"
},
"source": [
"@partial(jit, static_argnames=['unroll'])\n",
"def polyval(p, x, unroll=64):\n",
" shape = lax.broadcast_shapes(p.shape[1:], x.shape)\n",
" dtype = jnp.result_type(p, x)\n",
" y = lax.full_like(x, 0, shape=shape, dtype=dtype)\n",
" y, _ = lax.scan(lambda y, p: (y * x + p, None), y, p, unroll=unroll)\n",
" return y\n",
"\n",
"\n",
"x = jax.device_put(np.random.rand(100))\n",
"p = jax.device_put(np.random.randn(10000))\n",
"\n",
"print(\"TPU\")\n",
"for unroll in [1, 2, 4, 8, 16, 32, 64, 128]:\n",
" print(f\"unroll={unroll}\")\n",
" %time polyval(p, x, unroll).block_until_ready()\n",
" %timeit polyval(p, x, unroll).block_until_ready()"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"TPU\n",
"unroll=1\n",
"CPU times: user 34.8 ms, sys: 0 ns, total: 34.8 ms\n",
"Wall time: 45.2 ms\n",
"100 loops, best of 5: 13.2 ms per loop\n",
"unroll=2\n",
"CPU times: user 134 ms, sys: 89 µs, total: 134 ms\n",
"Wall time: 107 ms\n",
"100 loops, best of 5: 13.3 ms per loop\n",
"unroll=4\n",
"CPU times: user 118 ms, sys: 0 ns, total: 118 ms\n",
"Wall time: 90.7 ms\n",
"100 loops, best of 5: 9.31 ms per loop\n",
"unroll=8\n",
"CPU times: user 96.1 ms, sys: 2.07 ms, total: 98.2 ms\n",
"Wall time: 94.7 ms\n",
"100 loops, best of 5: 5.48 ms per loop\n",
"unroll=16\n",
"CPU times: user 118 ms, sys: 0 ns, total: 118 ms\n",
"Wall time: 107 ms\n",
"100 loops, best of 5: 4.2 ms per loop\n",
"unroll=32\n",
"CPU times: user 218 ms, sys: 0 ns, total: 218 ms\n",
"Wall time: 181 ms\n",
"100 loops, best of 5: 3.79 ms per loop\n",
"unroll=64\n",
"CPU times: user 325 ms, sys: 2.08 ms, total: 327 ms\n",
"Wall time: 276 ms\n",
"100 loops, best of 5: 3.64 ms per loop\n",
"unroll=128\n",
"CPU times: user 726 ms, sys: 0 ns, total: 726 ms\n",
"Wall time: 631 ms\n",
"100 loops, best of 5: 3.49 ms per loop\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "l53AfRCbLTEd"
},
"source": [
""
],
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment