Skip to content

Instantly share code, notes, and snippets.

@shoyer
Last active June 8, 2020 20:57
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save shoyer/98803242f9c0d3c4ddf442f9e063a8df to your computer and use it in GitHub Desktop.
Save shoyer/98803242f9c0d3c4ddf442f9e063a8df to your computer and use it in GitHub Desktop.
JAX einsum primitive .ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "JAX einsum primitive .ipynb",
"provenance": [],
"collapsed_sections": [],
"authorship_tag": "ABX9TyMR2zeT/AhQogvfYQwSUf2O",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/shoyer/98803242f9c0d3c4ddf442f9e063a8df/jax-einsum-primitive.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"metadata": {
"id": "effqC7pgqK0j",
"colab_type": "code",
"outputId": "e28b151a-8483-4621-eac2-42906b12a629",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 136
}
},
"source": [
"! pip install -U jax jaxlib"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"Requirement already up-to-date: jax in /usr/local/lib/python3.6/dist-packages (0.1.69)\n",
"Requirement already up-to-date: jaxlib in /usr/local/lib/python3.6/dist-packages (0.1.47)\n",
"Requirement already satisfied, skipping upgrade: absl-py in /usr/local/lib/python3.6/dist-packages (from jax) (0.9.0)\n",
"Requirement already satisfied, skipping upgrade: numpy>=1.12 in /usr/local/lib/python3.6/dist-packages (from jax) (1.18.5)\n",
"Requirement already satisfied, skipping upgrade: opt-einsum in /usr/local/lib/python3.6/dist-packages (from jax) (3.2.1)\n",
"Requirement already satisfied, skipping upgrade: scipy in /usr/local/lib/python3.6/dist-packages (from jaxlib) (1.4.1)\n",
"Requirement already satisfied, skipping upgrade: six in /usr/local/lib/python3.6/dist-packages (from absl-py->jax) (1.12.0)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "E987RUewqIch",
"colab_type": "code",
"colab": {}
},
"source": [
"# Copyright 2018 Google LLC\n",
"#\n",
"# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
"# https://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License.\n",
"\n",
"import collections\n",
"import functools\n",
"import itertools\n",
"import operator\n",
"import threading\n",
"\n",
"import numpy as onp\n",
"\n",
"from jax import api\n",
"from jax import core\n",
"from jax import dtypes\n",
"from jax.lax import lax\n",
"from jax import linear_util as lu\n",
"from jax.abstract_arrays import ShapedArray, raise_to_shaped\n",
"from jax.api_util import flatten_fun_nokwargs, apply_flat_fun_nokwargs\n",
"from jax.interpreters import ad\n",
"from jax.interpreters import partial_eval as pe\n",
"from jax.interpreters import xla\n",
"from jax.interpreters import batching\n",
"from jax.interpreters import masking\n",
"from jax.lib import xla_bridge as xb\n",
"from jax.lib import xla_client\n",
"from jax.util import (partial, unzip2, safe_map, safe_zip, split_list,\n",
" split_dict, cache, extend_name_stack)\n",
"from jax.tree_util import (tree_flatten, tree_unflatten, treedef_is_leaf,\n",
" treedef_children, treedef_tuple)\n",
"from jax import ad_util\n",
"import jax.numpy as jnp\n",
"import jax.test_util as jtu\n",
"\n",
"map = safe_map\n",
"zip = safe_zip\n",
"\n",
"\n",
"def einsum(*operands):\n",
" input_string, output_string, operands = _parse_einsum_input(operands)\n",
" out, = einsum_p.bind(*operands, input_strings=input_string.split(','),\n",
" output_string=output_string)\n",
" return out\n",
"\n",
"def _einsum_impl(*operands, input_strings, output_string):\n",
" subscripts = ','.join(input_strings) + '->' + output_string\n",
" return [jnp.einsum(subscripts, *operands)]\n",
"\n",
"def sum_tangents(tangents):\n",
" return functools.reduce(ad.add_tangents, tangents, ad.zero)\n",
"\n",
"def _einsum_jvp(primals, tangents, *, input_strings, output_string):\n",
" subscripts = ','.join(input_strings) + '->' + output_string\n",
" this_einsum = functools.partial(einsum, subscripts)\n",
" operands_list = []\n",
" for index, tangent in enumerate(tangents):\n",
" if tangent is not ad.zero:\n",
" operands = list(primals)\n",
" operands[index] = tangent\n",
" operands_list.append(operands)\n",
" out_primal = this_einsum(*primals)\n",
" out_tangent = sum_tangents(this_einsum(*ops) for ops in operands_list)\n",
" return [out_primal], [out_tangent]\n",
"\n",
"def _einsum_transpose_rule(cotangent, *primals, input_strings, output_string):\n",
" index, = [i for i, p in enumerate(primals) if ad.is_undefined_primal(p)]\n",
" subscripts = (','.join(input_strings[:index] + input_strings[index+1:])\n",
" + ',' + output_string\n",
" + '->' + input_strings[index])\n",
" operands = primals[:index] + primals[index+1:] + tuple(cotangent)\n",
" out = [None] * len(primals)\n",
" out[index] = einsum(subscripts, *operands)\n",
" return out\n",
"\n",
"einsum_p = core.Primitive('einsum')\n",
"einsum_p.multiple_results = True\n",
"einsum_p.def_impl(_einsum_impl)\n",
"\n",
"def generic_abstract_eval(*avals, **params):\n",
" return pe.abstract_eval_fun(_einsum_impl, *avals, **params)\n",
"einsum_p.def_abstract_eval(generic_abstract_eval)\n",
"\n",
"ad.primitive_jvps[einsum_p] = _einsum_jvp\n",
"\n",
"xla.initial_style_translations[einsum_p] = xla.lower_fun_initial_style(_einsum_impl)\n",
"\n",
"ad.primitive_transposes[einsum_p] = _einsum_transpose_rule\n",
"\n",
"# TODO(shoyer): batching rule (should be pretty easy)\n",
"# batching.primitive_batchers[einsum_p] = _einsum_batching_rule\n"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "jqWG8ZJbqfAa",
"colab_type": "code",
"colab": {}
},
"source": [
"#@title define `_parse_einsum_input` (from numpy) { display-mode: \"form\" }\n",
"# from numpy.core.einsumfunc\n",
"\n",
"einsum_symbols = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'\n",
"einsum_symbols_set = set(einsum_symbols)\n",
"\n",
"# asarray = lambda x: x\n",
"asarray = jnp.asarray\n",
"\n",
"def _parse_einsum_input(operands):\n",
" \"\"\"\n",
" A reproduction of einsum c side einsum parsing in python.\n",
"\n",
" Returns\n",
" -------\n",
" input_strings : str\n",
" Parsed input strings\n",
" output_string : str\n",
" Parsed output string\n",
" operands : list of array_like\n",
" The operands to use in the numpy contraction\n",
"\n",
" Examples\n",
" --------\n",
" The operand list is simplified to reduce printing:\n",
"\n",
" >>> a = np.random.rand(4, 4)\n",
" >>> b = np.random.rand(4, 4, 4)\n",
" >>> __parse_einsum_input(('...a,...a->...', a, b))\n",
" ('za,xza', 'xz', [a, b])\n",
"\n",
" >>> __parse_einsum_input((a, [Ellipsis, 0], b, [Ellipsis, 0]))\n",
" ('za,xza', 'xz', [a, b])\n",
" \"\"\"\n",
"\n",
" if len(operands) == 0:\n",
" raise ValueError(\"No input operands\")\n",
"\n",
" if isinstance(operands[0], str):\n",
" subscripts = operands[0].replace(\" \", \"\")\n",
" operands = [asarray(v) for v in operands[1:]]\n",
"\n",
" # Ensure all characters are valid\n",
" for s in subscripts:\n",
" if s in '.,->':\n",
" continue\n",
" if s not in einsum_symbols:\n",
" raise ValueError(\"Character %s is not a valid symbol.\" % s)\n",
"\n",
" else:\n",
" tmp_operands = list(operands)\n",
" operand_list = []\n",
" subscript_list = []\n",
" for p in range(len(operands) // 2):\n",
" operand_list.append(tmp_operands.pop(0))\n",
" subscript_list.append(tmp_operands.pop(0))\n",
"\n",
" output_list = tmp_operands[-1] if len(tmp_operands) else None\n",
" operands = [asarray(v) for v in operand_list]\n",
" subscripts = \"\"\n",
" last = len(subscript_list) - 1\n",
" for num, sub in enumerate(subscript_list):\n",
" for s in sub:\n",
" if s is Ellipsis:\n",
" subscripts += \"...\"\n",
" elif isinstance(s, int):\n",
" subscripts += einsum_symbols[s]\n",
" else:\n",
" raise TypeError(\"For this input type lists must contain \"\n",
" \"either int or Ellipsis\")\n",
" if num != last:\n",
" subscripts += \",\"\n",
"\n",
" if output_list is not None:\n",
" subscripts += \"->\"\n",
" for s in output_list:\n",
" if s is Ellipsis:\n",
" subscripts += \"...\"\n",
" elif isinstance(s, int):\n",
" subscripts += einsum_symbols[s]\n",
" else:\n",
" raise TypeError(\"For this input type lists must contain \"\n",
" \"either int or Ellipsis\")\n",
" # Check for proper \"->\"\n",
" if (\"-\" in subscripts) or (\">\" in subscripts):\n",
" invalid = (subscripts.count(\"-\") > 1) or (subscripts.count(\">\") > 1)\n",
" if invalid or (subscripts.count(\"->\") != 1):\n",
" raise ValueError(\"Subscripts can only contain one '->'.\")\n",
"\n",
" # Parse ellipses\n",
" if \".\" in subscripts:\n",
" used = subscripts.replace(\".\", \"\").replace(\",\", \"\").replace(\"->\", \"\")\n",
" unused = list(einsum_symbols_set - set(used))\n",
" ellipse_inds = \"\".join(unused)\n",
" longest = 0\n",
"\n",
" if \"->\" in subscripts:\n",
" input_tmp, output_sub = subscripts.split(\"->\")\n",
" split_subscripts = input_tmp.split(\",\")\n",
" out_sub = True\n",
" else:\n",
" split_subscripts = subscripts.split(',')\n",
" out_sub = False\n",
"\n",
" for num, sub in enumerate(split_subscripts):\n",
" if \".\" in sub:\n",
" if (sub.count(\".\") != 3) or (sub.count(\"...\") != 1):\n",
" raise ValueError(\"Invalid Ellipses.\")\n",
"\n",
" # Take into account numerical values\n",
" if operands[num].shape == ():\n",
" ellipse_count = 0\n",
" else:\n",
" ellipse_count = max(operands[num].ndim, 1)\n",
" ellipse_count -= (len(sub) - 3)\n",
"\n",
" if ellipse_count > longest:\n",
" longest = ellipse_count\n",
"\n",
" if ellipse_count < 0:\n",
" raise ValueError(\"Ellipses lengths do not match.\")\n",
" elif ellipse_count == 0:\n",
" split_subscripts[num] = sub.replace('...', '')\n",
" else:\n",
" rep_inds = ellipse_inds[-ellipse_count:]\n",
" split_subscripts[num] = sub.replace('...', rep_inds)\n",
"\n",
" subscripts = \",\".join(split_subscripts)\n",
" if longest == 0:\n",
" out_ellipse = \"\"\n",
" else:\n",
" out_ellipse = ellipse_inds[-longest:]\n",
"\n",
" if out_sub:\n",
" subscripts += \"->\" + output_sub.replace(\"...\", out_ellipse)\n",
" else:\n",
" # Special care for outputless ellipses\n",
" output_subscript = \"\"\n",
" tmp_subscripts = subscripts.replace(\",\", \"\")\n",
" for s in sorted(set(tmp_subscripts)):\n",
" if s not in (einsum_symbols):\n",
" raise ValueError(\"Character %s is not a valid symbol.\" % s)\n",
" if tmp_subscripts.count(s) == 1:\n",
" output_subscript += s\n",
" normal_inds = ''.join(sorted(set(output_subscript) -\n",
" set(out_ellipse)))\n",
"\n",
" subscripts += \"->\" + out_ellipse + normal_inds\n",
"\n",
" # Build output string if does not exist\n",
" if \"->\" in subscripts:\n",
" input_subscripts, output_subscript = subscripts.split(\"->\")\n",
" else:\n",
" input_subscripts = subscripts\n",
" # Build output subscripts\n",
" tmp_subscripts = subscripts.replace(\",\", \"\")\n",
" output_subscript = \"\"\n",
" for s in sorted(set(tmp_subscripts)):\n",
" if s not in einsum_symbols:\n",
" raise ValueError(\"Character %s is not a valid symbol.\" % s)\n",
" if tmp_subscripts.count(s) == 1:\n",
" output_subscript += s\n",
"\n",
" # Make sure output subscripts are in the input\n",
" for char in output_subscript:\n",
" if char not in input_subscripts:\n",
" raise ValueError(\"Output character %s did not appear in the input\"\n",
" % char)\n",
"\n",
" # Make sure number operands is equivalent to the number of terms\n",
" if len(input_subscripts.split(',')) != len(operands):\n",
" raise ValueError(\"Number of einsum subscripts must be equal to the \"\n",
" \"number of operands.\")\n",
"\n",
" return (input_subscripts, output_subscript, operands)\n"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "3VUxAr_HqWlE",
"colab_type": "code",
"colab": {}
},
"source": [
"import jax\n",
"import jax.test_util as jtu"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "SncY3Y9RqhuT",
"colab_type": "code",
"outputId": "543b1105-b2c5-4e6d-8c39-27a29eb57b39",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 85
}
},
"source": [
"jax.make_jaxpr(partial(einsum, 'i,ij->ij'))(jnp.zeros((2,)), jnp.zeros((2, 3)))"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"{ lambda ; a b.\n",
" let c = einsum[ input_strings=['i', 'ij']\n",
" output_string=ij ] a b\n",
" in (c,) }"
]
},
"metadata": {
"tags": []
},
"execution_count": 19
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "cYymXcbBqkeL",
"colab_type": "code",
"outputId": "1126afb9-88da-41bd-d741-d40156fa76e2",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 85
}
},
"source": [
"jax.make_jaxpr(partial(einsum, 'i,ij,jk->ij'))(jnp.zeros((2,)), jnp.zeros((2, 3)), jnp.zeros((3, 4)))"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"{ lambda ; a b c.\n",
" let d = einsum[ input_strings=['i', 'ij', 'jk']\n",
" output_string=ij ] a b c\n",
" in (d,) }"
]
},
"metadata": {
"tags": []
},
"execution_count": 20
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "EARozZRGqlmL",
"colab_type": "code",
"colab": {}
},
"source": [
"def make_einsum_grad(subscripts):\n",
" @jax.grad\n",
" def f(*operands):\n",
" return jnp.sum(einsum(subscripts, *operands) ** 2)\n",
" return f"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "T2NGlsvSqm3E",
"colab_type": "code",
"outputId": "c707a6a3-e279-4380-bb8c-27f8212621d5",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 153
}
},
"source": [
"jax.make_jaxpr(make_einsum_grad('ij,jk->ij'))(jnp.zeros((2, 3)), jnp.zeros((3, 4)))"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"{ lambda c ; a b.\n",
" let d = einsum[ input_strings=['ij', 'jk']\n",
" output_string=ij ] a b\n",
" e = mul 2.0 d\n",
" f = mul c e\n",
" g = einsum[ input_strings=['jk', 'ij']\n",
" output_string=ij ] b f\n",
" in (g,) }"
]
},
"metadata": {
"tags": []
},
"execution_count": 22
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "Dg_RhNupqoaD",
"colab_type": "code",
"outputId": "8f6e7cea-372f-4e9d-8973-8e49fbf8afd3",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 51
}
},
"source": [
"jax.jit(make_einsum_grad('ij,jk->ij'))(jnp.zeros((2, 3)), jnp.zeros((3, 4)))"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"DeviceArray([[0., 0., 0.],\n",
" [0., 0., 0.]], dtype=float32)"
]
},
"metadata": {
"tags": []
},
"execution_count": 23
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "Lt8vgNuSqxj7",
"colab_type": "code",
"outputId": "a1d2649a-6075-4c1f-d43f-79ef17ac510e",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 51
}
},
"source": [
"make_einsum_grad('ij,jk->ij')(jnp.zeros((2, 3)), jnp.zeros((3, 4)))"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"DeviceArray([[0., 0., 0.],\n",
" [0., 0., 0.]], dtype=float32)"
]
},
"metadata": {
"tags": []
},
"execution_count": 24
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "7BEZ6RkKrOCr",
"colab_type": "code",
"colab": {}
},
"source": [
"from functools import partial\n",
"import numpy as np\n",
"\n",
"rs = np.random.RandomState(0)\n",
"f = partial(einsum, 'i,ij,j->ij')\n",
"args = (rs.randn(2), rs.randn(2, 3), rs.randn(3,))\n",
"jtu.check_grads(f, args, order=2)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "xAvXZH4XrjjF",
"colab_type": "code",
"colab": {}
},
"source": [
"from functools import partial\n",
"import numpy as np\n",
"\n",
"rs = np.random.RandomState(0)\n",
"operands = 'ijk,ij,jk->ij'\n",
"f = partial(einsum, operands)\n",
"args = (rs.randn(2, 3, 4), rs.randn(2, 3), rs.randn(3, 4))\n",
"jtu.check_grads(f, args, order=2)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "PEvOPu8xr5j8",
"colab_type": "code",
"outputId": "c0bc0385-b451-447b-caec-5a9a4c1dfe2c",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 153
}
},
"source": [
"jax.make_jaxpr(make_einsum_grad(operands))(*args)"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"{ lambda d ; a b c.\n",
" let e = einsum[ input_strings=['ijk', 'ij', 'jk']\n",
" output_string=ij ] a b c\n",
" f = mul 2.0 e\n",
" g = mul d f\n",
" h = einsum[ input_strings=['ij', 'jk', 'ij']\n",
" output_string=ijk ] b c g\n",
" in (h,) }"
]
},
"metadata": {
"tags": []
},
"execution_count": 35
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "OFQ8PslmsclX",
"colab_type": "code",
"colab": {}
},
"source": [
""
],
"execution_count": 0,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment