Skip to content

Instantly share code, notes, and snippets.

@joelouismarino
Last active October 18, 2023 18:08
Show Gist options
  • Star 5 You must be signed in to star a gist
  • Fork 4 You must be signed in to fork a gist
  • Save joelouismarino/ce239b5601fff2698895f48003f7464b to your computer and use it in GitHub Desktop.
Save joelouismarino/ce239b5601fff2698895f48003f7464b to your computer and use it in GitHub Desktop.
import numpy as np
def whiten(X, method='zca'):
"""
Whitens the input matrix X using specified whitening method.
Inputs:
X: Input data matrix with data examples along the first dimension
method: Whitening method. Must be one of 'zca', 'zca_cor', 'pca',
'pca_cor', or 'cholesky'.
"""
X = X.reshape((-1, np.prod(X.shape[1:])))
X_centered = X - np.mean(X, axis=0)
Sigma = np.dot(X_centered.T, X_centered) / X_centered.shape[0]
W = None
if method in ['zca', 'pca', 'cholesky']:
U, Lambda, _ = np.linalg.svd(Sigma)
if method == 'zca':
W = np.dot(U, np.dot(np.diag(1.0 / np.sqrt(Lambda + 1e-5)), U.T))
elif method =='pca':
W = np.dot(np.diag(1.0 / np.sqrt(Lambda + 1e-5)), U.T)
elif method == 'cholesky':
W = np.linalg.cholesky(np.dot(U, np.dot(np.diag(1.0 / (Lambda + 1e-5)), U.T))).T
elif method in ['zca_cor', 'pca_cor']:
V_sqrt = np.diag(np.std(X, axis=0))
P = np.dot(np.dot(np.linalg.inv(V_sqrt), Sigma), np.linalg.inv(V_sqrt))
G, Theta, _ = np.linalg.svd(P)
if method == 'zca_cor':
W = np.dot(np.dot(G, np.dot(np.diag(1.0 / np.sqrt(Theta + 1e-5)), G.T)), np.linalg.inv(V_sqrt))
elif method == 'pca_cor':
W = np.dot(np.dot(np.diag(1.0/np.sqrt(Theta + 1e-5)), G.T), np.linalg.inv(V_sqrt))
else:
raise Exception('Whitening method not found.')
return np.dot(X_centered, W.T)
@JaeDukSeo
Copy link

thanks for this!

@cwindolf
Copy link

cwindolf commented Oct 23, 2019

Hi, based on https://rdrr.io/cran/whitening/src/R/whiteningMatrix.R and my tests, I think that in zca_cor/pca_cor you should be using 1/V_sqrt where you are using V_sqrt to finish the expressions for W. See my fork if I'm being unclear. Thanks a bunch for this nice gist!

@joelouismarino
Copy link
Author

@cwindolf Thanks for the catch! I've updated the gist to use the inverse matrices.

@RonenMelamed
Copy link

Thanks, very useful!

@thistleknot
Copy link

TU, now how do you inverse the transform?

@foiv0s
Copy link

foiv0s commented May 10, 2022

Thanks!!

@rraadd88
Copy link

Alternative method: using whiten parameter of sklearn's PCA.

from sklearn.decomposition import PCA
X_white = PCA(n_components = X.shape[1], whiten = True, svd_solver='full').fit_transform(X)

Note that the rotation of the transformed data may differ.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment