Last active
August 29, 2021 23:28
-
-
Save shoyer/4e0328c277e46f58c47d79b85a51aa0a to your computer and use it in GitHub Desktop.
JAX new scan benchmarking.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"name": "JAX new scan benchmarking.ipynb", | |
"provenance": [], | |
"collapsed_sections": [], | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"language_info": { | |
"name": "python" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/shoyer/4e0328c277e46f58c47d79b85a51aa0a/jax-new-scan-benchmarking.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "5eWe1pu6K9Em" | |
}, | |
"source": [ | |
"# Copyright 2021 Google LLC.\n", | |
"# SPDX-License-Identifier: Apache-2.0\n", | |
"\n", | |
"from jax import lax, jit\n", | |
"from functools import partial\n", | |
"\n", | |
"import jax.numpy as jnp\n", | |
"import jax\n", | |
"import numpy as np\n" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "4cv9jO_rL4Uk", | |
"outputId": "27cc5f3e-c1f8-411e-8938-c0b70a942d1b" | |
}, | |
"source": [ | |
"@partial(jit, static_argnames=['unroll'], backend='cpu')\n", | |
"def polyval(p, x, unroll=64):\n", | |
" shape = lax.broadcast_shapes(p.shape[1:], x.shape)\n", | |
" dtype = jnp.result_type(p, x)\n", | |
" y = lax.full_like(x, 0, shape=shape, dtype=dtype)\n", | |
" y, _ = lax.scan(lambda y, p: (y * x + p, None), y, p, unroll=unroll)\n", | |
" return y\n", | |
"\n", | |
"\n", | |
"x = np.random.rand(100).astype(np.float32)\n", | |
"p = np.random.randn(10000).astype(np.float32)\n", | |
"\n", | |
"print(\"CPU\")\n", | |
"for unroll in [1, 2, 4, 8, 16, 32, 64, 128]:\n", | |
" print(f\"unroll={unroll}\")\n", | |
" %time polyval(p, x, unroll).block_until_ready()\n", | |
" %timeit polyval(p, x, unroll).block_until_ready()" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"CPU\n", | |
"unroll=1\n", | |
"CPU times: user 29.3 ms, sys: 0 ns, total: 29.3 ms\n", | |
"Wall time: 29.6 ms\n", | |
"10000 loops, best of 5: 78 µs per loop\n", | |
"unroll=2\n", | |
"CPU times: user 35.2 ms, sys: 0 ns, total: 35.2 ms\n", | |
"Wall time: 35 ms\n", | |
"10000 loops, best of 5: 45.4 µs per loop\n", | |
"unroll=4\n", | |
"CPU times: user 38.3 ms, sys: 0 ns, total: 38.3 ms\n", | |
"Wall time: 38.5 ms\n", | |
"10000 loops, best of 5: 35.8 µs per loop\n", | |
"unroll=8\n", | |
"CPU times: user 47 ms, sys: 0 ns, total: 47 ms\n", | |
"Wall time: 47.1 ms\n", | |
"10000 loops, best of 5: 37.2 µs per loop\n", | |
"unroll=16\n", | |
"CPU times: user 61.1 ms, sys: 0 ns, total: 61.1 ms\n", | |
"Wall time: 61.4 ms\n", | |
"10000 loops, best of 5: 46.9 µs per loop\n", | |
"unroll=32\n", | |
"CPU times: user 135 ms, sys: 0 ns, total: 135 ms\n", | |
"Wall time: 135 ms\n", | |
"1000 loops, best of 5: 358 µs per loop\n", | |
"unroll=64\n", | |
"CPU times: user 178 ms, sys: 0 ns, total: 178 ms\n", | |
"Wall time: 177 ms\n", | |
"10000 loops, best of 5: 98.7 µs per loop\n", | |
"unroll=128\n", | |
"CPU times: user 307 ms, sys: 0 ns, total: 307 ms\n", | |
"Wall time: 307 ms\n", | |
"10000 loops, best of 5: 130 µs per loop\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "4cqrgQ_VMJqM", | |
"outputId": "5d8704e2-52ce-4517-bc1f-baf2848e9fe3" | |
}, | |
"source": [ | |
"@partial(jit, static_argnames=['unroll'], backend='gpu')\n", | |
"def polyval(p, x, unroll=64):\n", | |
" shape = lax.broadcast_shapes(p.shape[1:], x.shape)\n", | |
" dtype = jnp.result_type(p, x)\n", | |
" y = lax.full_like(x, 0, shape=shape, dtype=dtype)\n", | |
" y, _ = lax.scan(lambda y, p: (y * x + p, None), y, p, unroll=unroll)\n", | |
" return y\n", | |
"\n", | |
"\n", | |
"x = jax.device_put(np.random.rand(100))\n", | |
"p = jax.device_put(np.random.randn(10000))\n", | |
"\n", | |
"print(\"GPU\")\n", | |
"for unroll in [1, 2, 4, 8, 16, 32, 64, 128]:\n", | |
" print(f\"unroll={unroll}\")\n", | |
" %time polyval(p, x, unroll).block_until_ready()\n", | |
" %timeit polyval(p, x, unroll).block_until_ready()" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"GPU\n", | |
"unroll=1\n", | |
"CPU times: user 112 ms, sys: 34.1 ms, total: 146 ms\n", | |
"Wall time: 730 ms\n", | |
"10 loops, best of 5: 70.6 ms per loop\n", | |
"unroll=2\n", | |
"CPU times: user 62.9 ms, sys: 11.2 ms, total: 74.1 ms\n", | |
"Wall time: 150 ms\n", | |
"10 loops, best of 5: 35.6 ms per loop\n", | |
"unroll=4\n", | |
"CPU times: user 47.7 ms, sys: 13.6 ms, total: 61.3 ms\n", | |
"Wall time: 122 ms\n", | |
"100 loops, best of 5: 17.5 ms per loop\n", | |
"unroll=8\n", | |
"CPU times: user 42.9 ms, sys: 27.7 ms, total: 70.6 ms\n", | |
"Wall time: 129 ms\n", | |
"100 loops, best of 5: 8.86 ms per loop\n", | |
"unroll=16\n", | |
"CPU times: user 56 ms, sys: 34.4 ms, total: 90.4 ms\n", | |
"Wall time: 144 ms\n", | |
"100 loops, best of 5: 6.54 ms per loop\n", | |
"unroll=32\n", | |
"CPU times: user 105 ms, sys: 38.1 ms, total: 143 ms\n", | |
"Wall time: 214 ms\n", | |
"100 loops, best of 5: 3.16 ms per loop\n", | |
"unroll=64\n", | |
"CPU times: user 162 ms, sys: 32.8 ms, total: 195 ms\n", | |
"Wall time: 258 ms\n", | |
"1000 loops, best of 5: 1.83 ms per loop\n", | |
"unroll=128\n", | |
"CPU times: user 393 ms, sys: 8.66 ms, total: 402 ms\n", | |
"Wall time: 501 ms\n", | |
"1000 loops, best of 5: 861 µs per loop\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "Ttmy9oH7LJ8D", | |
"outputId": "7c5fd8a3-516c-413d-f906-8c92713e6d55" | |
}, | |
"source": [ | |
"@partial(jit, static_argnames=['unroll'])\n", | |
"def polyval(p, x, unroll=64):\n", | |
" shape = lax.broadcast_shapes(p.shape[1:], x.shape)\n", | |
" dtype = jnp.result_type(p, x)\n", | |
" y = lax.full_like(x, 0, shape=shape, dtype=dtype)\n", | |
" y, _ = lax.scan(lambda y, p: (y * x + p, None), y, p, unroll=unroll)\n", | |
" return y\n", | |
"\n", | |
"\n", | |
"x = jax.device_put(np.random.rand(100))\n", | |
"p = jax.device_put(np.random.randn(10000))\n", | |
"\n", | |
"print(\"TPU\")\n", | |
"for unroll in [1, 2, 4, 8, 16, 32, 64, 128]:\n", | |
" print(f\"unroll={unroll}\")\n", | |
" %time polyval(p, x, unroll).block_until_ready()\n", | |
" %timeit polyval(p, x, unroll).block_until_ready()" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"TPU\n", | |
"unroll=1\n", | |
"CPU times: user 34.8 ms, sys: 0 ns, total: 34.8 ms\n", | |
"Wall time: 45.2 ms\n", | |
"100 loops, best of 5: 13.2 ms per loop\n", | |
"unroll=2\n", | |
"CPU times: user 134 ms, sys: 89 µs, total: 134 ms\n", | |
"Wall time: 107 ms\n", | |
"100 loops, best of 5: 13.3 ms per loop\n", | |
"unroll=4\n", | |
"CPU times: user 118 ms, sys: 0 ns, total: 118 ms\n", | |
"Wall time: 90.7 ms\n", | |
"100 loops, best of 5: 9.31 ms per loop\n", | |
"unroll=8\n", | |
"CPU times: user 96.1 ms, sys: 2.07 ms, total: 98.2 ms\n", | |
"Wall time: 94.7 ms\n", | |
"100 loops, best of 5: 5.48 ms per loop\n", | |
"unroll=16\n", | |
"CPU times: user 118 ms, sys: 0 ns, total: 118 ms\n", | |
"Wall time: 107 ms\n", | |
"100 loops, best of 5: 4.2 ms per loop\n", | |
"unroll=32\n", | |
"CPU times: user 218 ms, sys: 0 ns, total: 218 ms\n", | |
"Wall time: 181 ms\n", | |
"100 loops, best of 5: 3.79 ms per loop\n", | |
"unroll=64\n", | |
"CPU times: user 325 ms, sys: 2.08 ms, total: 327 ms\n", | |
"Wall time: 276 ms\n", | |
"100 loops, best of 5: 3.64 ms per loop\n", | |
"unroll=128\n", | |
"CPU times: user 726 ms, sys: 0 ns, total: 726 ms\n", | |
"Wall time: 631 ms\n", | |
"100 loops, best of 5: 3.49 ms per loop\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "l53AfRCbLTEd" | |
}, | |
"source": [ | |
"" | |
], | |
"execution_count": null, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment