Skip to content

Instantly share code, notes, and snippets.

@dmaniry
Last active August 17, 2021 14:06
  • Star 14 You must be signed in to star a gist
  • Fork 9 You must be signed in to fork a gist
Star You must be signed in to star a gist
Save dmaniry/5170087 to your computer and use it in GitHub Desktop.
import numpy as np
from scipy import linalg
from sklearn.utils import array2d, as_float_array
from sklearn.base import TransformerMixin, BaseEstimator
class ZCA(BaseEstimator, TransformerMixin):
def __init__(self, regularization=10**-5, copy=False):
self.regularization = regularization
self.copy = copy
def fit(self, X, y=None):
X = array2d(X)
X = as_float_array(X, copy = self.copy)
self.mean_ = np.mean(X, axis=0)
X -= self.mean_
sigma = np.dot(X.T,X) / X.shape[1]
U, S, V = linalg.svd(sigma)
tmp = np.dot(U, np.diag(1/np.sqrt(S+self.regularization)))
self.components_ = np.dot(tmp, U.T)
return self
def transform(self, X):
X = array2d(X)
X_transformed = X - self.mean_
X_transformed = np.dot(X_transformed, self.components_.T)
return X_transformed
@JaeDukSeo
Copy link

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment