Skip to content

Instantly share code, notes, and snippets.

@norabelrose
Last active December 27, 2023 21:31
Show Gist options
  • Save norabelrose/3f7a553f4d69de3cf5bda93e2264a9c9 to your computer and use it in GitHub Desktop.
Save norabelrose/3f7a553f4d69de3cf5bda93e2264a9c9 to your computer and use it in GitHub Desktop.
Fast, optimal Kronecker decomposition
from einops import rearrange
from torch import Tensor
import torch
def kronecker_decompose(
A: Tensor, m: int, n: int, *, k: int = 1, niter: int = 10
) -> tuple[Tensor, Tensor]:
"""Frobenius-optimal decomposition of `A` into a sum of `k` Kronecker products.
Algorithm from Van Loan and Pitsianis (1993), "Approximation with Kronecker Products"
<https://bit.ly/46hT5aY>.
Args:
A: Matrix or batch of matrices to decompose, of shape (..., m * m2, n * n2)
m: Desired number of rows in the left Kronecker factor(s)
n: Desired number of columns in the left Kronecker factor(s)
k: Number of Kronecker factors
niter: Number of iterations for the low rank SVD algorithm
Returns:
Tuple of Kronecker factors (`left`, `right`) of shape `(..., k, m, n)` and
`(..., k, A.shape[-2] // m, A.shape[-1] // n)` respectively.
Raises:
AssertionError: If the dimensions of `A` are not compatible with the desired
number of rows and columns in the left Kronecker factor.
"""
m2, n2 = A.shape[-2] // m, A.shape[-1] // n
assert A.shape[-2:] == (m * m2, n * n2), "Dimensions do not match"
# Reshape and permute A, then perform SVD
A = rearrange(A, "... (m m2) (n n2) -> ... (m n) (m2 n2)", m=m, m2=m2, n=n, n2=n2)
u, s, v = torch.svd_lowrank(A, q=k, niter=niter)
# Unflatten the factors
u = rearrange(u, "... (m n) k -> ... k m n", m=m, n=n, k=k)
v = rearrange(v, "... (m2 n2) k -> ... k m2 n2", m2=m2, n2=n2, k=k)
scale = s[..., None, None].sqrt()
return u * scale, v * scale
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment