Skip to content

Instantly share code, notes, and snippets.

@shoyer
Created May 18, 2023 17:27
Show Gist options
  • Save shoyer/57107c2d2023d708683cdb5087810179 to your computer and use it in GitHub Desktop.
Save shoyer/57107c2d2023d708683cdb5087810179 to your computer and use it in GitHub Desktop.
JAX einsum primitive .ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"authorship_tag": "ABX9TyO0R2VDwSaxyTeVPyjiTCCE",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/shoyer/57107c2d2023d708683cdb5087810179/jax-einsum-primitive.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"metadata": {
"id": "effqC7pgqK0j",
"outputId": "e28b151a-8483-4621-eac2-42906b12a629",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 136
}
},
"source": [
"! pip install -U jax jaxlib"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"Requirement already up-to-date: jax in /usr/local/lib/python3.6/dist-packages (0.1.69)\n",
"Requirement already up-to-date: jaxlib in /usr/local/lib/python3.6/dist-packages (0.1.47)\n",
"Requirement already satisfied, skipping upgrade: absl-py in /usr/local/lib/python3.6/dist-packages (from jax) (0.9.0)\n",
"Requirement already satisfied, skipping upgrade: numpy>=1.12 in /usr/local/lib/python3.6/dist-packages (from jax) (1.18.5)\n",
"Requirement already satisfied, skipping upgrade: opt-einsum in /usr/local/lib/python3.6/dist-packages (from jax) (3.2.1)\n",
"Requirement already satisfied, skipping upgrade: scipy in /usr/local/lib/python3.6/dist-packages (from jaxlib) (1.4.1)\n",
"Requirement already satisfied, skipping upgrade: six in /usr/local/lib/python3.6/dist-packages (from absl-py->jax) (1.12.0)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "E987RUewqIch"
},
"source": [
"# Copyright 2023 Google LLC\n",
"#\n",
"# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
"# https://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License.\n",
"\n",
"import collections\n",
"import functools\n",
"import itertools\n",
"import operator\n",
"import threading\n",
"\n",
"import numpy as onp\n",
"\n",
"from jax import api\n",
"from jax import core\n",
"from jax import dtypes\n",
"from jax.lax import lax\n",
"from jax import linear_util as lu\n",
"from jax.abstract_arrays import ShapedArray, raise_to_shaped\n",
"from jax.api_util import flatten_fun_nokwargs, apply_flat_fun_nokwargs\n",
"from jax.interpreters import ad\n",
"from jax.interpreters import partial_eval as pe\n",
"from jax.interpreters import xla\n",
"from jax.interpreters import batching\n",
"from jax.interpreters import masking\n",
"from jax.lib import xla_bridge as xb\n",
"from jax.lib import xla_client\n",
"from jax.util import (partial, unzip2, safe_map, safe_zip, split_list,\n",
" split_dict, cache, extend_name_stack)\n",
"from jax.tree_util import (tree_flatten, tree_unflatten, treedef_is_leaf,\n",
" treedef_children, treedef_tuple)\n",
"from jax import ad_util\n",
"import jax.numpy as jnp\n",
"import jax.test_util as jtu\n",
"\n",
"map = safe_map\n",
"zip = safe_zip\n",
"\n",
"\n",
"def einsum(*operands):\n",
" input_string, output_string, operands = _parse_einsum_input(operands)\n",
" out, = einsum_p.bind(*operands, input_strings=input_string.split(','),\n",
" output_string=output_string)\n",
" return out\n",
"\n",
"def _einsum_impl(*operands, input_strings, output_string):\n",
" subscripts = ','.join(input_strings) + '->' + output_string\n",
" return [jnp.einsum(subscripts, *operands)]\n",
"\n",
"def sum_tangents(tangents):\n",
" return functools.reduce(ad.add_tangents, tangents)\n",
"\n",
"def _einsum_jvp(primals, tangents, *, input_strings, output_string):\n",
" subscripts = ','.join(input_strings) + '->' + output_string\n",
" this_einsum = functools.partial(einsum, subscripts)\n",
" operands_list = []\n",
" for index, tangent in enumerate(tangents):\n",
" if type(tangent) is not ad.Zero:\n",
" operands = list(primals)\n",
" operands[index] = tangent\n",
" operands_list.append(operands)\n",
" out_primal = this_einsum(*primals)\n",
" out_tangent = sum_tangents(this_einsum(*ops) for ops in operands_list)\n",
" return [out_primal], [out_tangent]\n",
"\n",
"def _einsum_transpose_rule(cotangent, *primals, input_strings, output_string):\n",
" index, = [i for i, p in enumerate(primals) if ad.is_undefined_primal(p)]\n",
" subscripts = (','.join(input_strings[:index] + input_strings[index+1:])\n",
" + ',' + output_string\n",
" + '->' + input_strings[index])\n",
" operands = primals[:index] + primals[index+1:] + tuple(cotangent)\n",
" out = [None] * len(primals)\n",
" out[index] = einsum(subscripts, *operands)\n",
" return out\n",
"\n",
"einsum_p = core.Primitive('einsum')\n",
"einsum_p.multiple_results = True\n",
"einsum_p.def_impl(_einsum_impl)\n",
"\n",
"def generic_abstract_eval(*avals, **params):\n",
" return pe.abstract_eval_fun(_einsum_impl, *avals, **params)\n",
"einsum_p.def_abstract_eval(generic_abstract_eval)\n",
"\n",
"ad.primitive_jvps[einsum_p] = _einsum_jvp\n",
"\n",
"xla.initial_style_translations[einsum_p] = xla.lower_fun_initial_style(_einsum_impl)\n",
"\n",
"ad.primitive_transposes[einsum_p] = _einsum_transpose_rule\n",
"\n",
"# TODO(shoyer): batching rule (should be pretty easy)\n",
"# batching.primitive_batchers[einsum_p] = _einsum_batching_rule\n"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "jqWG8ZJbqfAa"
},
"source": [
"#@title define `_parse_einsum_input` (from numpy) { display-mode: \"form\" }\n",
"# from numpy.core.einsumfunc\n",
"\n",
"einsum_symbols = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'\n",
"einsum_symbols_set = set(einsum_symbols)\n",
"\n",
"# asarray = lambda x: x\n",
"asarray = jnp.asarray\n",
"\n",
"def _parse_einsum_input(operands):\n",
" \"\"\"\n",
" A reproduction of einsum c side einsum parsing in python.\n",
"\n",
" Returns\n",
" -------\n",
" input_strings : str\n",
" Parsed input strings\n",
" output_string : str\n",
" Parsed output string\n",
" operands : list of array_like\n",
" The operands to use in the numpy contraction\n",
"\n",
" Examples\n",
" --------\n",
" The operand list is simplified to reduce printing:\n",
"\n",
" >>> a = np.random.rand(4, 4)\n",
" >>> b = np.random.rand(4, 4, 4)\n",
" >>> __parse_einsum_input(('...a,...a->...', a, b))\n",
" ('za,xza', 'xz', [a, b])\n",
"\n",
" >>> __parse_einsum_input((a, [Ellipsis, 0], b, [Ellipsis, 0]))\n",
" ('za,xza', 'xz', [a, b])\n",
" \"\"\"\n",
"\n",
" if len(operands) == 0:\n",
" raise ValueError(\"No input operands\")\n",
"\n",
" if isinstance(operands[0], str):\n",
" subscripts = operands[0].replace(\" \", \"\")\n",
" operands = [asarray(v) for v in operands[1:]]\n",
"\n",
" # Ensure all characters are valid\n",
" for s in subscripts:\n",
" if s in '.,->':\n",
" continue\n",
" if s not in einsum_symbols:\n",
" raise ValueError(\"Character %s is not a valid symbol.\" % s)\n",
"\n",
" else:\n",
" tmp_operands = list(operands)\n",
" operand_list = []\n",
" subscript_list = []\n",
" for p in range(len(operands) // 2):\n",
" operand_list.append(tmp_operands.pop(0))\n",
" subscript_list.append(tmp_operands.pop(0))\n",
"\n",
" output_list = tmp_operands[-1] if len(tmp_operands) else None\n",
" operands = [asarray(v) for v in operand_list]\n",
" subscripts = \"\"\n",
" last = len(subscript_list) - 1\n",
" for num, sub in enumerate(subscript_list):\n",
" for s in sub:\n",
" if s is Ellipsis:\n",
" subscripts += \"...\"\n",
" elif isinstance(s, int):\n",
" subscripts += einsum_symbols[s]\n",
" else:\n",
" raise TypeError(\"For this input type lists must contain \"\n",
" \"either int or Ellipsis\")\n",
" if num != last:\n",
" subscripts += \",\"\n",
"\n",
" if output_list is not None:\n",
" subscripts += \"->\"\n",
" for s in output_list:\n",
" if s is Ellipsis:\n",
" subscripts += \"...\"\n",
" elif isinstance(s, int):\n",
" subscripts += einsum_symbols[s]\n",
" else:\n",
" raise TypeError(\"For this input type lists must contain \"\n",
" \"either int or Ellipsis\")\n",
" # Check for proper \"->\"\n",
" if (\"-\" in subscripts) or (\">\" in subscripts):\n",
" invalid = (subscripts.count(\"-\") > 1) or (subscripts.count(\">\") > 1)\n",
" if invalid or (subscripts.count(\"->\") != 1):\n",
" raise ValueError(\"Subscripts can only contain one '->'.\")\n",
"\n",
" # Parse ellipses\n",
" if \".\" in subscripts:\n",
" used = subscripts.replace(\".\", \"\").replace(\",\", \"\").replace(\"->\", \"\")\n",
" unused = list(einsum_symbols_set - set(used))\n",
" ellipse_inds = \"\".join(unused)\n",
" longest = 0\n",
"\n",
" if \"->\" in subscripts:\n",
" input_tmp, output_sub = subscripts.split(\"->\")\n",
" split_subscripts = input_tmp.split(\",\")\n",
" out_sub = True\n",
" else:\n",
" split_subscripts = subscripts.split(',')\n",
" out_sub = False\n",
"\n",
" for num, sub in enumerate(split_subscripts):\n",
" if \".\" in sub:\n",
" if (sub.count(\".\") != 3) or (sub.count(\"...\") != 1):\n",
" raise ValueError(\"Invalid Ellipses.\")\n",
"\n",
" # Take into account numerical values\n",
" if operands[num].shape == ():\n",
" ellipse_count = 0\n",
" else:\n",
" ellipse_count = max(operands[num].ndim, 1)\n",
" ellipse_count -= (len(sub) - 3)\n",
"\n",
" if ellipse_count > longest:\n",
" longest = ellipse_count\n",
"\n",
" if ellipse_count < 0:\n",
" raise ValueError(\"Ellipses lengths do not match.\")\n",
" elif ellipse_count == 0:\n",
" split_subscripts[num] = sub.replace('...', '')\n",
" else:\n",
" rep_inds = ellipse_inds[-ellipse_count:]\n",
" split_subscripts[num] = sub.replace('...', rep_inds)\n",
"\n",
" subscripts = \",\".join(split_subscripts)\n",
" if longest == 0:\n",
" out_ellipse = \"\"\n",
" else:\n",
" out_ellipse = ellipse_inds[-longest:]\n",
"\n",
" if out_sub:\n",
" subscripts += \"->\" + output_sub.replace(\"...\", out_ellipse)\n",
" else:\n",
" # Special care for outputless ellipses\n",
" output_subscript = \"\"\n",
" tmp_subscripts = subscripts.replace(\",\", \"\")\n",
" for s in sorted(set(tmp_subscripts)):\n",
" if s not in (einsum_symbols):\n",
" raise ValueError(\"Character %s is not a valid symbol.\" % s)\n",
" if tmp_subscripts.count(s) == 1:\n",
" output_subscript += s\n",
" normal_inds = ''.join(sorted(set(output_subscript) -\n",
" set(out_ellipse)))\n",
"\n",
" subscripts += \"->\" + out_ellipse + normal_inds\n",
"\n",
" # Build output string if does not exist\n",
" if \"->\" in subscripts:\n",
" input_subscripts, output_subscript = subscripts.split(\"->\")\n",
" else:\n",
" input_subscripts = subscripts\n",
" # Build output subscripts\n",
" tmp_subscripts = subscripts.replace(\",\", \"\")\n",
" output_subscript = \"\"\n",
" for s in sorted(set(tmp_subscripts)):\n",
" if s not in einsum_symbols:\n",
" raise ValueError(\"Character %s is not a valid symbol.\" % s)\n",
" if tmp_subscripts.count(s) == 1:\n",
" output_subscript += s\n",
"\n",
" # Make sure output subscripts are in the input\n",
" for char in output_subscript:\n",
" if char not in input_subscripts:\n",
" raise ValueError(\"Output character %s did not appear in the input\"\n",
" % char)\n",
"\n",
" # Make sure number operands is equivalent to the number of terms\n",
" if len(input_subscripts.split(',')) != len(operands):\n",
" raise ValueError(\"Number of einsum subscripts must be equal to the \"\n",
" \"number of operands.\")\n",
"\n",
" return (input_subscripts, output_subscript, operands)\n"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "3VUxAr_HqWlE"
},
"source": [
"import jax\n",
"import jax.test_util as jtu"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "SncY3Y9RqhuT",
"outputId": "879f01db-15fc-4bdd-df98-4689e46ddb29",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 119
}
},
"source": [
"jax.make_jaxpr(partial(einsum, 'i,ij->ij'))(jnp.zeros((2,)), jnp.zeros((2, 3)))"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"/usr/local/lib/python3.6/dist-packages/jax/lib/xla_bridge.py:127: UserWarning: No GPU/TPU found, falling back to CPU.\n",
" warnings.warn('No GPU/TPU found, falling back to CPU.')\n"
],
"name": "stderr"
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"{ lambda ; a b.\n",
" let c = einsum[ input_strings=['i', 'ij']\n",
" output_string=ij ] a b\n",
" in (c,) }"
]
},
"metadata": {
"tags": []
},
"execution_count": 4
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "cYymXcbBqkeL",
"outputId": "d1c24b47-d5b9-41d8-f1cf-dd1d7f12d79e",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 85
}
},
"source": [
"jax.make_jaxpr(partial(einsum, 'i,ij,jk->ij'))(jnp.zeros((2,)), jnp.zeros((2, 3)), jnp.zeros((3, 4)))"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"{ lambda ; a b c.\n",
" let d = einsum[ input_strings=['i', 'ij', 'jk']\n",
" output_string=ij ] a b c\n",
" in (d,) }"
]
},
"metadata": {
"tags": []
},
"execution_count": 5
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "EARozZRGqlmL"
},
"source": [
"def make_einsum_grad(subscripts, einsum_fun=einsum, argnums=0):\n",
" @partial(jax.grad, argnums=argnums)\n",
" def f(*operands):\n",
" return jnp.sum(einsum_fun(subscripts, *operands) ** 2)\n",
" return f"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "T2NGlsvSqm3E",
"outputId": "d1b03be7-fc51-41d5-f545-fc5412b1cd37",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 153
}
},
"source": [
"jax.make_jaxpr(make_einsum_grad('ij,jk->ij'))(jnp.zeros((2, 3)), jnp.zeros((3, 4)))"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"{ lambda c ; a b.\n",
" let d = einsum[ input_strings=['ij', 'jk']\n",
" output_string=ij ] a b\n",
" e = mul 2.0 d\n",
" f = mul c e\n",
" g = einsum[ input_strings=['jk', 'ij']\n",
" output_string=ij ] b f\n",
" in (g,) }"
]
},
"metadata": {
"tags": []
},
"execution_count": 10
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "CinG5HQOKEgM"
},
"source": [
"import opt_einsum"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "Cb59R0xmKFHE"
},
"source": [
"opt_einsum"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "C2HW3evBKehE"
},
"source": [
"import collections"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "MhcVxm5gJvKN",
"outputId": "03bf46ac-e2e1-49fb-83fc-e6c24334b3d5",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 153
}
},
"source": [
"operands = 'abc,ad,be,cf,def,dg,eh,fi->ghi'\n",
"sizes = collections.defaultdict(lambda: 100)\n",
"arrays = [jnp.zeros(tuple(sizes[k] for k in op)) for op in operands.split('->')[0].split(',')]\n",
"jax.make_jaxpr(make_einsum_grad(operands))(*arrays)"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"{ lambda i ; a b c d e f g h.\n",
" let j = einsum[ input_strings=['abc', 'ad', 'be', 'cf', 'def', 'dg', 'eh', 'fi']\n",
" output_string=ghi ] a b c d e f g h\n",
" k = mul 2.0 j\n",
" l = mul i k\n",
" m = einsum[ input_strings=['ad', 'be', 'cf', 'def', 'dg', 'eh', 'fi', 'ghi']\n",
" output_string=abc ] b c d e f g h l\n",
" in (m,) }"
]
},
"metadata": {
"tags": []
},
"execution_count": 19
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "QltFpuSnLCit",
"outputId": "a50ae749-5524-49e6-da5a-70d506ecf090",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 391
}
},
"source": [
"operands = 'ad,be,cf,def,dg,eh,fi,ghi->abc'\n",
"arrays = [jnp.zeros(tuple(sizes[k] for k in op)) for op in operands.split('->')[0].split(',')]\n",
"jax.make_jaxpr(partial(jnp.einsum, operands))(*arrays)"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"{ lambda ; a b c d e f g h.\n",
" let i = xla_call[ backend=None\n",
" call_jaxpr={ lambda ; a b c d e f g h.\n",
" let i = dot_general[ dimension_numbers=(((0,), (1,)), ((), ()))\n",
" precision=None ] h e\n",
" j = dot_general[ dimension_numbers=(((0,), (1,)), ((), ()))\n",
" precision=None ] i f\n",
" k = dot_general[ dimension_numbers=(((0,), (1,)), ((), ()))\n",
" precision=None ] j g\n",
" l = dot_general[ dimension_numbers=(((), ()), ((0, 1, 2), (0, 1, 2)))\n",
" precision=None ] k d\n",
" m = dot_general[ dimension_numbers=(((0,), (1,)), ((), ()))\n",
" precision=None ] l a\n",
" n = dot_general[ dimension_numbers=(((0,), (1,)), ((), ()))\n",
" precision=None ] m b\n",
" o = dot_general[ dimension_numbers=(((0,), (1,)), ((), ()))\n",
" precision=None ] n c\n",
" in (o,) }\n",
" device=None\n",
" donated_invars=(False, False, False, False, False, False, False, False)\n",
" name=_einsum ] a b c d e f g h\n",
" in (i,) }"
]
},
"metadata": {
"tags": []
},
"execution_count": 25
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "mXEiJoeVK8eN",
"outputId": "ede223bf-9d72-4baf-d598-bab651161ddd",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 887
}
},
"source": [
"operands = 'abc,ad,be,cf,def,dg,eh,fi->ghi'\n",
"sizes = collections.defaultdict(lambda: 100)\n",
"arrays = [jnp.zeros(tuple(sizes[k] for k in op)) for op in operands.split('->')[0].split(',')]\n",
"jax.make_jaxpr(make_einsum_grad(operands, einsum=jnp.einsum))(*arrays)"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"{ lambda i s ; a b c d e f g h.\n",
" let j k l m n o p q r = xla_call[ backend=None\n",
" call_jaxpr={ lambda ; a b c d e f g h i.\n",
" let j = dot_general[ dimension_numbers=(((0,), (0,)), ((), ()))\n",
" precision=None ] b a\n",
" k = dot_general[ dimension_numbers=(((1,), (0,)), ((), ()))\n",
" precision=None ] j c\n",
" l = dot_general[ dimension_numbers=(((1,), (0,)), ((), ()))\n",
" precision=None ] k d\n",
" m = dot_general[ dimension_numbers=(((), ()), ((0, 1, 2), (0, 1, 2)))\n",
" precision=None ] l e\n",
" n = dot_general[ dimension_numbers=(((0,), (0,)), ((), ()))\n",
" precision=None ] m f\n",
" o = dot_general[ dimension_numbers=(((0,), (0,)), ((), ()))\n",
" precision=None ] n g\n",
" p = dot_general[ dimension_numbers=(((0,), (0,)), ((), ()))\n",
" precision=None ] o h\n",
" in (p, *, b, c, d, e, f, g, h) }\n",
" device=None\n",
" donated_invars=(False, False, False, False, False, False, False, False, False)\n",
" name=jvp(_einsum) ] a b c d e f g h i\n",
" t = mul 2.0 j\n",
" u = mul s t\n",
" v = xla_call[ backend=None\n",
" call_jaxpr={ lambda ; a b c d e f g h i j k l m n o.\n",
" let p = dot_general[ dimension_numbers=(((2,), (1,)), ((), ()))\n",
" precision=None ] o g\n",
" q = transpose[ permutation=(2, 0, 1) ] p\n",
" r = dot_general[ dimension_numbers=(((2,), (1,)), ((), ()))\n",
" precision=None ] q f\n",
" s = transpose[ permutation=(2, 0, 1) ] r\n",
" t = dot_general[ dimension_numbers=(((2,), (1,)), ((), ()))\n",
" precision=None ] s e\n",
" u = transpose[ permutation=(2, 0, 1) ] t\n",
" v = dot_general[ dimension_numbers=(((), ()), ((0, 1, 2), (0, 1, 2)))\n",
" precision=None ] u d\n",
" w = dot_general[ dimension_numbers=(((2,), (1,)), ((), ()))\n",
" precision=None ] v c\n",
" x = transpose[ permutation=(0, 2, 1) ] w\n",
" y = dot_general[ dimension_numbers=(((2,), (1,)), ((), ()))\n",
" precision=None ] x b\n",
" z = transpose[ permutation=(0, 2, 1) ] y\n",
" ba = dot_general[ dimension_numbers=(((0,), (1,)), ((), ()))\n",
" precision=None ] z a\n",
" bb = transpose[ permutation=(2, 0, 1) ] ba\n",
" in (bb,) }\n",
" device=None\n",
" donated_invars=(False, False, False, False, False, False, False, False, False, False, False, False, False, False, False)\n",
" name=transpose(jvp(_einsum)) ] l m n o p q r i i i i i i i u\n",
" in (v,) }"
]
},
"metadata": {
"tags": []
},
"execution_count": 20
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "T5nLpDnXL7H3",
"outputId": "58c81d37-b3ba-4019-d395-a721d3b08c68",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 153
}
},
"source": [
"operands = 'bdik,acaj,ikab,ajac,ikbd->'\n",
"sizes = collections.defaultdict(lambda: 10)\n",
"arrays = [jnp.zeros(tuple(sizes[k] for k in op)) for op in operands.split('->')[0].split(',')]\n",
"jax.make_jaxpr(make_einsum_grad(operands))(*arrays)\n"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"{ lambda ; a b c d e.\n",
" let f = einsum[ input_strings=['bdik', 'acaj', 'ikab', 'ajac', 'ikbd']\n",
" output_string= ] a b c d e\n",
" g = mul 2.0 f\n",
" h = mul 1.0 g\n",
" i = einsum[ input_strings=['acaj', 'ikab', 'ajac', 'ikbd', '']\n",
" output_string=bdik ] b c d e h\n",
" in (i,) }"
]
},
"metadata": {
"tags": []
},
"execution_count": 26
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "ikKb81hVRREY"
},
"source": [
"block_until_ready = partial(jax.tree_map, lambda x: x.block_until_ready())"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "qn1xnP2VXj5H"
},
"source": [
"def make_einsum_grad2(subscripts, einsum_fun=einsum, argnums=0):\n",
" @partial(jax.grad, argnums=argnums)\n",
" def f(*operands):\n",
" return einsum_fun(subscripts, *operands)\n",
" return f"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "wrMmF7xgQJQV",
"outputId": "cadad34c-65b0-4204-d54a-efab70d977af",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 170
}
},
"source": [
"operands = 'abcde,abfg,cdhi,ghjk,ielm,fjno,klpq,nopqm->'\n",
"dim_size = 8\n",
"print(f\"expression: {operands}\")\n",
"print(f\"dim_size: {dim_size}\")\n",
"sizes = collections.defaultdict(lambda: dim_size)\n",
"arrays = [jnp.zeros(tuple(sizes[k] for k in op)) for op in operands.split('->')[0].split(',')]\n",
"argnums = (1, 2, 3, 4, 5, 6, 7)\n",
"print(f\"gradient argnums: {argnums}\")\n",
"\n",
"print()\n",
"print(\"einsum primitive\")\n",
"f = jax.jit(make_einsum_grad(operands, einsum_fun=einsum, argnums=argnums))\n",
"# print(jax.make_jaxpr(f)(*arrays))\n",
"block_until_ready(f(*arrays)) # compile\n",
"%timeit block_until_ready(f(*arrays))\n",
"\n",
"print()\n",
"print(\"dot_general primitive\")\n",
"f = jax.jit(make_einsum_grad(operands, einsum_fun=jnp.einsum, argnums=argnums))\n",
"# print(jax.make_jaxpr(f)(*arrays))\n",
"block_until_ready(f(*arrays)) # compile\n",
"%timeit block_until_ready(f(*arrays))"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"expression: abcde,abfg,cdhi,ghjk,ielm,fjno,klpq,nopqm->\n",
"dim_size: 8\n",
"gradient argnums: (1, 2, 3, 4, 5, 6, 7)\n",
"\n",
"einsum primitive\n",
"100 loops, best of 3: 4.88 ms per loop\n",
"\n",
"dot_general primitive\n",
"100 loops, best of 3: 3.93 ms per loop\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "icUau75pQV8M",
"outputId": "acfdb690-8fb5-45a7-bbd4-e23ae4ab4f99",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 972
}
},
"source": [
"operands = 'abcde,abfg,cdhi,ghjk,ielm,fjno,klpq->nopqm'\n",
"sizes = collections.defaultdict(lambda: 16)\n",
"arrays = [jnp.zeros(tuple(sizes[k] for k in op)) for op in operands.split('->')[0].split(',')]\n",
"f = jax.jit(make_einsum_grad(operands, einsum=jnp.einsum, argnums=(0,)))\n",
"print(jax.make_jaxpr(f)(*arrays))\n",
"block_until_ready(f(*arrays)) # compile\n",
"%timeit block_until_ready(f(*arrays))"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"{ lambda h i ; a b c d e f g.\n",
" let j = xla_call[ backend=None\n",
" call_jaxpr={ lambda ; h q a b c d e f g.\n",
" let i j k l m n o p = xla_call[ backend=None\n",
" call_jaxpr={ lambda ; a b c d e f g h.\n",
" let i = dot_general[ dimension_numbers=(((0, 1), (0, 1)), ((), ()))\n",
" precision=None ] b a\n",
" j = dot_general[ dimension_numbers=(((2, 3), (0, 1)), ((), ()))\n",
" precision=None ] i c\n",
" k = dot_general[ dimension_numbers=(((1, 3), (0, 1)), ((), ()))\n",
" precision=None ] j d\n",
" l = dot_general[ dimension_numbers=(((1, 2), (1, 0)), ((), ()))\n",
" precision=None ] k e\n",
" m = dot_general[ dimension_numbers=(((0, 1), (0, 1)), ((), ()))\n",
" precision=None ] l f\n",
" n = dot_general[ dimension_numbers=(((0, 1), (0, 1)), ((), ()))\n",
" precision=None ] m g\n",
" o = transpose[ permutation=(1, 2, 3, 4, 0) ] n\n",
" in (o, *, b, c, d, e, f, g) }\n",
" device=None\n",
" donated_invars=(False, False, False, False, False, False, False, False)\n",
" name=jvp(_einsum) ] a b c d e f g h\n",
" r = mul 2.0 i\n",
" s = mul q r\n",
" t = xla_call[ backend=None\n",
" call_jaxpr={ lambda ; a b c d e f g h i j k l m.\n",
" let n = transpose[ permutation=(4, 0, 1, 2, 3) ] m\n",
" o = dot_general[ dimension_numbers=(((3, 4), (2, 3)), ((), ()))\n",
" precision=None ] n f\n",
" p = transpose[ permutation=(3, 4, 0, 1, 2) ] o\n",
" q = dot_general[ dimension_numbers=(((3, 4), (2, 3)), ((), ()))\n",
" precision=None ] p e\n",
" r = transpose[ permutation=(3, 4, 0, 1, 2) ] q\n",
" s = dot_general[ dimension_numbers=(((3, 4), (2, 3)), ((), ()))\n",
" precision=None ] r d\n",
" t = transpose[ permutation=(0, 4, 3, 1, 2) ] s\n",
" u = dot_general[ dimension_numbers=(((3, 4), (2, 3)), ((), ()))\n",
" precision=None ] t c\n",
" v = transpose[ permutation=(0, 3, 1, 4, 2) ] u\n",
" w = dot_general[ dimension_numbers=(((3, 4), (2, 3)), ((), ()))\n",
" precision=None ] v b\n",
" x = transpose[ permutation=(0, 1, 3, 4, 2) ] w\n",
" y = dot_general[ dimension_numbers=(((0, 1), (2, 3)), ((), ()))\n",
" precision=None ] x a\n",
" z = transpose[ permutation=(3, 4, 0, 1, 2) ] y\n",
" in (z,) }\n",
" device=None\n",
" donated_invars=(False, False, False, False, False, False, False, False, False, False, False, False, False)\n",
" name=transpose(jvp(_einsum)) ] k l m n o p h h h h h h s\n",
" in (t,) }\n",
" device=None\n",
" donated_invars=(False, False, False, False, False, False, False, False, False)\n",
" name=f ] h i a b c d e f g\n",
" in (j,) }\n",
"10 loops, best of 3: 130 ms per loop\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "twzMuFgsNncy",
"outputId": "57f07c77-ca7a-40f5-fdc5-b89aa8cced02",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 442
}
},
"source": [
"operands = 'acaj,ikab,ajac,ikbd,->bdik'\n",
"arrays = [jnp.zeros(tuple(sizes[k] for k in op)) for op in operands.split('->')[0].split(',')]\n",
"jax.make_jaxpr(partial(jnp.einsum, operands))(*arrays)"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"{ lambda f g ; a b c d e.\n",
" let h = xla_call[ backend=None\n",
" call_jaxpr={ lambda ; f j a b c d e.\n",
" let g = mul c f\n",
" h = reduce_sum[ axes=(0,) ] g\n",
" i = transpose[ permutation=(1, 0, 2) ] h\n",
" k = mul a j\n",
" l = reduce_sum[ axes=(0,) ] k\n",
" m = transpose[ permutation=(1, 0, 2) ] l\n",
" n = dot_general[ dimension_numbers=(((2, 1), (1, 2)), ((0,), (0,)))\n",
" precision=None ] i m\n",
" o = dot_general[ dimension_numbers=(((0,), (2,)), ((), ()))\n",
" precision=None ] n b\n",
" p = reshape[ dimensions=None\n",
" new_sizes=() ] e\n",
" q = dot_general[ dimension_numbers=(((), ()), ((), ()))\n",
" precision=None ] o p\n",
" r = dot_general[ dimension_numbers=(((), ()), ((0, 1, 2), (0, 1, 2)))\n",
" precision=None ] q d\n",
" s = transpose[ permutation=(2, 3, 0, 1) ] r\n",
" in (s,) }\n",
" device=None\n",
" donated_invars=(False, False, False, False, False, False, False)\n",
" name=_einsum ] f g a b c d e\n",
" in (h,) }"
]
},
"metadata": {
"tags": []
},
"execution_count": 27
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "BXbdK8W1P638"
},
"source": [
"#"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "8nQdioSXKo78",
"outputId": "059ee47d-77a5-467d-cfc2-872eed5029d8",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 714
}
},
"source": [
"operands = 'bdik,acaj,ikab,ajac,ikbd->'\n",
"sizes = collections.defaultdict(lambda: 10)\n",
"arrays = [jnp.zeros(tuple(sizes[k] for k in op)) for op in operands.split('->')[0].split(',')]\n",
"jax.make_jaxpr(make_einsum_grad(operands, einsum=jnp.einsum))(*arrays)"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"{ lambda f g h ; a b c d e.\n",
" let i j k l m = xla_call[ backend=None\n",
" call_jaxpr={ lambda ; l p a b c d e f.\n",
" let g = transpose[ permutation=(2, 0, 1, 3) ] e\n",
" h = transpose[ permutation=(0, 2, 3, 1) ] a\n",
" i = dot_general[ dimension_numbers=(((3,), (3,)), ((0, 1, 2), (0, 1, 2)))\n",
" precision=None ] g h\n",
" j = transpose[ permutation=(1, 2, 0) ] i\n",
" k = dot_general[ dimension_numbers=(((2, 0, 1), (3, 0, 1)), ((), ()))\n",
" precision=None ] j c\n",
" m = mul d l\n",
" n = reduce_sum[ axes=(0,) ] m\n",
" o = transpose[ permutation=(1, 0, 2) ] n\n",
" q = mul b p\n",
" r = reduce_sum[ axes=(0,) ] q\n",
" s = transpose[ permutation=(1, 0, 2) ] r\n",
" t = dot_general[ dimension_numbers=(((2, 1), (1, 2)), ((0,), (0,)))\n",
" precision=None ] o s\n",
" u = dot_general[ dimension_numbers=(((0,), (0,)), ((), ()))\n",
" precision=None ] k t\n",
" in (u, *, g, c, t) }\n",
" device=None\n",
" donated_invars=(False, False, False, False, False, False, False, False)\n",
" name=jvp(_einsum) ] f g a b c d e h\n",
" n = mul 2.0 i\n",
" o = mul 1.0 n\n",
" p = xla_call[ backend=None\n",
" call_jaxpr={ lambda ; a b c d e f g h.\n",
" let i = dot_general[ dimension_numbers=(((), ()), ((), ()))\n",
" precision=None ] h c\n",
" j = dot_general[ dimension_numbers=(((0,), (2,)), ((), ()))\n",
" precision=None ] i b\n",
" k = transpose[ permutation=(2, 0, 1) ] j\n",
" l = dot_general[ dimension_numbers=(((), ()), ((0, 1, 2), (0, 1, 2)))\n",
" precision=None ] k a\n",
" m = transpose[ permutation=(0, 3, 1, 2) ] l\n",
" in (m,) }\n",
" device=None\n",
" donated_invars=(False, False, False, False, False, False, False, False)\n",
" name=transpose(jvp(_einsum)) ] k l m h h h h o\n",
" in (p,) }"
]
},
"metadata": {
"tags": []
},
"execution_count": 28
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "Dg_RhNupqoaD",
"outputId": "3b089faa-59da-4a2c-98da-8363f2c42054",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 52
}
},
"source": [
"jax.jit(make_einsum_grad('ij,jk->ij'))(jnp.zeros((2, 3)), jnp.zeros((3, 4)))"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"DeviceArray([[0., 0., 0.],\n",
" [0., 0., 0.]], dtype=float32)"
]
},
"metadata": {
"tags": []
},
"execution_count": 10
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "Lt8vgNuSqxj7",
"outputId": "cc23622e-1946-4b33-eddb-b85c7f7a96a4",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 52
}
},
"source": [
"make_einsum_grad('ij,jk->ij')(jnp.zeros((2, 3)), jnp.zeros((3, 4)))"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"DeviceArray([[0., 0., 0.],\n",
" [0., 0., 0.]], dtype=float32)"
]
},
"metadata": {
"tags": []
},
"execution_count": 11
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "7BEZ6RkKrOCr",
"outputId": "b23acff3-4f61-4b56-f26e-943e4ae7d6c5",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 52
}
},
"source": [
"from functools import partial\n",
"import numpy as np\n",
"\n",
"rs = np.random.RandomState(0)\n",
"f = partial(einsum, 'i,ij,j->ij')\n",
"args = (rs.randn(2), rs.randn(2, 3), rs.randn(3,))\n",
"jtu.check_grads(f, args, order=2)"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"/usr/local/lib/python3.6/dist-packages/jax/lib/xla_bridge.py:116: UserWarning: No GPU/TPU found, falling back to CPU.\n",
" warnings.warn('No GPU/TPU found, falling back to CPU.')\n"
],
"name": "stderr"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "xAvXZH4XrjjF",
"outputId": "d6774b36-088a-48bc-aecb-61dd23b604e8",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 51
}
},
"source": [
"from functools import partial\n",
"import numpy as np\n",
"\n",
"rs = np.random.RandomState(0)\n",
"operands = 'ijk,ij,jk->ij'\n",
"f = partial(einsum, operands)\n",
"args = (rs.randn(2, 3, 4), rs.randn(2, 3), rs.randn(3, 4))\n",
"jtu.check_grads(f, args, order=2)"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"/usr/local/lib/python3.6/dist-packages/jax/lib/xla_bridge.py:127: UserWarning: No GPU/TPU found, falling back to CPU.\n",
" warnings.warn('No GPU/TPU found, falling back to CPU.')\n"
],
"name": "stderr"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "PEvOPu8xr5j8",
"outputId": "f1ccef7a-4e8b-4aa2-e193-ec645b1f859b",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 155
}
},
"source": [
"jax.make_jaxpr(make_einsum_grad(operands))(*args)"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"{ lambda d ; a b c.\n",
" let e = einsum[ input_strings=['ijk', 'ij', 'jk']\n",
" output_string=ij ] a b c\n",
" f = mul 2.0 e\n",
" g = mul d f\n",
" h = einsum[ input_strings=['ij', 'jk', 'ij']\n",
" output_string=ijk ] b c g\n",
" in (h,) }"
]
},
"metadata": {
"tags": []
},
"execution_count": 14
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "8_zcLutJujVl",
"outputId": "35ebfc26-58d1-4189-80c8-391d9090c169",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 953
}
},
"source": [
"print(jax.xla_computation(make_einsum_grad(operands, einsum=jnp.einsum))(*args).as_hlo_text())"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"HloModule xla_computation_f__3.44\n",
"\n",
"jit_pe_jvp__einsum__.8 {\n",
" parameter.12 = pred[] parameter(3)\n",
" parameter.11 = f32[3,4]{1,0} parameter(2)\n",
" parameter.9 = f32[2,3,4]{2,1,0} parameter(0)\n",
" transpose.14 = f32[3,2,4]{2,0,1} transpose(parameter.9), dimensions={1,0,2}\n",
" dot.15 = f32[3,2]{1,0} dot(parameter.11, transpose.14), lhs_batch_dims={0}, lhs_contracting_dims={1}, rhs_batch_dims={0}, rhs_contracting_dims={2}\n",
" parameter.10 = f32[2,3]{1,0} parameter(1)\n",
" transpose.16 = f32[3,2]{0,1} transpose(parameter.10), dimensions={1,0}\n",
" dot.17 = f32[3,2]{1,0} dot(dot.15, transpose.16), lhs_batch_dims={0,1}, lhs_contracting_dims={}, rhs_batch_dims={0,1}, rhs_contracting_dims={}\n",
" transpose.18 = f32[2,3]{0,1} transpose(dot.17), dimensions={1,0}\n",
" constant.13 = pred[] constant(false)\n",
" ROOT tuple.19 = (f32[2,3]{0,1}, pred[], f32[3,4]{1,0}, f32[3,2]{0,1}) tuple(transpose.18, constant.13, parameter.11, transpose.16)\n",
"}\n",
"\n",
"jit_transpose_pe_jvp__einsum___.29 {\n",
" parameter.32 = pred[] parameter(2)\n",
" parameter.33 = pred[] parameter(3)\n",
" constant.35 = pred[] constant(false)\n",
" parameter.34 = f32[2,3]{1,0} parameter(4)\n",
" transpose.36 = f32[3,2]{0,1} transpose(parameter.34), dimensions={1,0}\n",
" parameter.31 = f32[3,2]{0,1} parameter(1)\n",
" dot.37 = f32[3,2]{1,0} dot(transpose.36, parameter.31), lhs_batch_dims={0,1}, lhs_contracting_dims={}, rhs_batch_dims={0,1}, rhs_contracting_dims={}\n",
" parameter.30 = f32[3,4]{1,0} parameter(0)\n",
" dot.38 = f32[3,2,4]{2,1,0} dot(dot.37, parameter.30), lhs_batch_dims={0}, lhs_contracting_dims={}, rhs_batch_dims={0}, rhs_contracting_dims={}\n",
" transpose.39 = f32[2,3,4]{2,0,1} transpose(dot.38), dimensions={1,0,2}\n",
" ROOT tuple.40 = (f32[2,3,4]{2,0,1}) tuple(transpose.39)\n",
"}\n",
"\n",
"ENTRY xla_computation_f__3.44 {\n",
" constant.7 = pred[] constant(false)\n",
" parameter.4 = f32[2,3,4]{2,1,0} parameter(0)\n",
" parameter.5 = f32[2,3]{1,0} parameter(1)\n",
" parameter.6 = f32[3,4]{1,0} parameter(2)\n",
" constant.1 = pred[] constant(false)\n",
" call.20 = (f32[2,3]{0,1}, pred[], f32[3,4]{1,0}, f32[3,2]{0,1}) call(parameter.4, parameter.5, parameter.6, constant.1), to_apply=jit_pe_jvp__einsum__.8\n",
" get-tuple-element.22 = pred[] get-tuple-element(call.20), index=1\n",
" get-tuple-element.23 = f32[3,4]{1,0} get-tuple-element(call.20), index=2\n",
" get-tuple-element.24 = f32[3,2]{0,1} get-tuple-element(call.20), index=3\n",
" constant.2 = f32[] constant(1)\n",
" broadcast.3 = f32[2,3]{1,0} broadcast(constant.2), dimensions={}\n",
" constant.25 = f32[] constant(2)\n",
" broadcast.26 = f32[2,3]{1,0} broadcast(constant.25), dimensions={}\n",
" get-tuple-element.21 = f32[2,3]{0,1} get-tuple-element(call.20), index=0\n",
" multiply.27 = f32[2,3]{1,0} multiply(broadcast.26, get-tuple-element.21)\n",
" multiply.28 = f32[2,3]{1,0} multiply(broadcast.3, multiply.27)\n",
" call.41 = (f32[2,3,4]{2,0,1}) call(get-tuple-element.23, get-tuple-element.24, constant.1, constant.1, multiply.28), to_apply=jit_transpose_pe_jvp__einsum___.29\n",
" get-tuple-element.42 = f32[2,3,4]{2,0,1} get-tuple-element(call.41), index=0\n",
" ROOT tuple.43 = (f32[2,3,4]{2,0,1}) tuple(get-tuple-element.42)\n",
"}\n",
"\n",
"\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "OFQ8PslmsclX",
"outputId": "8c4644d0-fc11-4a5c-8ee8-1377c8f39da9",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 935
}
},
"source": [
"print(jax.xla_computation(make_einsum_grad(operands))(*args).as_hlo_text())"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"HloModule xla_computation_f__2.43\n",
"\n",
"jit__einsum__260.8 {\n",
" constant.12 = pred[] constant(false)\n",
" parameter.11 = f32[3,4]{1,0} parameter(2)\n",
" parameter.9 = f32[2,3,4]{2,1,0} parameter(0)\n",
" transpose.13 = f32[3,2,4]{2,0,1} transpose(parameter.9), dimensions={1,0,2}\n",
" dot.14 = f32[3,2]{1,0} dot(parameter.11, transpose.13), lhs_batch_dims={0}, lhs_contracting_dims={1}, rhs_batch_dims={0}, rhs_contracting_dims={2}\n",
" parameter.10 = f32[2,3]{1,0} parameter(1)\n",
" transpose.15 = f32[3,2]{0,1} transpose(parameter.10), dimensions={1,0}\n",
" dot.16 = f32[3,2]{1,0} dot(dot.14, transpose.15), lhs_batch_dims={0,1}, lhs_contracting_dims={}, rhs_batch_dims={0,1}, rhs_contracting_dims={}\n",
" transpose.17 = f32[2,3]{0,1} transpose(dot.16), dimensions={1,0}\n",
" ROOT tuple.18 = (f32[2,3]{0,1}) tuple(transpose.17)\n",
"}\n",
"\n",
"jit__einsum__261.28 {\n",
" constant.32 = pred[] constant(false)\n",
" parameter.31 = f32[2,3]{1,0} parameter(2)\n",
" parameter.29 = f32[2,3]{1,0} parameter(0)\n",
" dot.33 = f32[2,3]{1,0} dot(parameter.31, parameter.29), lhs_batch_dims={0,1}, lhs_contracting_dims={}, rhs_batch_dims={0,1}, rhs_contracting_dims={}\n",
" transpose.34 = f32[3,2]{0,1} transpose(dot.33), dimensions={1,0}\n",
" parameter.30 = f32[3,4]{1,0} parameter(1)\n",
" dot.35 = f32[3,2,4]{2,1,0} dot(transpose.34, parameter.30), lhs_batch_dims={0}, lhs_contracting_dims={}, rhs_batch_dims={0}, rhs_contracting_dims={}\n",
" transpose.36 = f32[2,3,4]{2,0,1} transpose(dot.35), dimensions={1,0,2}\n",
" ROOT tuple.37 = (f32[2,3,4]{2,0,1}) tuple(transpose.36)\n",
"}\n",
"\n",
"ENTRY xla_computation_f__2.43 {\n",
" constant.6 = pred[] constant(false)\n",
" constant.7 = pred[] constant(false)\n",
" constant.27 = pred[] constant(false)\n",
" parameter.4 = f32[2,3]{1,0} parameter(1)\n",
" parameter.5 = f32[3,4]{1,0} parameter(2)\n",
" constant.1 = f32[] constant(1)\n",
" broadcast.2 = f32[2,3]{1,0} broadcast(constant.1), dimensions={}\n",
" constant.23 = f32[] constant(2)\n",
" broadcast.24 = f32[2,3]{1,0} broadcast(constant.23), dimensions={}\n",
" parameter.3 = f32[2,3,4]{2,1,0} parameter(0)\n",
" call.19 = (f32[2,3]{0,1}) call(parameter.3, parameter.4, parameter.5), to_apply=jit__einsum__260.8\n",
" get-tuple-element.20 = f32[2,3]{0,1} get-tuple-element(call.19), index=0\n",
" tuple.21 = (f32[2,3]{0,1}) tuple(get-tuple-element.20)\n",
" get-tuple-element.22 = f32[2,3]{0,1} get-tuple-element(tuple.21), index=0\n",
" multiply.25 = f32[2,3]{1,0} multiply(broadcast.24, get-tuple-element.22)\n",
" multiply.26 = f32[2,3]{1,0} multiply(broadcast.2, multiply.25)\n",
" call.38 = (f32[2,3,4]{2,0,1}) call(parameter.4, parameter.5, multiply.26), to_apply=jit__einsum__261.28\n",
" get-tuple-element.39 = f32[2,3,4]{2,0,1} get-tuple-element(call.38), index=0\n",
" tuple.40 = (f32[2,3,4]{2,0,1}) tuple(get-tuple-element.39)\n",
" get-tuple-element.41 = f32[2,3,4]{2,0,1} get-tuple-element(tuple.40), index=0\n",
" ROOT tuple.42 = (f32[2,3,4]{2,0,1}) tuple(get-tuple-element.41)\n",
"}\n",
"\n",
"\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "DrAru3d7ueL-"
},
"source": [],
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment