Skip to content

Instantly share code, notes, and snippets.

@xmodar
Last active September 19, 2019 00:21
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save xmodar/07b7627c92e05ca1a08210d27e3e861b to your computer and use it in GitHub Desktop.
Save xmodar/07b7627c92e05ca1a08210d27e3e861b to your computer and use it in GitHub Desktop.
Principal Component Analysis (PCA) with pytorch and numpy
from collections import namedtuple
import numpy as np
class PCA:
def __init__(self, features, variance=None):
if variance is None:
pca = self.fit(features).pca
features = pca.projection_matrix
variance = pca.variance
self.projection_matrix, self.variance = features, variance
@classmethod
def fit(cls, features, components=None, explained_variance=None):
# compute SVD
inplace = True
if isinstance(features, np.ndarray):
u, s, v = np.linalg.svd(features, full_matrices=False)
v = v.T
else:
u, s, v = features.svd() # pytorch
inplace = not s.requires_grad
# compute explained variance per principal component
variance = s * s
if inplace:
variance /= variance.sum()
else:
variance = variance / variance.sum()
# compute the projection (reduced features), if needed
pca = cls(v, variance)
if components is explained_variance is None:
projection = None
else:
keep = features.shape[-1]
if components is not None:
keep = min(keep, components)
if explained_variance is not None:
keep = min(keep, pca.num_components(explained_variance))
projection = u[:,:keep] * s[:keep]
Output = namedtuple('PCA', ['pca', 'projection'])
return Output(pca, projection)
def explained_variance(self, num_components=None):
return self.variance[:num_components].sum()
def num_components(self, explained_variance=1):
return (self.variance.cumsum(0) < explained_variance).sum() + 1
def project(self, features, keep=None):
return features @ self.projection_matrix[:, :keep]
def restore(self, features):
keep = features.shape[-1]
return features @ self.projection_matrix[:, :keep].T
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment