Skip to content

Instantly share code, notes, and snippets.

@shoyer
Last active June 16, 2020 06:42
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save shoyer/00e79a120968701fc716f6a8584b4ba0 to your computer and use it in GitHub Desktop.
Save shoyer/00e79a120968701fc716f6a8584b4ba0 to your computer and use it in GitHub Desktop.
numba vs jax filter.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "numba vs jax filter.ipynb",
"provenance": [],
"collapsed_sections": [
"vEGPZj-Rptta"
],
"authorship_tag": "ABX9TyNCS+m8d916yK9sD9g46eSa",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/shoyer/00e79a120968701fc716f6a8584b4ba0/numba-vs-jax-filter.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"metadata": {
"id": "QD_RulrVP5An",
"colab_type": "code",
"colab": {}
},
"source": [
"! pip install -U jax jaxlib"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "JZN6ELqKJhns",
"colab_type": "code",
"colab": {}
},
"source": [
"from jax.experimental import loops\n",
"import jax\n",
"import jax.numpy as jnp\n",
"import numpy as np\n",
"import numba\n",
"\n",
"@jax.jit\n",
"def smooth_image_jax(x, n):\n",
" with loops.Scope() as s:\n",
" s.x = x\n",
" s.y = jnp.zeros_like(x)\n",
" s.k = 0\n",
" for _ in s.while_range(lambda: s.k < n):\n",
" s.k += 1\n",
" k, m = x.shape\n",
" for i in s.range(k):\n",
" for j in s.range(m):\n",
" new_value = 0.25 * (s.x[i - 1, j]\n",
" + s.x[(i + 1) % k, j]\n",
" + s.x[i, j - 1]\n",
" + s.x[i, (j + 1) % m])\n",
" s.y = s.y.at[i, j].set(new_value)\n",
" s.x = s.y\n",
" return s.y\n",
"\n",
"@numba.jit\n",
"def _smooth_image_numba(x, n):\n",
" y = np.zeros_like(x)\n",
" for _ in range(n):\n",
" k, m = x.shape\n",
" for i in range(k):\n",
" for j in range(m):\n",
" y[i, j] = 0.25 * (x[i - 1, j]\n",
" + x[(i + 1) % k, j]\n",
" + x[i, j - 1]\n",
" + x[i, (j + 1) % m])\n",
" x[:] = y\n",
" return y\n",
"\n",
"def smooth_image_numba(x, n):\n",
" return _smooth_image_numba(x.copy(), n)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "vEGPZj-Rptta",
"colab_type": "text"
},
"source": [
"## Unit test"
]
},
{
"cell_type": "code",
"metadata": {
"id": "sx_bI0OfN9Oc",
"colab_type": "code",
"colab": {}
},
"source": [
"x = np.arange(25.0).reshape(5, 5)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "gDIjwkX_NfGc",
"colab_type": "code",
"outputId": "567dde17-e44f-4c47-b909-71533ca68724",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 104
}
},
"source": [
"x"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"array([[ 0., 1., 2., 3., 4.],\n",
" [ 5., 6., 7., 8., 9.],\n",
" [10., 11., 12., 13., 14.],\n",
" [15., 16., 17., 18., 19.],\n",
" [20., 21., 22., 23., 24.]])"
]
},
"metadata": {
"tags": []
},
"execution_count": 41
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "u1NpbLfRpijT",
"colab_type": "code",
"outputId": "2c0c2289-d3d2-4205-aeb7-1b68db5d79fe",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 104
}
},
"source": [
"smooth_image_numba(x, 2)"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"array([[ 9.375 , 9.125 , 9.8125, 10.5 , 10.25 ],\n",
" [ 8.125 , 7.875 , 8.5625, 9.25 , 9. ],\n",
" [11.5625, 11.3125, 12. , 12.6875, 12.4375],\n",
" [15. , 14.75 , 15.4375, 16.125 , 15.875 ],\n",
" [13.75 , 13.5 , 14.1875, 14.875 , 14.625 ]])"
]
},
"metadata": {
"tags": []
},
"execution_count": 45
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "oLFNQ5_yoQjo",
"colab_type": "code",
"outputId": "689df148-8181-4063-fbd7-6b177924f63a",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 104
}
},
"source": [
"smooth_image_numba(x, 2)"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"array([[ 9.375 , 9.125 , 9.8125, 10.5 , 10.25 ],\n",
" [ 8.125 , 7.875 , 8.5625, 9.25 , 9. ],\n",
" [11.5625, 11.3125, 12. , 12.6875, 12.4375],\n",
" [15. , 14.75 , 15.4375, 16.125 , 15.875 ],\n",
" [13.75 , 13.5 , 14.1875, 14.875 , 14.625 ]])"
]
},
"metadata": {
"tags": []
},
"execution_count": 46
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "iWC6zOPApkEI",
"colab_type": "code",
"outputId": "b958db62-2496-4cc1-e5d2-975e58abfb1c",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 104
}
},
"source": [
"smooth_image_jax(x, 1)"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"DeviceArray([[ 7.5 , 7.25, 8.25, 9.25, 9. ],\n",
" [ 6.25, 6. , 7. , 8. , 7.75],\n",
" [11.25, 11. , 12. , 13. , 12.75],\n",
" [16.25, 16. , 17. , 18. , 17.75],\n",
" [15. , 14.75, 15.75, 16.75, 16.5 ]], dtype=float32)"
]
},
"metadata": {
"tags": []
},
"execution_count": 43
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "MziafifSoSGI",
"colab_type": "code",
"outputId": "613dffb6-dd27-4ea0-9dfd-f2a5ae35fb98",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 104
}
},
"source": [
"smooth_image_jax(x, 1)"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"DeviceArray([[ 7.5 , 7.25, 8.25, 9.25, 10. ],\n",
" [ 6.25, 6. , 7. , 8. , 8.75],\n",
" [11.25, 11. , 12. , 13. , 13.75],\n",
" [16.25, 16. , 17. , 18. , 18.75],\n",
" [20. , 19.75, 20.75, 21.75, 22.5 ]], dtype=float32)"
]
},
"metadata": {
"tags": []
},
"execution_count": 37
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "nkPuBoKSpvk5",
"colab_type": "text"
},
"source": [
"## Performance test"
]
},
{
"cell_type": "code",
"metadata": {
"id": "WSTpnF8epywR",
"colab_type": "code",
"colab": {}
},
"source": [
"x = np.arange(256.0 ** 2).reshape(256, 256)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "U6gTcqOwMzt0",
"colab_type": "code",
"outputId": "9e717233-8124-4a66-c975-4fba70dba732",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 294
}
},
"source": [
"%time print(smooth_image_numba(x, 64))\n",
"%timeit smooth_image_numba(x, 64)"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"[[30580.57911198 30563.8374853 30547.90229619 ... 30630.27476735\n",
" 30614.33957825 30597.59795157]\n",
" [26294.72268224 26277.98105556 26262.04586645 ... 26344.41833761\n",
" 26328.48314851 26311.74152183]\n",
" [22215.31427203 22198.57264535 22182.63745625 ... 22265.00992741\n",
" 22249.0747383 22232.33311163]\n",
" ...\n",
" [43302.66688837 43285.9252617 43269.99007259 ... 43352.36254375\n",
" 43336.42735465 43319.68572797]\n",
" [39223.25847817 39206.51685149 39190.58166239 ... 39272.95413355\n",
" 39257.01894444 39240.27731776]\n",
" [34937.40204843 34920.66042175 34904.72523265 ... 34987.09770381\n",
" 34971.1625147 34954.42088802]]\n",
"CPU times: user 442 ms, sys: 10.7 ms, total: 453 ms\n",
"Wall time: 454 ms\n",
"10 loops, best of 3: 31 ms per loop\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "gCIgoVJkoHsq",
"colab_type": "code",
"outputId": "0f257a79-f40e-42a7-cc43-1cfb673afe1f",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 225
}
},
"source": [
"%time print(smooth_image_jax(x, 64).block_until_ready())\n",
"%timeit smooth_image_jax(x, 64).block_until_ready()"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"/usr/local/lib/python3.6/dist-packages/jax/lib/xla_bridge.py:127: UserWarning: No GPU/TPU found, falling back to CPU.\n",
" warnings.warn('No GPU/TPU found, falling back to CPU.')\n"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"[[30580.578 30563.84 30547.902 ... 30630.277 30614.34 30597.602]\n",
" [26294.727 26277.98 26262.047 ... 26344.418 26328.488 26311.74 ]\n",
" [22215.312 22198.574 22182.637 ... 22265.014 22249.072 22232.338]\n",
" ...\n",
" [43302.67 43285.926 43269.992 ... 43352.363 43336.43 43319.688]\n",
" [39223.258 39206.52 39190.582 ... 39272.957 39257.023 39240.28 ]\n",
" [34937.406 34920.66 34904.727 ... 34987.098 34971.164 34954.42 ]]\n",
"CPU times: user 814 ms, sys: 37.1 ms, total: 851 ms\n",
"Wall time: 844 ms\n",
"10 loops, best of 3: 59.3 ms per loop\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "VNPX2iXPKmK_",
"colab_type": "code",
"colab": {}
},
"source": [
""
],
"execution_count": 0,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment