Skip to content

Instantly share code, notes, and snippets.

@scaomath
Created Jul 6, 2021
Embed
What would you like to do?
Einstein sum reference

Einstein sum refs

Copied from torch's linear algebra test file.

Vector operations

torch.einsum('i->', x)                     # sum
torch.einsum('i,i->', x, x)                # dot
torch.einsum('i,i->i', x, x)               # vector element-wisem mul
torch.einsum('i,j->ij', x, y)              # outer

Matrix operations

torch.einsum("ij->ji", A)                  # transpose
torch.einsum("ij->j", A)                   # row sum
torch.einsum("ij->i", A)                   # col sum
torch.einsum("ij,ij->ij", A, A)            # matrix element-wise mul
torch.einsum("ij,j->i", A, x)              # matrix vector multiplication
torch.einsum("ij,kj->ik", A, B)            # matmul
torch.einsum("ij,ab->ijab", A, E)          # matrix outer product

Tensor operations

torch.einsum("Aij,Ajk->Aik", C, D)         # batch matmul
torch.einsum("ijk,jk->i", C, A)            # tensor matrix contraction
torch.einsum("aij,jk->aik", D, E)          # tensor matrix contraction
torch.einsum("abCd,dFg->abCFg", F, G)      # tensor tensor contraction
torch.einsum("ijk,jk->ik", C, A)           # tensor matrix contraction with double indices
torch.einsum("ijk,jk->ij", C, A)           # tensor matrix contraction with double indices
torch.einsum("ijk,ik->j", C, B)            # non contiguous
torch.einsum("ijk,ik->jk", C, B)           # non contiguous with double indices

Test diagonals

torch.einsum("ii", H)                      # trace
torch.einsum("ii->i", H)                   # diagonal
torch.einsum('iji->j', I)                  # non-contiguous trace
torch.einsum('ngrg...->nrg...', torch.randn((2, 1, 3, 1, 4), device, dtype))

Test ellipsis

torch.einsum("i...->...", H)
torch.einsum("ki,...k->i...", A.t(), B)
torch.einsum("k...,jk->...", A.t(), B)
torch.einsum('...ik, ...j -> ...ij', C, x)
torch.einsum('Bik,k...j->i...j', C, torch.randn((5, 3), device, dtype))
torch.einsum('i...j, ij... -> ...ij', C, torch.randn((2, 5, 2, 3), device, dtype))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment