Instantly share code, notes, and snippets.

# scaomath/einsum.md

Created July 6, 2021 21:23
Show Gist options
• Save scaomath/19ab432b43f6b082387624c64deb9648 to your computer and use it in GitHub Desktop.
Einstein sum reference

# Einstein sum refs

Copied from `torch`'s linear algebra test file.

## Vector operations

```torch.einsum('i->', x)                     # sum
torch.einsum('i,i->', x, x)                # dot
torch.einsum('i,i->i', x, x)               # vector element-wisem mul
torch.einsum('i,j->ij', x, y)              # outer```

## Matrix operations

```torch.einsum("ij->ji", A)                  # transpose
torch.einsum("ij->j", A)                   # row sum
torch.einsum("ij->i", A)                   # col sum
torch.einsum("ij,ij->ij", A, A)            # matrix element-wise mul
torch.einsum("ij,j->i", A, x)              # matrix vector multiplication
torch.einsum("ij,kj->ik", A, B)            # matmul
torch.einsum("ij,ab->ijab", A, E)          # matrix outer product```

## Tensor operations

```torch.einsum("Aij,Ajk->Aik", C, D)         # batch matmul
torch.einsum("ijk,jk->i", C, A)            # tensor matrix contraction
torch.einsum("aij,jk->aik", D, E)          # tensor matrix contraction
torch.einsum("abCd,dFg->abCFg", F, G)      # tensor tensor contraction
torch.einsum("ijk,jk->ik", C, A)           # tensor matrix contraction with double indices
torch.einsum("ijk,jk->ij", C, A)           # tensor matrix contraction with double indices
torch.einsum("ijk,ik->j", C, B)            # non contiguous
torch.einsum("ijk,ik->jk", C, B)           # non contiguous with double indices```

## Test diagonals

```torch.einsum("ii", H)                      # trace
torch.einsum("ii->i", H)                   # diagonal
torch.einsum('iji->j', I)                  # non-contiguous trace
torch.einsum('ngrg...->nrg...', torch.randn((2, 1, 3, 1, 4), device, dtype))```

## Test ellipsis

```torch.einsum("i...->...", H)
torch.einsum("ki,...k->i...", A.t(), B)
torch.einsum("k...,jk->...", A.t(), B)
torch.einsum('...ik, ...j -> ...ij', C, x)
torch.einsum('Bik,k...j->i...j', C, torch.randn((5, 3), device, dtype))
torch.einsum('i...j, ij... -> ...ij', C, torch.randn((2, 5, 2, 3), device, dtype))```