Skip to content

Instantly share code, notes, and snippets.

@kyscg
Last active July 13, 2023 15:14
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kyscg/fe6bfe5ddb0e34c918c06242f7979c87 to your computer and use it in GitHub Desktop.
Save kyscg/fe6bfe5ddb0e34c918c06242f7979c87 to your computer and use it in GitHub Desktop.
Einstein Summation Notation Exercises
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"id": "nO4U9fWoa0BI"
},
"source": [
"# Einsum Notation\n",
"\n",
"See discussion on [Twitter](https://twitter.com/kyscg7/status/1679508714141937664), [Mathstodon](https://mathstodon.xyz/@kyscg/110707391200337881).\n",
"\n",
"---\n",
"\n",
"This notebook has exercises to understand how Einstein Summation Notation works for Deep Learning algorithms. I initially tested my understanding by taking an operation (say matrix-vector multiplication), writing the corresponding einsum notation, and checking against the NumPy implementation using `all()`. I then moved on to PyTorch, and implemented an MLP class and a Multi-head Self-Attention block using einsum notation.\n",
"\n",
"Sources: [A basic introduction to NumPy's einsum | ajcr.net](https://ajcr.net/Basic-guide-to-einsum/) by Alex Riley, and [Einsum is All you Need - Einstein Summation in Deep Learning | rockt.github.io](https://rockt.github.io/2018/04/30/einsum) by Tim Rocktäschel. The following are rules from Alex Riley's blog post which are a nice way to get introduced to the idea.\n",
"\n",
"- Repeating letters between input arrays means that values along those axes will be multiplied together. The products make up the values for the output array.\n",
"- Omitting a letter from the output means that values along that axis will be summed.\n",
"- We can return the unsummed axes in any order we like.\n",
"\n",
"I felt however, that a better way to understand einsum notation is to decompose them to loops. Say we have the following matrix multiplication:\n",
"\n",
"```python\n",
"C = np.einsum(\"ij,jk -> ik\", A, B)\n",
"```\n",
"\n",
"This is essentially, the same as looping over `i` and `k`, and preparing sums. Then, loop over the variables that are not in the output, `j`, and sum over that variable. Which means, the loop looks like:\n",
"\n",
"```python\n",
"for i in range(...):\n",
" for k in range (...):\n",
" C[i, k] = 0\n",
" for j in range(...):\n",
" C[i, k] += A[i, j] * B[j, k]\n",
"```\n",
"\n",
"One example that bothered me was an einsum notation that returned the leading diagonal of the matrix, so I'll write it out here:\n",
"\n",
"```python\n",
"diag = np.einsum(\"ii -> i\", A)\n",
"```\n",
"\n",
"Here, we loop once over `i`, and then directly add in only elements from the leading diagonal.\n",
"\n",
"```python\n",
"for i in range(...):\n",
" diag[i] = 0\n",
" diag[i] += A[i, i]\n",
"```\n",
"\n",
"This makes it easy to work out complex notation like batched matrix multiplication, inner products, tensor contractions etc. Note that if the right hand side of the notation is empty, we initialize the sum outside all the loops, and hence get the sum of the complete tensor.\n",
"\n",
"### Log\n",
"\n",
"- Other stuff to see, because it's too late in the night now, https://en.wikipedia.org/wiki/Einstein_notation, https://www.continuummechanics.org/tensornotationbasic.html\n",
"- I don't understand how `A[:, None]` is working, nor what it is doing.\n",
" - Why is `A[: None].shape` giving me `(3, 1, 3)` when expected behavior is `(3, 3, 1)`? `A` is a 2D matrix of shape `(3, 3)`\n",
" - More confusion. `A[:, None, None]` returns `(3, 1, 1, 3)` and `A[:, :, None, None]` returns `(3, 3, 1, 1)`.\n",
"- Amazing what a good night's sleep will do. I instantly realized that `A[:, None]` adds an empty second dimension because we've placed the `None` second. It is equivalent to `A[:, None, :]`.\n",
" - If we want the shape to be `(3, 3, 1)`, we slice with the `None` in third place, like `A[:, :, None]`.\n",
"- https://stackoverflow.com/questions/tagged/numpy-einsum has many real world examples.\n",
"- [Tensor Contraction](https://en.wikipedia.org/wiki/Tensor_contraction) is the general form of Batch Matrix Multiplication.\n",
"- Bilinear transformation doesn't give the same result as einsum because PyTorch uses random weights to compute the transformation. I'm settling to check whether the shapes match."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "oN9B3B-wayg8"
},
"outputs": [],
"source": [
"import numpy as np"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"id": "Iq4nMIufkuMb"
},
"source": [
"### Multiply `A` and `B` and sum along the rows"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "K3WbwWDwd05z"
},
"outputs": [],
"source": [
"A = np.array([0, 1, 2])\n",
"B = np.array([[ 0, 1, 2, 3],\n",
" [ 4, 5, 6, 7],\n",
" [ 8, 9, 10, 11]])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "dADrYpfIeOUQ",
"outputId": "27627d5c-cfeb-41bf-d343-6309dc05a0c4"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Shape of A: (3,)\n",
"Shape of B: (3, 4)\n"
]
}
],
"source": [
"print(\"Shape of A: \", A.shape)\n",
"print(\"Shape of B: \", B.shape)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "vxMV3nrseQ58",
"outputId": "8456e069-5389-4cd7-cca3-6c18fe2d6332"
},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"(np.einsum(\"i,ij -> i\", A, B) == (A[:, np.newaxis] * B).sum(axis=1)).all()"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"id": "CXFLZLRulXiT"
},
"source": [
"### Matrix Multiplication\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "FHvs2DlflWwC"
},
"outputs": [],
"source": [
"A = np.array([[1, 1, 1],\n",
" [2, 2, 2],\n",
" [5, 5, 5]])\n",
"B = np.array([[0, 1, 0],\n",
" [1, 1, 0],\n",
" [1, 1, 1]])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "eaN7ra_ZkE2B",
"outputId": "88eed5ec-68a6-4834-b674-2742a94246ef"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Shape of A: (3, 3)\n",
"Shape of B: (3, 3)\n"
]
}
],
"source": [
"print(\"Shape of A: \", A.shape)\n",
"print(\"Shape of B: \", B.shape)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "bGJ0PuIdlsXB",
"outputId": "6b61746b-f21f-4b75-b7ea-1611794a09be"
},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"(np.einsum(\"ij,jk -> ik\", A, B) == np.matmul(A, B)).all()"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"id": "YMaLOXhbbcYT"
},
"source": [
"### Operations on 1D arrays\n",
"\n",
"| Call signature | NumPy equivalent | Description |\n",
"|:---|:---:|---:|\n",
"| `('i', A)` | `A` | Returns a view of A |\n",
"| `('i->', A)` | `sum(A)` | Sums the values of A |\n",
"| `('i,i->i', A, B)` | `A * B` | Element-wise multiplication of A and B |\n",
"| `('i,i', A, B)` | `inner(A, B)` | Inner product of A and B |\n",
"| `('i,j->ij', A, B)` | `outer(A, B)` | Outer product of A and B |\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "5jAM9Lfznkso"
},
"outputs": [],
"source": [
"A = np.array([1, 2, 3])\n",
"B = np.array([4, 5, 6])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "dInLim65cOhF",
"outputId": "55b18506-3027-4d70-e485-0d342f7bdc21"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Shape of A: (3,)\n",
"Shape of B: (3,)\n"
]
}
],
"source": [
"print(\"Shape of A: \", A.shape)\n",
"print(\"Shape of B: \", B.shape)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "oDOUCLWlcTM2",
"outputId": "0092729d-486e-40f0-adc2-85bdf0f2d700"
},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"(np.einsum(\"i -> i\", A) == A) .all()\n",
"(np.einsum(\"i -> \", A) == np.sum(A)).all()\n",
"\n",
"(np.einsum(\"i,i -> i \", A, B) == A * B) .all()\n",
"(np.einsum(\"i,i -> \", A, B) == np.matmul(A, B.T)).all()\n",
"(np.einsum(\"i,i -> \", A, B) == np.inner(A, B)) .all()\n",
"(np.einsum(\"i,j -> ij\", A, B) == np.outer(A, B)) .all()"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"id": "DOO-kfAXf5gA"
},
"source": [
"### Operations on 2D arrays\n",
"\n",
"| Call signature | NumPy equivalent | Description |\n",
"|:---|:---:|---:|\n",
"| `('ij', A)` | `A` | Returns a view of A |\n",
"| `('ji', A)` | `A.T` | View transpose of A |\n",
"| `('ii->i', A)` | `diag(A)` | View main diagonal of A |\n",
"| `('ii', A)` | `trace(A)` | Sums main diagonal of A |\n",
"| `('ij->', A)` | `sum(A)` | Sums the values of A |\n",
"| `('ij->j', A)` | `sum(A, axis=0)` | Sum down the columns of A (across rows) |\n",
"| `('ij->i', A)` | `sum(A, axis=1)` | Sum horizontally along the rows of A |\n",
"| `('ij,ij->ij', A, B)` | `A * B` | Element-wise multiplication of A and B |\n",
"| `('ij,ji->ij', A, B)` | `A * B.T` | Element-wise multiplication of A and B.T |\n",
"| `('ij,jk', A, B)` | `dot(A, B)` | Matrix multiplication of A and B |\n",
"| `('ij,kj->ik', A, B)` | `inner(A, B)` | Inner product of A and B |\n",
"| `('ij,kj->ikj', A, B)` | `A[:, None] * B` | Each row of A multiplied by B |\n",
"| `('ij,kl->ijkl', A, B)` | `A[:, :, None, None] * B` | Each value of A multiplied by B |\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "S2ibzmXtcaYL"
},
"outputs": [],
"source": [
"A = np.array([[1, 2, 3],\n",
" [4, 5, 6],\n",
" [7, 8, 9]])\n",
"B = np.array([[10, 11, 12],\n",
" [13, 14, 15],\n",
" [16, 17, 18]])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "4KADjITcfD11",
"outputId": "744ef885-2c28-41b0-f234-954ab409e02e"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Shape of A: (3, 3)\n",
"Shape of B: (3, 3)\n"
]
}
],
"source": [
"print(\"Shape of A: \", A.shape)\n",
"print(\"Shape of B: \", B.shape)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "TGpUUtq8lkMH",
"outputId": "2482195e-bae5-430b-c163-4038f8fe399f"
},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"(np.einsum(\"ij -> ij\", A) == A) .all()\n",
"(np.einsum(\"ji -> ij\", A) == A.T) .all()\n",
"(np.einsum(\"ii -> i\", A) == np.diag(A)) .all()\n",
"(np.einsum(\"ii -> \", A) == np.trace(A)) .all()\n",
"(np.einsum(\"ij -> \", A) == np.sum(A)) .all()\n",
"(np.einsum(\"ij -> j\", A) == np.sum(A, axis=0)).all()\n",
"(np.einsum(\"ij -> i\", A) == np.sum(A, axis=1)).all()\n",
"\n",
"(np.einsum(\"ij,ij -> ij\", A, B) == A * B) .all()\n",
"(np.einsum(\"ij,ji -> ij\", A, B) == A * B.T) .all()\n",
"(np.einsum(\"ij,jk -> ik\", A, B) == np.matmul(A, B)) .all()\n",
"(np.einsum(\"ij,jk -> \", A, B) == np.sum(np.matmul(A, B))).all()\n",
"(np.einsum(\"ij,kj -> ik\", A, B) == np.matmul(A, B.T)) .all()\n",
"(np.einsum(\"ij,kj -> ikj\", A, B) == A[:, None, :] * B) .all()\n",
"(np.einsum(\"ij,kl -> ijkl\", A, B) == A[:, :, None, None] * B).all()"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"id": "jpRWXYvxqkB_"
},
"source": [
"### PyTorch"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Oym4zTo2qk8P"
},
"outputs": [],
"source": [
"import torch"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "stIRExCAqnOj",
"outputId": "c43ac9fd-8913-485e-8658-13f36f72d1a1"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Shape of a: torch.Size([2, 3])\n",
"Shape of b: torch.Size([3])\n",
"Shape of c: torch.Size([3, 5])\n",
"Shape of d: torch.Size([3])\n",
"Shape of e: torch.Size([2, 3])\n",
"Shape of f: torch.Size([4])\n",
"Shape of g: torch.Size([3, 2, 5])\n",
"Shape of h: torch.Size([3, 5, 3])\n",
"Shape of i: torch.Size([2, 3, 5, 7])\n",
"Shape of j: torch.Size([11, 13, 3, 17, 5])\n",
"Shape of k: torch.Size([2, 3])\n",
"Shape of l: torch.Size([5, 3, 7])\n",
"Shape of m: torch.Size([2, 7])\n"
]
}
],
"source": [
"a = torch.arange(6).reshape(2, 3)\n",
"b = torch.arange(3)\n",
"c = torch.arange(15).reshape(3, 5)\n",
"d = torch.arange(3, 6)\n",
"e = torch.arange(6, 12).reshape(2, 3)\n",
"f = torch.arange(3, 7)\n",
"g = torch.randn(3, 2, 5)\n",
"h = torch.randn(3, 5, 3)\n",
"i = torch.randn(2, 3, 5, 7)\n",
"j = torch.randn(11, 13, 3, 17, 5)\n",
"k = torch.randn(2, 3)\n",
"l = torch.randn(5, 3, 7)\n",
"m = torch.randn(2, 7)\n",
"\n",
"print(\"Shape of a: \", a.shape)\n",
"print(\"Shape of b: \", b.shape)\n",
"print(\"Shape of c: \", c.shape)\n",
"print(\"Shape of d: \", d.shape)\n",
"print(\"Shape of e: \", e.shape)\n",
"print(\"Shape of f: \", f.shape)\n",
"print(\"Shape of g: \", g.shape)\n",
"print(\"Shape of h: \", h.shape)\n",
"print(\"Shape of i: \", i.shape)\n",
"print(\"Shape of j: \", j.shape)\n",
"print(\"Shape of k: \", k.shape)\n",
"print(\"Shape of l: \", l.shape)\n",
"print(\"Shape of m: \", m.shape)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "bX_MrZb2q4AZ",
"outputId": "a100d48e-bb36-49e7-c376-8cadf22645f9"
},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 56,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"(torch.einsum(\"ij -> ji\", a) == a.T) .all()\n",
"(torch.einsum(\"ij -> \", a) == torch.sum(a)) .all()\n",
"(torch.einsum(\"ij -> j\", a) == torch.sum(a, axis=0)).all()\n",
"(torch.einsum(\"ij -> i\", a) == torch.sum(a, axis=1)).all()\n",
"\n",
"(torch.einsum(\"ij,j -> i\", a, b) == torch.matmul(a, b)).all()\n",
"(torch.einsum(\"ij,jk -> ik\", a, c) == torch.matmul(a, c)).all()\n",
"(torch.einsum(\"i,i -> \", b, d) == torch.dot(b, d)) .all()\n",
"(torch.einsum(\"ij,ij -> \", a, e) == torch.sum(a * e)) .all()\n",
"(torch.einsum(\"ij,ij -> ij\", a, e) == a * e) .all()\n",
"(torch.einsum(\"i,j -> ij\", b, f) == torch.outer(b, f)) .all()\n",
"\n",
"# Batched Matrix Multiplication\n",
"(torch.einsum(\"bmn,bno -> bmo\", g, h) == torch.bmm(g, h)).all()\n",
"\n",
"# Tensor Contraction\n",
"(torch.einsum(\"ijkl,mnjok -> ilmno\", i, j) == torch.tensordot(i, j, dims=([1, 2], [2, 4]))).all()\n",
"\n",
"# Bilinear Transformation\n",
"(torch.einsum(\"ij,kjl,il -> ik\", k, l, m).shape == torch.nn.Bilinear(k.shape[1], m.shape[1], 5, bias=False)(k, m).shape)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"id": "eu98UuPy27pW"
},
"source": [
"### Multi-layer Perceptron\n",
"\n",
"An MLP class that uses `einsum` instead of `matmul`, that can be used to classify MNIST images. Note the intuitive notation that helps us label axes appropriately. Backpropagation works in the usual manner after loading data and choosing an optimizer and loss."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "bRq88eanMWwP"
},
"outputs": [],
"source": [
"batch_size = 32\n",
"input_size = 784\n",
"hidden_size = 500\n",
"output_size = 10\n",
"\n",
"class MLP(torch.nn.Module):\n",
" def __init__(self):\n",
"\n",
" super().__init__()\n",
"\n",
" self.W = torch.nn.Parameter(torch.randn(input_size, hidden_size))\n",
" self.b = torch.nn.Parameter(torch.randn(hidden_size))\n",
" self.V = torch.nn.Parameter(torch.randn(hidden_size, output_size))\n",
" self.c = torch.nn.Parameter(torch.randn(output_size))\n",
"\n",
" def forward(self, x):\n",
" \"\"\"\n",
" h = sigmoid(Wx + b)\n",
" y = softmax(Vh + c)\n",
" \"\"\"\n",
"\n",
" assert x.shape == (batch_size, input_size) # x: (batch_size, input_size)\n",
"\n",
" h = torch.sigmoid(torch.einsum(\"bi,ih,h -> bh\", x, self.W, self.b))\n",
" y = torch.softmax(torch.einsum(\"bh,ho,o -> bo\", h, self.V, self.c),\n",
" dim=-1)\n",
" return y"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"id": "Gwix8iVFAT2h"
},
"source": [
"### Self-Attention\n",
"\n",
"A class that implements the multi-head self attention algorithm to an input sequence. Note how we can skip all the calls to `transpose`, `contiguous`, and `reshape` by using einsum. Work out the multiplications by yourself to see how the notation helps writing code."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Gme92KisAVFr"
},
"outputs": [],
"source": [
"class SelfAttention(nn.Module):\n",
" def __init__(self, k, heads=8):\n",
" \"\"\"\n",
" Initialize the MultiheadAttention layer.\n",
"\n",
" Args:\n",
" k : The embedding dimension.\n",
" heads: The number of attention heads.\n",
"\n",
" Raises:\n",
" ValueError: If k is not divisible by heads.\n",
" \"\"\"\n",
"\n",
" super().__init__()\n",
"\n",
" assert k % heads == 0, \"k must be divisible by the number of heads\"\n",
"\n",
" self.k = k\n",
" self.heads = heads\n",
"\n",
" self.tokeys = nn.Linear(k, k, bias=False)\n",
" self.toqueries = nn.Linear(k, k, bias=False)\n",
" self.tovalues = nn.Linear(k, k, bias=False)\n",
" self.unifyheads = nn.Linear(heads * (k // heads), k)\n",
"\n",
" def forward(self, x):\n",
" \"\"\"\n",
" Applies multi-head attention to the input sequence.\n",
"\n",
" Args:\n",
" x (torch.Tensor): The input sequence, of shape (b, t, k).\n",
"\n",
" Returns:\n",
" torch.Tensor: The output sequence, of shape (b, t, k).\n",
"\n",
" Raises:\n",
" ValueError: If the input embedding dimension does not match the layer embedding dimension.\n",
" \"\"\"\n",
"\n",
" b, t, k = x.size()\n",
" h = self.heads\n",
" assert k == self.k, f'Input embedding dimension={k} does not match layer embedding dimension={self.k}'\n",
"\n",
" # (b, t, k)\n",
" queries = self.toqueries(x)\n",
" keys = self.tokeys(x)\n",
" values = self.tovalues(x)\n",
"\n",
" s = k // h\n",
"\n",
" # (b, t, k) --> (b, t, h, s)\n",
" queries = queries.reshape(b, t, h, s)\n",
" keys = keys .reshape(b, t, h, s)\n",
" values = values .reshape(b, t, h, s)\n",
"\n",
" # (b, h, t, t)\n",
" dot = torch.einsum(\"bths,bdhs -> bhtd\", (queries, keys)) / math.sqrt(k)\n",
" dot = F.softmax(dot, dim=3)\n",
"\n",
" # (b, t, h, s)\n",
" out = torch.einsum(\"bhtl,blhs -> bths\", (dot, values))\n",
"\n",
" # (b, t, k)\n",
" out = out.reshape(b, t, h * s)\n",
" return self.unifyheads(out)"
]
}
],
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment