Skip to content

Instantly share code, notes, and snippets.

@mw3i
Created June 17, 2019 16:04
Show Gist options
  • Save mw3i/daec09321caa620fa38d0d51d7e92dcd to your computer and use it in GitHub Desktop.
Save mw3i/daec09321caa620fa38d0d51d7e92dcd to your computer and use it in GitHub Desktop.
LDA in pure Numpy
'''
Implementation of LDA with Numpy (using covariance & scatter matrix), based on this tutorial by Sebastian Raschka: https://sebastianraschka.com/Articles/2014_python_lda.html
'''
import numpy as np
def get_components(data: np.ndarray, labels: np.ndarray) -> np.ndarray: # <-- using covariance method
label_set = np.unique(labels)
class_means = np.array([
data[labels == label,:].mean(axis = 0, keepdims = True)
for label in label_set
])
## Within Class Covariance Matrix
class_cov_mats = np.array([
np.cov(data[labels == label,:].T)
for label in label_set
]).sum(axis = 0)
overall_means = data.mean(axis = 0, keepdims = True)
## Between Class Scatter Matrix
overall_scat_mats = np.array([
data[labels == l,:].shape[0] * (class_means[l] - overall_means).T @ (class_means[l] - overall_means)
for l in label_set
]).sum(axis = 0)
## calculate eigenvalues of matmul of within_class_variability(inv) & between_class_variability
eig_vals, eig_vecs = np.linalg.eig(np.linalg.inv(class_cov_mats) @ overall_scat_mats)
# sort components, largest to smallest
idx_sort = np.flip(eig_vals.argsort()) # <-- get ordering of eigenvectors: largest to smallest
components = eig_vecs[:,idx_sort]
return components
## run example:
if __name__ == '__main__':
##__Generate Data
num_features = 3
c1 = np.random.normal(-2,1,[50,num_features])
labels_c1 = [0]*50
c2 = np.random.normal(0,1,[50,num_features])
labels_c2 = [1]*50
c3 = np.random.normal(2,1,[50,num_features])
labels_c3 = [2]*50
data = np.concatenate([c1,c2,c3], axis = 0)
labels = np.array(labels_c1 + labels_c2 + labels_c3)
##__Get Components
components = get_components(data, labels)
##__Transform data using top 2 components (ie, matmul)
num_components = 2
transformed_data = data @ components[:,:num_components]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment