Skip to content

Instantly share code, notes, and snippets.

@shoyer
Created January 16, 2020 01:06
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save shoyer/a44120aa308cb74e87797e8050df7a6b to your computer and use it in GitHub Desktop.
Save shoyer/a44120aa308cb74e87797e8050df7a6b to your computer and use it in GitHub Desktop.
jax odeint benchmark.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "jax odeint benchmark.ipynb",
"provenance": [],
"collapsed_sections": [],
"authorship_tag": "ABX9TyMMZtCAuWMpqpLEjbP0HboQ",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/shoyer/a44120aa308cb74e87797e8050df7a6b/jax-odeint-benchmark.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"metadata": {
"id": "FXob83P53SRD",
"colab_type": "code",
"colab": {}
},
"source": [
"! pip install -U -q jax jaxlib"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "pmqk-rOT3K8g",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"outputId": "e5817734-02e9-4e23-afcb-6f9ce020b14b"
},
"source": [
"import jax.numpy as jnp\n",
"from jax.experimental.ode import odeint\n",
" \n",
"def f(u, t, sigma, rho, beta):\n",
" x, y, z = u\n",
" return jnp.array([sigma * (y - x),\n",
" x * (rho - z) - y,\n",
" x * y - beta * z])\n",
" \n",
"u0 = jnp.array([1.0, 0.0, 0.0])\n",
"tspan = (0., 100.)\n",
"t = jnp.linspace(0, 100, 1001)\n",
"sol = odeint(f, u0, t, 10.0, 28.0, 8/3, rtol=1e-8, atol=1e-8)\n",
"\n",
"%timeit odeint(f, u0, t, 10.0, 28.0, 8/3, rtol=1e-8, atol=1e-8).block_until_ready()"
],
"execution_count": 26,
"outputs": [
{
"output_type": "stream",
"text": [
"100 loops, best of 3: 3.66 ms per loop\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "v-VBHO9d3PBw",
"colab_type": "code",
"colab": {}
},
"source": [
""
],
"execution_count": 0,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment