Skip to content

Instantly share code, notes, and snippets.

@refraction-ray
Created May 30, 2020 01:44
Show Gist options
  • Save refraction-ray/dc22288a0d9e22bb263e59487ad8f5ea to your computer and use it in GitHub Desktop.
Save refraction-ray/dc22288a0d9e22bb263e59487ad8f5ea to your computer and use it in GitHub Desktop.
einsum_symbol.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "einsum_symbol.ipynb",
"provenance": [],
"collapsed_sections": [],
"authorship_tag": "ABX9TyNRnzWO/H0U0kPQ3nLBDKsC",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/refraction-ray/dc22288a0d9e22bb263e59487ad8f5ea/einsum_symbol.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"metadata": {
"id": "figBFYiINilo",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 154
},
"outputId": "7a4cbe58-b951-4c93-f135-1501f070b8bc"
},
"source": [
"!pip install qop\n",
"!pip install opt_einsum"
],
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"text": [
"Collecting qop\n",
" Downloading https://files.pythonhosted.org/packages/d3/56/8182fa684d0dc249ced5c3a219f9ae4b34e898dc8dca504ccba3cc848111/qop-0.0.2-py3-none-any.whl\n",
"Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from qop) (1.18.4)\n",
"Installing collected packages: qop\n",
"Successfully installed qop-0.0.2\n",
"Requirement already satisfied: opt_einsum in /usr/local/lib/python3.6/dist-packages (3.2.1)\n",
"Requirement already satisfied: numpy>=1.7 in /usr/local/lib/python3.6/dist-packages (from opt_einsum) (1.18.4)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "aoJP-IYTNnmk",
"colab_type": "code",
"colab": {}
},
"source": [
"from qop.symbol import Symbols\n",
"from qop.base import simplify\n",
"from opt_einsum import contract\n",
"import numpy as np"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "Xi5utt9yNsF6",
"colab_type": "code",
"colab": {}
},
"source": [
"a,b,c = Symbols(\"abc\")"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "x89q05U_NwWJ",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"outputId": "0bdcf524-8ad4-4aa7-b92c-1e2505265804"
},
"source": [
"# a,b,c are some customized object with __add__ and __mul__\n",
"((a*a+3*b-c**2)**2).simplify()"
],
"execution_count": 4,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"1.0*a*a*a*a + 6.0*a*a*b + -2.0*a*a*c*c + 9.0*b*b + -6.0*b*c*c + 1.0*c*c*c*c"
]
},
"metadata": {
"tags": []
},
"execution_count": 4
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "dmabx2_6Nw7S",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 50
},
"outputId": "db7a05a4-9bea-4602-b329-aacfcc1790fd"
},
"source": [
"# some utilities of numpy works perfectly with symbols\n",
"simplify(a*np.ones([2,2]))"
],
"execution_count": 7,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"array([[1.0*a, 1.0*a],\n",
" [1.0*a, 1.0*a]], dtype=object)"
]
},
"metadata": {
"tags": []
},
"execution_count": 7
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "vvlADgQPNzIy",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 50
},
"outputId": "87cf4641-4eab-4cab-c4ca-d88a32a3db7c"
},
"source": [
"simplify(np.array([[a,b],[b,a]])@np.array([[1,2],[c,c]]))"
],
"execution_count": 8,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"array([[1.0*a + 1.0*b*c, 2.0*a + 1.0*b*c],\n",
" [1.0*b + 1.0*a*c, 2.0*b + 1.0*a*c]], dtype=object)"
]
},
"metadata": {
"tags": []
},
"execution_count": 8
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "Iz3rbMF_N4Xr",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 50
},
"outputId": "9dd11605-93d0-4dc8-e90a-67ac3de16b0b"
},
"source": [
"## but einsum doesn't work that well, if the call is invoked with backend.einsum beynond ternsordot\n",
"## successful case\n",
"simplify(contract(\"ij,jk->ik\", np.array([[a,b],[1.1,2]]), np.array([[c+1,c],[c,0.]])))"
],
"execution_count": 10,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"array([[1.0*a*c + 1.0*a + 1.0*b*c, 1.0*a*c],\n",
" [3.1*c + 1.1*I, 1.1*c]], dtype=object)"
]
},
"metadata": {
"tags": []
},
"execution_count": 10
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "3Q-eP7CzOM3u",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 391
},
"outputId": "4b6bc573-efc2-47a1-8dd7-365586460484"
},
"source": [
"# the above example succeed because it only call tensordot which supports symbolic objects in python level\n",
"# and if there is no blas contraction in the middle, the fallback call to c_einsum is invoked which is purely in C and cannot handle python duck types\n",
"contract(\"ij->\", np.array([[a,a]]))"
],
"execution_count": 11,
"outputs": [
{
"output_type": "error",
"ename": "TypeError",
"evalue": "ignored",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-11-8d2ca37c8f33>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m# the above example succeed because it only call tensordot which supports symbolic objects in python level\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;31m# and if there is no blas contraction in the middle, the fallback call to c_einsum is invoked which is purely in C and cannot handle python duck types\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mcontract\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"ij->\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/opt_einsum/contract.py\u001b[0m in \u001b[0;36mcontract\u001b[0;34m(*operands, **kwargs)\u001b[0m\n\u001b[1;32m 481\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mContractExpression\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfull_str\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcontraction_list\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mconstants_dict\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0meinsum_kwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 482\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 483\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0m_core_contract\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moperands\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcontraction_list\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbackend\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mbackend\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0meinsum_kwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 484\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 485\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/opt_einsum/contract.py\u001b[0m in \u001b[0;36m_core_contract\u001b[0;34m(operands, contraction_list, backend, evaluate_constants, **einsum_kwargs)\u001b[0m\n\u001b[1;32m 565\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 566\u001b[0m \u001b[0;31m# Do the contraction\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 567\u001b[0;31m \u001b[0mnew_view\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_einsum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0meinsum_str\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mtmp_operands\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbackend\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mbackend\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0meinsum_kwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 568\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 569\u001b[0m \u001b[0;31m# Append new items and dereference what we can\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/opt_einsum/sharing.py\u001b[0m in \u001b[0;36mcached_einsum\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 149\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mcached_einsum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 150\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mcurrently_sharing\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 151\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0meinsum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 152\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 153\u001b[0m \u001b[0;31m# hash modulo commutativity by computing a canonical ordering and names\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/opt_einsum/contract.py\u001b[0m in \u001b[0;36m_einsum\u001b[0;34m(*operands, **kwargs)\u001b[0m\n\u001b[1;32m 338\u001b[0m \u001b[0meinsum_str\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mparser\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconvert_to_valid_einsum_chars\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0meinsum_str\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 339\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 340\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0meinsum_str\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0moperands\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 341\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 342\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m<__array_function__ internals>\u001b[0m in \u001b[0;36meinsum\u001b[0;34m(*args, **kwargs)\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/numpy/core/einsumfunc.py\u001b[0m in \u001b[0;36meinsum\u001b[0;34m(*operands, **kwargs)\u001b[0m\n\u001b[1;32m 1354\u001b[0m \u001b[0;31m# If no optimization, run pure einsum\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1355\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0moptimize_arg\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1356\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mc_einsum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0moperands\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1357\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1358\u001b[0m \u001b[0mvalid_einsum_kwargs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m'out'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'dtype'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'order'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'casting'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mTypeError\u001b[0m: invalid data type for einsum"
]
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "37gbOpwIPM89",
"colab_type": "code",
"colab": {}
},
"source": [
""
],
"execution_count": 0,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment