Skip to content

Instantly share code, notes, and snippets.

@Micky774
Created March 3, 2022 23:04
Show Gist options
  • Save Micky774/5a2c3ecb16f749ad089fcd9d643bac91 to your computer and use it in GitHub Desktop.
Save Micky774/5a2c3ecb16f749ad089fcd9d643bac91 to your computer and use it in GitHub Desktop.
Test correctness of `eigh` solver in `FastICA`
from scipy import linalg
import numpy as np
from sklearn.utils._testing import assert_array_almost_equal
import warnings
def assert_sign_redundant(x,y):
X, Y = x.copy(), y.copy()
for A in (X,Y):
for c in range(A.shape[1]):
if A[0,c] < 0:
A[:,c] *= -1
assert_array_almost_equal(X, Y)
for i in range(100):
X = np.random.rand(2000, 300)
K = min(X.shape)
d, u = linalg.eigh(X.T.dot(X))
eps = np.finfo(np.double).eps
degenerate_idx = d < eps
if np.any(degenerate_idx):
warnings.warn(
"There are some small singular values, using "
"svd_solver = 'svd' might lead to more "
"accurate results."
)
d[degenerate_idx] = eps # For numerical issues
d = np.sqrt(d)
# Reorder based on eigenvalues
sort_indices = np.argsort(d)[::-1]
d, u = d[sort_indices], u[sort_indices]
# Reshape to match SVD
u = u[:,:K]
# Reorder to match SVD
L = np.eye(u.shape[0])[::-1]
R = np.eye(u.shape[1])[::-1]
u = L@u@R
u2, d2 = linalg.svd(X.T, full_matrices=False, check_finite=False)[:2]
assert_sign_redundant(u,u2)
assert_array_almost_equal(d,d2)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment