Skip to content

Instantly share code, notes, and snippets.

@cemlyn007
Created October 31, 2022 15:08
Show Gist options
  • Save cemlyn007/78873b2d37647d0b64fcec599d6b8aad to your computer and use it in GitHub Desktop.
Save cemlyn007/78873b2d37647d0b64fcec599d6b8aad to your computer and use it in GitHub Desktop.
JitRlaxLambdaReturnsBenchmark.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"collapsed_sections": [],
"authorship_tag": "ABX9TyMvg/5R6yDGp3kaFijZN9aM",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/cemlyn007/78873b2d37647d0b64fcec599d6b8aad/jitrlaxlambdareturnsbenchmark.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"source": [
"!pip install -q -U pip jax\n",
"!pip install -q rlax"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "mqC6Zd9FV03w",
"outputId": "f226da0e-61e7-4336-b9a5-6b48d1ca9688"
},
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
"\u001b[0m\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
"\u001b[0m"
]
}
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "JLJjYbYwVleB",
"outputId": "ed778f68-0026-4e42-e67e-dfed9061773e"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"JAX Version: 0.3.23\n",
"Size, Time taken to execute 100 times\n",
"2 0.004910496920001606 0.061564857000121265 5.2999499985162405e-06\n",
"4 0.007591661210003622 0.08067618299992318 6.602229996133247e-06\n",
"8 0.013084104059998936 0.13883926699963922 8.437299998149683e-06\n",
"16 0.023296493190000548 0.2841104499998437 5.7832109996525104e-05\n",
"32 0.043934581620001155 0.49321855699963635 5.813510001644317e-06\n",
"64 0.08431943933000184 1.075515901000017 6.881499998598883e-06\n",
"128 0.16692044445000193 2.5870816419997027 7.299479998437164e-06\n",
"256 0.3538112200400019 8.227743565999845 1.0447419999763951e-05\n",
"512 0.6520403084700002 40.382920166000076 1.3444420001178513e-05\n",
"1024 1.4777235191299996 269.793594882 2.5668820003375005e-05\n"
]
}
],
"source": [
"import functools\n",
"import jax\n",
"import rlax\n",
"import jax.numpy as jnp\n",
"import timeit\n",
"import gc\n",
"import matplotlib.pyplot as plt\n",
"\n",
"print(f'JAX Version: {jax.__version__}', flush=True)\n",
"\n",
"number = 100\n",
"lambda_returns = functools.partial(rlax.lambda_returns, stop_target_gradients=True)\n",
"jitted_lambda_returns = jax.jit(lambda_returns)\n",
"sizes = []\n",
"mean_time_takens = []\n",
"jit_first_time_takens = []\n",
"mean_jit_exec_time_takens = []\n",
"print(f'Size, Time taken to execute {number} times')\n",
"for n in range(1, 11):\n",
" jitted_lambda_returns.clear_cache()\n",
" gc.collect()\n",
" size = 2 ** n\n",
" arr = jnp.zeros((size,), float)\n",
" mean_time_taken = timeit.timeit('lambda_returns(r_t, discount_t, v_t, 1.).block_until_ready()',\n",
" number=number,\n",
" globals={\n",
" 'lambda_returns': lambda_returns,\n",
" 'r_t': arr,\n",
" 'discount_t': arr,\n",
" 'v_t': arr\n",
" }) / number\n",
" jit_first_time_taken = timeit.timeit(\n",
" 'jitted_lambda_returns.clear_cache(); jitted_lambda_returns(r_t, discount_t, v_t, 1.).block_until_ready()',\n",
" number=1,\n",
" globals={\n",
" 'jitted_lambda_returns': jitted_lambda_returns,\n",
" 'r_t': arr,\n",
" 'discount_t': arr,\n",
" 'v_t': arr\n",
" })\n",
" mean_jit_exec_time_taken = timeit.timeit('jitted_lambda_returns(r_t, discount_t, v_t, 1.).block_until_ready()',\n",
" number=number,\n",
" globals={\n",
" 'jitted_lambda_returns': jitted_lambda_returns,\n",
" 'r_t': arr,\n",
" 'discount_t': arr,\n",
" 'v_t': arr\n",
" }) / number\n",
" print(str(size).ljust(6), mean_time_taken, jit_first_time_taken, mean_jit_exec_time_taken, flush=True)\n",
" sizes.append(size)\n",
" mean_time_takens.append(mean_time_taken)\n",
" jit_first_time_takens.append(jit_first_time_taken)\n",
" mean_jit_exec_time_takens.append(mean_jit_exec_time_taken)\n"
]
},
{
"cell_type": "code",
"source": [
"plt.figure(figsize=(12, 6))\n",
"plt.subplot(1, 2, 1)\n",
"plt.plot(sizes, mean_time_takens, label='Mean Normal', color='blue')\n",
"plt.plot(sizes, mean_jit_exec_time_takens, label='Mean JIT Exec', color='red')\n",
"plt.yscale('log')\n",
"plt.ylabel(f'Time Taken to execute {number} times (s)')\n",
"plt.xlabel('Size')\n",
"plt.legend()\n",
"plt.subplot(1, 2, 2)\n",
"plt.plot(sizes, jit_first_time_takens, label='JIT Compile & Exec', color='red')\n",
"plt.ylabel(f'Time Taken to execute {number} times (s)')\n",
"plt.xlabel('Size')\n",
"plt.legend()\n",
"plt.tight_layout()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 441
},
"id": "YLScDcrfV2Bp",
"outputId": "c0d742dd-6f69-45b2-8e33-c891ac0b88ea"
},
"execution_count": 30,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 864x432 with 2 Axes>"
],
"image/png": "\n"
},
"metadata": {
"needs_background": "light"
}
}
]
}
]
}
@cemlyn007
Copy link
Author

I should have also plotted the memory usage.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment