Copied from torch
's linear algebra test file.
torch.einsum('i->', x) # sum
torch.einsum('i,i->', x, x) # dot
torch.einsum('i,i->i', x, x) # vector element-wisem mul
torch.einsum('i,j->ij', x, y) # outer
torch.einsum("ij->ji", A) # transpose
torch.einsum("ij->j", A) # row sum
torch.einsum("ij->i", A) # col sum
torch.einsum("ij,ij->ij", A, A) # matrix element-wise mul
torch.einsum("ij,j->i", A, x) # matrix vector multiplication
torch.einsum("ij,kj->ik", A, B) # matmul
torch.einsum("ij,ab->ijab", A, E) # matrix outer product
torch.einsum("Aij,Ajk->Aik", C, D) # batch matmul
torch.einsum("ijk,jk->i", C, A) # tensor matrix contraction
torch.einsum("aij,jk->aik", D, E) # tensor matrix contraction
torch.einsum("abCd,dFg->abCFg", F, G) # tensor tensor contraction
torch.einsum("ijk,jk->ik", C, A) # tensor matrix contraction with double indices
torch.einsum("ijk,jk->ij", C, A) # tensor matrix contraction with double indices
torch.einsum("ijk,ik->j", C, B) # non contiguous
torch.einsum("ijk,ik->jk", C, B) # non contiguous with double indices
torch.einsum("ii", H) # trace
torch.einsum("ii->i", H) # diagonal
torch.einsum('iji->j', I) # non-contiguous trace
torch.einsum('ngrg...->nrg...', torch.randn((2, 1, 3, 1, 4), device, dtype))
torch.einsum("i...->...", H)
torch.einsum("ki,...k->i...", A.t(), B)
torch.einsum("k...,jk->...", A.t(), B)
torch.einsum('...ik, ...j -> ...ij', C, x)
torch.einsum('Bik,k...j->i...j', C, torch.randn((5, 3), device, dtype))
torch.einsum('i...j, ij... -> ...ij', C, torch.randn((2, 5, 2, 3), device, dtype))