Last active
December 11, 2021 16:43
-
-
Save shoyer/9c6593ef6f65ddfcb394e96b90f87a72 to your computer and use it in GitHub Desktop.
JAX harmonic oscillator odeint.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"name": "JAX harmonic oscillator odeint.ipynb", | |
"provenance": [], | |
"collapsed_sections": [], | |
"authorship_tag": "ABX9TyMxr/n/Hphj+rk7PVD7eCj9", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/shoyer/9c6593ef6f65ddfcb394e96b90f87a72/jax-harmonic-oscillator-odeint.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "wvDXYIeEB-js", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 34 | |
}, | |
"outputId": "92dcc90f-3f76-4c63-9e0d-4c5bcc7bb0ec" | |
}, | |
"source": [ | |
"! pip install -U -q git+https://github.com/google/jax.git" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
" Building wheel for jax (setup.py) ... \u001b[?25l\u001b[?25hdone\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "Z594yUSsCCw4", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"# Copyright 2020 Google LLC.\n", | |
"# SPDX-License-Identifier: Apache-2.0\n", | |
"\n", | |
"from jax.experimental.ode import odeint\n", | |
"from jax.experimental import host_callback\n", | |
"import jax.numpy as jnp\n", | |
"import jax\n", | |
"from functools import partial, wraps\n", | |
"from collections import namedtuple\n", | |
"import time\n", | |
"\n", | |
"CallRecord = namedtuple('CallRecord', 'args, kwargs, result, transforms')\n", | |
"\n", | |
"def recorded(fun):\n", | |
" \"\"\"Record function calls, using host_callback.\"\"\" \n", | |
" records = []\n", | |
"\n", | |
" def tap_func(arg, transforms=None):\n", | |
" args, kwargs, result = arg\n", | |
" records.append(CallRecord(args, kwargs, result, transforms))\n", | |
"\n", | |
" @wraps(fun)\n", | |
" def wrapper(*args, **kwargs):\n", | |
" result = fun(*args, **kwargs)\n", | |
" return host_callback.id_tap(tap_func, (args, kwargs, result), result=result)\n", | |
"\n", | |
" return wrapper, records\n", | |
"\n", | |
"def oscillator(state, t, k=1.0, c=1.0, m=1.0):\n", | |
" # https://en.wikipedia.org/wiki/Harmonic_oscillator#Damped_harmonic_oscillator\n", | |
" x, x_t = state\n", | |
" x_tt = -k/m * x - c/m * x_t\n", | |
" return x_t, x_tt\n", | |
"\n", | |
"def ode_loss(fun, init_state, t_max, *params):\n", | |
" t = jnp.linspace(0, t_max, num=2)\n", | |
" xs, _ = odeint(fun, init_state, t, *params)\n", | |
" return xs[-1]" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "QEF_ei9SD3_h", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"def benchmark(\n", | |
" k=1.0,\n", | |
" c=1.0,\n", | |
" m=1.0,\n", | |
" t_max=1.0,\n", | |
" init_state=(1.0, 0.0),\n", | |
"):\n", | |
" args = (init_state, t_max, k, c, m)\n", | |
"\n", | |
" f, forward_records = recorded(oscillator)\n", | |
" jit_loss = jax.jit(partial(ode_loss, f))\n", | |
" print('forward result: ', jit_loss(*args))\n", | |
" host_callback.barrier_wait()\n", | |
" num_forward = len(forward_records)\n", | |
" print('forward evaluations: ', num_forward)\n", | |
"\n", | |
" f, gradient_records = recorded(oscillator)\n", | |
" jit_grad = jax.jit(jax.grad(partial(ode_loss, f), argnums=(0, 1, 2, 3, 4)))\n", | |
" print('gradient result: ', jit_grad(*args))\n", | |
" host_callback.barrier_wait()\n", | |
" num_gradient = len(gradient_records)\n", | |
" print('gradient evaluations:', num_gradient)\n", | |
"\n", | |
" print(f'evaluation ratio: {num_gradient / num_forward :2.1f}')" | |
], | |
"execution_count": 35, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "kaOVBJQWDh-J", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 122 | |
}, | |
"outputId": "c110d511-d56c-48ac-811e-0389e0bda713" | |
}, | |
"source": [ | |
"benchmark(k=1.0, c=1.0, t_max=1.0)" | |
], | |
"execution_count": 36, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"forward result: 0.65969956\n", | |
"forward evaluations: 62\n", | |
"gradient result: ((DeviceArray(0.65970004, dtype=float32), DeviceArray(0.53350735, dtype=float32)), DeviceArray(-0.5335077, dtype=float32), DeviceArray(-0.31360716, dtype=float32), DeviceArray(0.09370755, dtype=float32), DeviceArray(0.21989964, dtype=float32))\n", | |
"gradient evaluations: 187\n", | |
"evaluation ratio: 3.0\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "q3oSX5ktFowU", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 122 | |
}, | |
"outputId": "09cd189a-4092-4cf9-e9ac-3a4d3978f667" | |
}, | |
"source": [ | |
"benchmark(k=1.0, c=100.0, t_max=1.0)" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"forward result: 0.99014837\n", | |
"forward evaluations: 332\n", | |
"gradient result: ((DeviceArray(0.9901495, dtype=float32), DeviceArray(0.00990249, dtype=float32)), DeviceArray(-0.00990246, dtype=float32), DeviceArray(2.7833288e+29, dtype=float32), DeviceArray(-2.7830507e+31, dtype=float32), DeviceArray(2.7827725e+33, dtype=float32))\n", | |
"gradient evaluations: 8737\n", | |
"evaluation ratio: 26.3\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "O_MM1M_QDuJg", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 122 | |
}, | |
"outputId": "2506af8f-1c28-4bd1-875c-d7b3bf7339a5" | |
}, | |
"source": [ | |
"benchmark(k=1.0, c=1.0, t_max=100.0)" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"forward result: 7.762907e-09\n", | |
"forward evaluations: 698\n", | |
"gradient result: ((DeviceArray(-6.91281e-23, dtype=float32), DeviceArray(-2.1801066e-22, dtype=float32)), DeviceArray(1.2263421e-10, dtype=float32), DeviceArray(2.628874e-07, dtype=float32), DeviceArray(-5.232058e-07, dtype=float32), DeviceArray(2.6031813e-07, dtype=float32))\n", | |
"gradient evaluations: 9730\n", | |
"evaluation ratio: 13.9\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "AG8dga8_D3C5", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 122 | |
}, | |
"outputId": "eae50e54-0111-46dc-999e-649862ab4e3e" | |
}, | |
"source": [ | |
"benchmark(k=1e4, c=1.0, t_max=1.0)" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"forward result: 0.52108455\n", | |
"forward evaluations: 6416\n", | |
"gradient result: ((DeviceArray(0.52106625, dtype=float32), DeviceArray(-0.00307838, dtype=float32)), DeviceArray(30.780766, dtype=float32), DeviceArray(0.00155218, dtype=float32), DeviceArray(-0.26285663, dtype=float32), DeviceArray(-15.258909, dtype=float32))\n", | |
"gradient evaluations: 8027\n", | |
"evaluation ratio: 1.3\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "vmYg_Mj-Fd2K", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"" | |
], | |
"execution_count": null, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment