Skip to content

Instantly share code, notes, and snippets.

@ferrine
Last active September 16, 2019 12:11
Show Gist options
  • Save ferrine/0c0e03bd21323a048baab8dadc83cdcc to your computer and use it in GitHub Desktop.
Save ferrine/0c0e03bd21323a048baab8dadc83cdcc to your computer and use it in GitHub Desktop.
multidimensional svd pytorch
import itertools
import torch
def svd(x):
# https://discuss.pytorch.org/t/multidimensional-svd/4366/2
batches = x.shape[:-2]
if batches:
n, m = x.shape[-2:]
k = min(n, m)
U, d, V = x.new(*batches, n, k), x.new(*batches, k), x.new(*batches, m, k)
for idx in itertools.product(*map(range, batches)):
U[idx], d[idx], V[idx] = torch.svd(x[idx])
return U, d, V
else:
return torch.svd(x)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment