Created
March 3, 2022 23:04
-
-
Save Micky774/5a2c3ecb16f749ad089fcd9d643bac91 to your computer and use it in GitHub Desktop.
Test correctness of `eigh` solver in `FastICA`
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from scipy import linalg | |
import numpy as np | |
from sklearn.utils._testing import assert_array_almost_equal | |
import warnings | |
def assert_sign_redundant(x,y): | |
X, Y = x.copy(), y.copy() | |
for A in (X,Y): | |
for c in range(A.shape[1]): | |
if A[0,c] < 0: | |
A[:,c] *= -1 | |
assert_array_almost_equal(X, Y) | |
for i in range(100): | |
X = np.random.rand(2000, 300) | |
K = min(X.shape) | |
d, u = linalg.eigh(X.T.dot(X)) | |
eps = np.finfo(np.double).eps | |
degenerate_idx = d < eps | |
if np.any(degenerate_idx): | |
warnings.warn( | |
"There are some small singular values, using " | |
"svd_solver = 'svd' might lead to more " | |
"accurate results." | |
) | |
d[degenerate_idx] = eps # For numerical issues | |
d = np.sqrt(d) | |
# Reorder based on eigenvalues | |
sort_indices = np.argsort(d)[::-1] | |
d, u = d[sort_indices], u[sort_indices] | |
# Reshape to match SVD | |
u = u[:,:K] | |
# Reorder to match SVD | |
L = np.eye(u.shape[0])[::-1] | |
R = np.eye(u.shape[1])[::-1] | |
u = L@u@R | |
u2, d2 = linalg.svd(X.T, full_matrices=False, check_finite=False)[:2] | |
assert_sign_redundant(u,u2) | |
assert_array_almost_equal(d,d2) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment