Skip to content

Instantly share code, notes, and snippets.

@shoyer
Last active December 11, 2021 16:43
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save shoyer/9c6593ef6f65ddfcb394e96b90f87a72 to your computer and use it in GitHub Desktop.
Save shoyer/9c6593ef6f65ddfcb394e96b90f87a72 to your computer and use it in GitHub Desktop.
JAX harmonic oscillator odeint.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "JAX harmonic oscillator odeint.ipynb",
"provenance": [],
"collapsed_sections": [],
"authorship_tag": "ABX9TyMxr/n/Hphj+rk7PVD7eCj9",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/shoyer/9c6593ef6f65ddfcb394e96b90f87a72/jax-harmonic-oscillator-odeint.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"metadata": {
"id": "wvDXYIeEB-js",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"outputId": "92dcc90f-3f76-4c63-9e0d-4c5bcc7bb0ec"
},
"source": [
"! pip install -U -q git+https://github.com/google/jax.git"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
" Building wheel for jax (setup.py) ... \u001b[?25l\u001b[?25hdone\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "Z594yUSsCCw4",
"colab_type": "code",
"colab": {}
},
"source": [
"# Copyright 2020 Google LLC.\n",
"# SPDX-License-Identifier: Apache-2.0\n",
"\n",
"from jax.experimental.ode import odeint\n",
"from jax.experimental import host_callback\n",
"import jax.numpy as jnp\n",
"import jax\n",
"from functools import partial, wraps\n",
"from collections import namedtuple\n",
"import time\n",
"\n",
"CallRecord = namedtuple('CallRecord', 'args, kwargs, result, transforms')\n",
"\n",
"def recorded(fun):\n",
" \"\"\"Record function calls, using host_callback.\"\"\" \n",
" records = []\n",
"\n",
" def tap_func(arg, transforms=None):\n",
" args, kwargs, result = arg\n",
" records.append(CallRecord(args, kwargs, result, transforms))\n",
"\n",
" @wraps(fun)\n",
" def wrapper(*args, **kwargs):\n",
" result = fun(*args, **kwargs)\n",
" return host_callback.id_tap(tap_func, (args, kwargs, result), result=result)\n",
"\n",
" return wrapper, records\n",
"\n",
"def oscillator(state, t, k=1.0, c=1.0, m=1.0):\n",
" # https://en.wikipedia.org/wiki/Harmonic_oscillator#Damped_harmonic_oscillator\n",
" x, x_t = state\n",
" x_tt = -k/m * x - c/m * x_t\n",
" return x_t, x_tt\n",
"\n",
"def ode_loss(fun, init_state, t_max, *params):\n",
" t = jnp.linspace(0, t_max, num=2)\n",
" xs, _ = odeint(fun, init_state, t, *params)\n",
" return xs[-1]"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "QEF_ei9SD3_h",
"colab_type": "code",
"colab": {}
},
"source": [
"def benchmark(\n",
" k=1.0,\n",
" c=1.0,\n",
" m=1.0,\n",
" t_max=1.0,\n",
" init_state=(1.0, 0.0),\n",
"):\n",
" args = (init_state, t_max, k, c, m)\n",
"\n",
" f, forward_records = recorded(oscillator)\n",
" jit_loss = jax.jit(partial(ode_loss, f))\n",
" print('forward result: ', jit_loss(*args))\n",
" host_callback.barrier_wait()\n",
" num_forward = len(forward_records)\n",
" print('forward evaluations: ', num_forward)\n",
"\n",
" f, gradient_records = recorded(oscillator)\n",
" jit_grad = jax.jit(jax.grad(partial(ode_loss, f), argnums=(0, 1, 2, 3, 4)))\n",
" print('gradient result: ', jit_grad(*args))\n",
" host_callback.barrier_wait()\n",
" num_gradient = len(gradient_records)\n",
" print('gradient evaluations:', num_gradient)\n",
"\n",
" print(f'evaluation ratio: {num_gradient / num_forward :2.1f}')"
],
"execution_count": 35,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "kaOVBJQWDh-J",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 122
},
"outputId": "c110d511-d56c-48ac-811e-0389e0bda713"
},
"source": [
"benchmark(k=1.0, c=1.0, t_max=1.0)"
],
"execution_count": 36,
"outputs": [
{
"output_type": "stream",
"text": [
"forward result: 0.65969956\n",
"forward evaluations: 62\n",
"gradient result: ((DeviceArray(0.65970004, dtype=float32), DeviceArray(0.53350735, dtype=float32)), DeviceArray(-0.5335077, dtype=float32), DeviceArray(-0.31360716, dtype=float32), DeviceArray(0.09370755, dtype=float32), DeviceArray(0.21989964, dtype=float32))\n",
"gradient evaluations: 187\n",
"evaluation ratio: 3.0\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "q3oSX5ktFowU",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 122
},
"outputId": "09cd189a-4092-4cf9-e9ac-3a4d3978f667"
},
"source": [
"benchmark(k=1.0, c=100.0, t_max=1.0)"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"forward result: 0.99014837\n",
"forward evaluations: 332\n",
"gradient result: ((DeviceArray(0.9901495, dtype=float32), DeviceArray(0.00990249, dtype=float32)), DeviceArray(-0.00990246, dtype=float32), DeviceArray(2.7833288e+29, dtype=float32), DeviceArray(-2.7830507e+31, dtype=float32), DeviceArray(2.7827725e+33, dtype=float32))\n",
"gradient evaluations: 8737\n",
"evaluation ratio: 26.3\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "O_MM1M_QDuJg",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 122
},
"outputId": "2506af8f-1c28-4bd1-875c-d7b3bf7339a5"
},
"source": [
"benchmark(k=1.0, c=1.0, t_max=100.0)"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"forward result: 7.762907e-09\n",
"forward evaluations: 698\n",
"gradient result: ((DeviceArray(-6.91281e-23, dtype=float32), DeviceArray(-2.1801066e-22, dtype=float32)), DeviceArray(1.2263421e-10, dtype=float32), DeviceArray(2.628874e-07, dtype=float32), DeviceArray(-5.232058e-07, dtype=float32), DeviceArray(2.6031813e-07, dtype=float32))\n",
"gradient evaluations: 9730\n",
"evaluation ratio: 13.9\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "AG8dga8_D3C5",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 122
},
"outputId": "eae50e54-0111-46dc-999e-649862ab4e3e"
},
"source": [
"benchmark(k=1e4, c=1.0, t_max=1.0)"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"forward result: 0.52108455\n",
"forward evaluations: 6416\n",
"gradient result: ((DeviceArray(0.52106625, dtype=float32), DeviceArray(-0.00307838, dtype=float32)), DeviceArray(30.780766, dtype=float32), DeviceArray(0.00155218, dtype=float32), DeviceArray(-0.26285663, dtype=float32), DeviceArray(-15.258909, dtype=float32))\n",
"gradient evaluations: 8027\n",
"evaluation ratio: 1.3\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "vmYg_Mj-Fd2K",
"colab_type": "code",
"colab": {}
},
"source": [
""
],
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment