Skip to content

Instantly share code, notes, and snippets.

@Palpatineli
Created April 22, 2020 21:43
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Palpatineli/b5b7733634c96224b766ae7116443a14 to your computer and use it in GitHub Desktop.
Save Palpatineli/b5b7733634c96224b766ae7116443a14 to your computer and use it in GitHub Desktop.
common pca: stepwise algorithm to find the nth common principle components
import numpy as np
from scipy.linalg import eigh
def cpca(cov: np.ndarray, sample_n: np.ndarray, comp_n: int = 0, tol: float = 1E-6,
max_iter: int = 1000) -> np.ndarray:
"""
Args:
cov: 3D array where the last 2 axes are covariance matrices.
sample_n: for each covariance, how many samples were in there.
"""
cov = np.asarray(cov)
sample_n = np.asarray(sample_n)
s = ((sample_n / sample_n.sum()).reshape(-1, 1, 1) * cov).sum(0)
p = cov[0].shape[0]
comp_n = comp_n if comp_n > 0 else p
q0 = eigh(s, eigvals=(p - comp_n, p - 1))[1].T
qw = np.eye(p)
D = list()
components = list()
convergence = list()
initialized = False
for q in q0:
d = (q.T @ (cov.swapaxes(1, 2) @ q)).ravel()
cost_0 = 0
for _ in range(max_iter):
s += (sample_n / d).reshape(-1, 1, 1) * cov
w = s.T @ q
if initialized:
w = qw @ w
q = w / np.sqrt(w.T @ w)
d = (q.T @ (cov.swapaxes(1, 2) @ q)).ravel()
cost = (np.log(d) * sample_n).sum()
if abs(cost - cost_0) / cost < tol:
convergence.append(True)
break
cost_0 = cost
else:
convergence.append(False)
D.append(d)
components.append(q)
qw = qw - q @ q.T
initialized = True
return np.asarray(D), np.asarray(components).T, np.asarray(convergence)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment