Created
August 8, 2024 17:11
-
-
Save jxbz/fe235ee1c72b8b41ccd0d02b43378cf2 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" Computing zeroth matrix powers via Lakic 1998. | |
paper: "On the Computation of the Matrix k-th Root" | |
Suppose we have a matrix G = USV^T and we want to compute | |
G^0 defined via G^0 = UV^T. We might want to do this to run | |
"stochastic spectral descent" of Carlson et al 2015. The | |
naive way to do this is via the SVD. But we can also just do | |
(GG^T)^(-1/2) G or alternatively G (G^TG)^(-1/2) and apply | |
the iterative method from Lakic 1998. | |
In particular, we implement the first special case of Alg 1 | |
in that paper. | |
""" | |
import torch | |
def zeroth_power_via_newton(G, steps=20): | |
device = G.device | |
d1, d2 = G.shape | |
d = min(d1, d2) | |
# store the smaller of the squares as S | |
S = G @ G.t() if d1 < d2 else G.t() @ G | |
S_norm = torch.linalg.matrix_norm(S, ord='fro') # there is freedom here. See Lakic (1998) Thm 2.3 | |
S /= S_norm | |
# Now let's set up the state for the Lakic (1998) method | |
N = S | |
X = torch.eye(d).to(device) | |
I = torch.eye(d).to(device) | |
# Now let's run the iteration | |
for _ in range(steps): | |
U = (3 * I - N) / 2 | |
X = X @ U | |
N = N @ U @ U | |
# X should now store either (G G^T)^(-1/2) or (G^T G)^(-1/2) | |
return X @ G / S_norm.sqrt() if d1 < d2 else G @ X / S_norm.sqrt() | |
def zeroth_power_via_svd(G): | |
U,S,V = G.svd() | |
return U @ V.t() | |
# Let's test it on a random Gaussian matrix | |
G = torch.randn(100, 100) | |
G_zero_newton = zeroth_power_via_newton(G) | |
G_zero_svd = zeroth_power_via_svd(G) | |
# Check the singular values are all one | |
print(G_zero_newton.svd()[1]) | |
print(G_zero_svd.svd()[1]) | |
# Print the error | |
# Seems like relative Frobenius error is sensible here | |
error = torch.linalg.matrix_norm(G_zero_newton - G_zero_svd, ord='fro') | |
error /= torch.linalg.matrix_norm(G_zero_svd, ord='fro') | |
print(error.item()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment