Skip to content

Instantly share code, notes, and snippets.

@zhangqiaorjc
Created November 10, 2023 05:27
Show Gist options
  • Save zhangqiaorjc/73cc154f22cbae474f6959d9a9fc8589 to your computer and use it in GitHub Desktop.
Save zhangqiaorjc/73cc154f22cbae474f6959d9a9fc8589 to your computer and use it in GitHub Desktop.
sincos remat example.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/zhangqiaorjc/73cc154f22cbae474f6959d9a9fc8589/sincos-remat-example.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "wgXLGxV2-nfI"
},
"outputs": [],
"source": [
"import jax\n",
"from jax import core\n",
"import jax.numpy as jnp"
]
},
{
"cell_type": "code",
"source": [
"sincos_p = core.Primitive('sincos')\n",
"sincos_p.multiple_results = True"
],
"metadata": {
"id": "OXui8YA2-uF4"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"@sincos_p.def_impl\n",
"def sincos_impl(x):\n",
" return jnp.sin(x), jnp.cos(x)"
],
"metadata": {
"id": "eeeD56dS-8Lc"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"@sincos_p.def_abstract_eval\n",
"def sincos_abstract_eval(x):\n",
" return x, x"
],
"metadata": {
"id": "m0L5ow22_AmN"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def sincos(x):\n",
" return sincos_p.bind(x)"
],
"metadata": {
"id": "0xyg3Jg__Q8C"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"jax.make_jaxpr(sincos)(5)"
],
"metadata": {
"id": "m1cFiNCb_LbK",
"outputId": "5f464fa0-b521-4f16-98d3-4bea39ec3850"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"{ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; a\u001b[35m:i32[]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\u001b[39m\u001b[22m\u001b[22m b\u001b[35m:i32[]\u001b[39m c\u001b[35m:i32[]\u001b[39m = sincos a \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(b, c) }"
]
},
"metadata": {},
"execution_count": 8
}
]
},
{
"cell_type": "code",
"source": [
"@jax.custom_vjp\n",
"def sin(x):\n",
" return jnp.sin(x)\n",
"\n",
"def sin_fwd(x):\n",
" return sincos(x)\n",
"\n",
"def sin_bwd(res, g):\n",
" return (res * g,)\n",
"\n",
"sin.defvjp(sin_fwd, sin_bwd)"
],
"metadata": {
"id": "w6pG4Q0__PdI"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"jax.make_jaxpr(jax.grad(sin))(5.0)"
],
"metadata": {
"id": "MUZNly7x_vWW",
"outputId": "34e730b9-6b0f-4794-c2ba-a384f2fc9813"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"{ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; a\u001b[35m:f32[]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n",
" \u001b[39m\u001b[22m\u001b[22m_\u001b[35m:f32[]\u001b[39m b\u001b[35m:f32[]\u001b[39m = sincos a\n",
" c\u001b[35m:f32[]\u001b[39m = mul b 1.0\n",
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(c,) }"
]
},
"metadata": {},
"execution_count": 14
}
]
},
{
"cell_type": "code",
"source": [
"def loss(x):\n",
" return jnp.exp(sin(x))"
],
"metadata": {
"id": "5tgmUiTH_2aH"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"jax.make_jaxpr(jax.grad(loss))(5.0)"
],
"metadata": {
"id": "NW80ucjxAK1z",
"outputId": "21839d74-de71-4e48-d91e-222b199b4705"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"{ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; a\u001b[35m:f32[]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n",
" \u001b[39m\u001b[22m\u001b[22mb\u001b[35m:f32[]\u001b[39m c\u001b[35m:f32[]\u001b[39m = sincos a\n",
" d\u001b[35m:f32[]\u001b[39m = exp b\n",
" e\u001b[35m:f32[]\u001b[39m = mul 1.0 d\n",
" f\u001b[35m:f32[]\u001b[39m = mul c e\n",
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(f,) }"
]
},
"metadata": {},
"execution_count": 17
}
]
},
{
"cell_type": "code",
"source": [
"@jax.checkpoint\n",
"def loss(x):\n",
" return jnp.exp(sin(x))\n",
"\n",
"jax.make_jaxpr(jax.grad(loss))(5.0)"
],
"metadata": {
"id": "C1swFN6MAMRa",
"outputId": "5a81c4bd-d466-422d-b52a-762441006c91"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"{ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; a\u001b[35m:f32[]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n",
" \u001b[39m\u001b[22m\u001b[22mb\u001b[35m:f32[]\u001b[39m _\u001b[35m:f32[]\u001b[39m = sincos a\n",
" _\u001b[35m:f32[]\u001b[39m = exp b\n",
" c\u001b[35m:f32[]\u001b[39m = remat2[\n",
" differentiated=True\n",
" jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; d\u001b[35m:f32[]\u001b[39m e\u001b[35m:f32[]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n",
" \u001b[39m\u001b[22m\u001b[22mf\u001b[35m:f32[]\u001b[39m g\u001b[35m:f32[]\u001b[39m = sincos d\n",
" h\u001b[35m:f32[]\u001b[39m = exp f\n",
" i\u001b[35m:f32[]\u001b[39m = mul e h\n",
" j\u001b[35m:f32[]\u001b[39m = mul g i\n",
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(j,) }\n",
" policy=None\n",
" prevent_cse=True\n",
" ] a 1.0\n",
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(c,) }"
]
},
"metadata": {},
"execution_count": 20
}
]
},
{
"cell_type": "code",
"source": [
"def loss(x):\n",
" x = jax._src.ad_checkpoint.checkpoint_name(sin(x), 'sin(x)')\n",
" return jnp.exp(x)\n",
"\n",
"loss = jax.checkpoint(loss, policy=jax.checkpoint_policies.save_only_these_names('sin(x)'))\n",
"\n",
"jax.make_jaxpr(jax.grad(loss))(5.0)"
],
"metadata": {
"id": "tZP1R5t7AWU8",
"outputId": "48d1919c-7d02-4c98-c169-d454239b8ed4"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"{ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; a\u001b[35m:f32[]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n",
" \u001b[39m\u001b[22m\u001b[22mb\u001b[35m:f32[]\u001b[39m _\u001b[35m:f32[]\u001b[39m = sincos a\n",
" c\u001b[35m:f32[]\u001b[39m = name[name=sin(x)] b\n",
" _\u001b[35m:f32[]\u001b[39m = exp c\n",
" d\u001b[35m:f32[]\u001b[39m = remat2[\n",
" differentiated=True\n",
" jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; e\u001b[35m:f32[]\u001b[39m f\u001b[35m:f32[]\u001b[39m g\u001b[35m:f32[]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n",
" \u001b[39m\u001b[22m\u001b[22m_\u001b[35m:f32[]\u001b[39m h\u001b[35m:f32[]\u001b[39m = sincos f\n",
" i\u001b[35m:f32[]\u001b[39m = exp e\n",
" j\u001b[35m:f32[]\u001b[39m = mul g i\n",
" k\u001b[35m:f32[]\u001b[39m = mul h j\n",
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(k,) }\n",
" policy=<function save_only_these_names.<locals>.policy at 0x7f531ed7b130>\n",
" prevent_cse=True\n",
" ] c a 1.0\n",
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(d,) }"
]
},
"metadata": {},
"execution_count": 23
}
]
},
{
"cell_type": "code",
"source": [
"@jax.custom_vjp\n",
"def sin(x):\n",
" return jnp.sin(x)\n",
"\n",
"def sin_fwd(x):\n",
" sinx, cosx = sincos(x)\n",
" sinx = jax._src.ad_checkpoint.checkpoint_name(sinx, 'sin(x)')\n",
" cosx = jax._src.ad_checkpoint.checkpoint_name(cosx, 'sin(x)')\n",
" return sinx, cosx\n",
"\n",
"def sin_bwd(res, g):\n",
" return (res * g,)\n",
"\n",
"sin.defvjp(sin_fwd, sin_bwd)\n",
"\n",
"def loss(x):\n",
" return jnp.exp(sin(x))\n",
"\n",
"loss = jax.checkpoint(loss, policy=jax.checkpoint_policies.save_only_these_names('sin(x)'))\n",
"\n",
"jax.make_jaxpr(jax.grad(loss))(5.0)"
],
"metadata": {
"id": "QbiqHCecAr_l",
"outputId": "31d4a105-df41-4e7b-f17c-814a19ad6a2c"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"{ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; a\u001b[35m:f32[]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n",
" \u001b[39m\u001b[22m\u001b[22mb\u001b[35m:f32[]\u001b[39m c\u001b[35m:f32[]\u001b[39m = sincos a\n",
" d\u001b[35m:f32[]\u001b[39m = name[name=sin(x)] b\n",
" e\u001b[35m:f32[]\u001b[39m = name[name=sin(x)] c\n",
" _\u001b[35m:f32[]\u001b[39m = exp d\n",
" f\u001b[35m:f32[]\u001b[39m = remat2[\n",
" differentiated=True\n",
" jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; g\u001b[35m:f32[]\u001b[39m h\u001b[35m:f32[]\u001b[39m i\u001b[35m:f32[]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n",
" \u001b[39m\u001b[22m\u001b[22mj\u001b[35m:f32[]\u001b[39m = exp h\n",
" k\u001b[35m:f32[]\u001b[39m = mul i j\n",
" l\u001b[35m:f32[]\u001b[39m = mul g k\n",
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(l,) }\n",
" policy=<function save_only_these_names.<locals>.policy at 0x7f531ed7ba30>\n",
" prevent_cse=True\n",
" ] e d 1.0\n",
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(f,) }"
]
},
"metadata": {},
"execution_count": 25
}
]
},
{
"cell_type": "code",
"source": [],
"metadata": {
"id": "t0IEhLtICZY4"
},
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment