Last active Sep 16, 2019
multidimensional svd pytorch
import itertools
import torch
def svd(x):
batches = x.shape[:-2]
if batches:
n, m = x.shape[-2:]
k = min(n, m)
U, d, V =*batches, n, k),*batches, k),*batches, m, k)
for idx in itertools.product(*map(range, batches)):
U[idx], d[idx], V[idx] = torch.svd(x[idx])
return U, d, V
return torch.svd(x)
