Skip to content

Instantly share code, notes, and snippets.

Last active Sep 16, 2019
What would you like to do?
multidimensional svd pytorch
import itertools
import torch
def svd(x):
batches = x.shape[:-2]
if batches:
n, m = x.shape[-2:]
k = min(n, m)
U, d, V =*batches, n, k),*batches, k),*batches, m, k)
for idx in itertools.product(*map(range, batches)):
U[idx], d[idx], V[idx] = torch.svd(x[idx])
return U, d, V
return torch.svd(x)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment