Last active
June 16, 2020 06:42
-
-
Save shoyer/00e79a120968701fc716f6a8584b4ba0 to your computer and use it in GitHub Desktop.
numba vs jax filter.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"name": "numba vs jax filter.ipynb", | |
"provenance": [], | |
"collapsed_sections": [ | |
"vEGPZj-Rptta" | |
], | |
"authorship_tag": "ABX9TyNCS+m8d916yK9sD9g46eSa", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/shoyer/00e79a120968701fc716f6a8584b4ba0/numba-vs-jax-filter.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "QD_RulrVP5An", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"! pip install -U jax jaxlib" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "JZN6ELqKJhns", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"from jax.experimental import loops\n", | |
"import jax\n", | |
"import jax.numpy as jnp\n", | |
"import numpy as np\n", | |
"import numba\n", | |
"\n", | |
"@jax.jit\n", | |
"def smooth_image_jax(x, n):\n", | |
" with loops.Scope() as s:\n", | |
" s.x = x\n", | |
" s.y = jnp.zeros_like(x)\n", | |
" s.k = 0\n", | |
" for _ in s.while_range(lambda: s.k < n):\n", | |
" s.k += 1\n", | |
" k, m = x.shape\n", | |
" for i in s.range(k):\n", | |
" for j in s.range(m):\n", | |
" new_value = 0.25 * (s.x[i - 1, j]\n", | |
" + s.x[(i + 1) % k, j]\n", | |
" + s.x[i, j - 1]\n", | |
" + s.x[i, (j + 1) % m])\n", | |
" s.y = s.y.at[i, j].set(new_value)\n", | |
" s.x = s.y\n", | |
" return s.y\n", | |
"\n", | |
"@numba.jit\n", | |
"def _smooth_image_numba(x, n):\n", | |
" y = np.zeros_like(x)\n", | |
" for _ in range(n):\n", | |
" k, m = x.shape\n", | |
" for i in range(k):\n", | |
" for j in range(m):\n", | |
" y[i, j] = 0.25 * (x[i - 1, j]\n", | |
" + x[(i + 1) % k, j]\n", | |
" + x[i, j - 1]\n", | |
" + x[i, (j + 1) % m])\n", | |
" x[:] = y\n", | |
" return y\n", | |
"\n", | |
"def smooth_image_numba(x, n):\n", | |
" return _smooth_image_numba(x.copy(), n)" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "vEGPZj-Rptta", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"## Unit test" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "sx_bI0OfN9Oc", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"x = np.arange(25.0).reshape(5, 5)" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "gDIjwkX_NfGc", | |
"colab_type": "code", | |
"outputId": "567dde17-e44f-4c47-b909-71533ca68724", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 104 | |
} | |
}, | |
"source": [ | |
"x" | |
], | |
"execution_count": 0, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"array([[ 0., 1., 2., 3., 4.],\n", | |
" [ 5., 6., 7., 8., 9.],\n", | |
" [10., 11., 12., 13., 14.],\n", | |
" [15., 16., 17., 18., 19.],\n", | |
" [20., 21., 22., 23., 24.]])" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 41 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "u1NpbLfRpijT", | |
"colab_type": "code", | |
"outputId": "2c0c2289-d3d2-4205-aeb7-1b68db5d79fe", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 104 | |
} | |
}, | |
"source": [ | |
"smooth_image_numba(x, 2)" | |
], | |
"execution_count": 0, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"array([[ 9.375 , 9.125 , 9.8125, 10.5 , 10.25 ],\n", | |
" [ 8.125 , 7.875 , 8.5625, 9.25 , 9. ],\n", | |
" [11.5625, 11.3125, 12. , 12.6875, 12.4375],\n", | |
" [15. , 14.75 , 15.4375, 16.125 , 15.875 ],\n", | |
" [13.75 , 13.5 , 14.1875, 14.875 , 14.625 ]])" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 45 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "oLFNQ5_yoQjo", | |
"colab_type": "code", | |
"outputId": "689df148-8181-4063-fbd7-6b177924f63a", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 104 | |
} | |
}, | |
"source": [ | |
"smooth_image_numba(x, 2)" | |
], | |
"execution_count": 0, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"array([[ 9.375 , 9.125 , 9.8125, 10.5 , 10.25 ],\n", | |
" [ 8.125 , 7.875 , 8.5625, 9.25 , 9. ],\n", | |
" [11.5625, 11.3125, 12. , 12.6875, 12.4375],\n", | |
" [15. , 14.75 , 15.4375, 16.125 , 15.875 ],\n", | |
" [13.75 , 13.5 , 14.1875, 14.875 , 14.625 ]])" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 46 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "iWC6zOPApkEI", | |
"colab_type": "code", | |
"outputId": "b958db62-2496-4cc1-e5d2-975e58abfb1c", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 104 | |
} | |
}, | |
"source": [ | |
"smooth_image_jax(x, 1)" | |
], | |
"execution_count": 0, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"DeviceArray([[ 7.5 , 7.25, 8.25, 9.25, 9. ],\n", | |
" [ 6.25, 6. , 7. , 8. , 7.75],\n", | |
" [11.25, 11. , 12. , 13. , 12.75],\n", | |
" [16.25, 16. , 17. , 18. , 17.75],\n", | |
" [15. , 14.75, 15.75, 16.75, 16.5 ]], dtype=float32)" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 43 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "MziafifSoSGI", | |
"colab_type": "code", | |
"outputId": "613dffb6-dd27-4ea0-9dfd-f2a5ae35fb98", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 104 | |
} | |
}, | |
"source": [ | |
"smooth_image_jax(x, 1)" | |
], | |
"execution_count": 0, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"DeviceArray([[ 7.5 , 7.25, 8.25, 9.25, 10. ],\n", | |
" [ 6.25, 6. , 7. , 8. , 8.75],\n", | |
" [11.25, 11. , 12. , 13. , 13.75],\n", | |
" [16.25, 16. , 17. , 18. , 18.75],\n", | |
" [20. , 19.75, 20.75, 21.75, 22.5 ]], dtype=float32)" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 37 | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "nkPuBoKSpvk5", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"## Performance test" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "WSTpnF8epywR", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"x = np.arange(256.0 ** 2).reshape(256, 256)" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "U6gTcqOwMzt0", | |
"colab_type": "code", | |
"outputId": "9e717233-8124-4a66-c975-4fba70dba732", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 294 | |
} | |
}, | |
"source": [ | |
"%time print(smooth_image_numba(x, 64))\n", | |
"%timeit smooth_image_numba(x, 64)" | |
], | |
"execution_count": 0, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"[[30580.57911198 30563.8374853 30547.90229619 ... 30630.27476735\n", | |
" 30614.33957825 30597.59795157]\n", | |
" [26294.72268224 26277.98105556 26262.04586645 ... 26344.41833761\n", | |
" 26328.48314851 26311.74152183]\n", | |
" [22215.31427203 22198.57264535 22182.63745625 ... 22265.00992741\n", | |
" 22249.0747383 22232.33311163]\n", | |
" ...\n", | |
" [43302.66688837 43285.9252617 43269.99007259 ... 43352.36254375\n", | |
" 43336.42735465 43319.68572797]\n", | |
" [39223.25847817 39206.51685149 39190.58166239 ... 39272.95413355\n", | |
" 39257.01894444 39240.27731776]\n", | |
" [34937.40204843 34920.66042175 34904.72523265 ... 34987.09770381\n", | |
" 34971.1625147 34954.42088802]]\n", | |
"CPU times: user 442 ms, sys: 10.7 ms, total: 453 ms\n", | |
"Wall time: 454 ms\n", | |
"10 loops, best of 3: 31 ms per loop\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "gCIgoVJkoHsq", | |
"colab_type": "code", | |
"outputId": "0f257a79-f40e-42a7-cc43-1cfb673afe1f", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 225 | |
} | |
}, | |
"source": [ | |
"%time print(smooth_image_jax(x, 64).block_until_ready())\n", | |
"%timeit smooth_image_jax(x, 64).block_until_ready()" | |
], | |
"execution_count": 0, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"/usr/local/lib/python3.6/dist-packages/jax/lib/xla_bridge.py:127: UserWarning: No GPU/TPU found, falling back to CPU.\n", | |
" warnings.warn('No GPU/TPU found, falling back to CPU.')\n" | |
], | |
"name": "stderr" | |
}, | |
{ | |
"output_type": "stream", | |
"text": [ | |
"[[30580.578 30563.84 30547.902 ... 30630.277 30614.34 30597.602]\n", | |
" [26294.727 26277.98 26262.047 ... 26344.418 26328.488 26311.74 ]\n", | |
" [22215.312 22198.574 22182.637 ... 22265.014 22249.072 22232.338]\n", | |
" ...\n", | |
" [43302.67 43285.926 43269.992 ... 43352.363 43336.43 43319.688]\n", | |
" [39223.258 39206.52 39190.582 ... 39272.957 39257.023 39240.28 ]\n", | |
" [34937.406 34920.66 34904.727 ... 34987.098 34971.164 34954.42 ]]\n", | |
"CPU times: user 814 ms, sys: 37.1 ms, total: 851 ms\n", | |
"Wall time: 844 ms\n", | |
"10 loops, best of 3: 59.3 ms per loop\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "VNPX2iXPKmK_", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment