Skip to content

Instantly share code, notes, and snippets.

@mbillingr
Created August 14, 2014 18:47
Show Gist options
  • Save mbillingr/f50a5f2b041d760dc293 to your computer and use it in GitHub Desktop.
Save mbillingr/f50a5f2b041d760dc293 to your computer and use it in GitHub Desktop.
Comparison of LDA transforms
import numpy as np
from sklearn.lda import LDA
from sklearn.datasets import make_blobs
from sklearn.utils import check_array
from sklearn.svm import LinearSVC
import matplotlib.pyplot as plt
def new_transform(lda, X):
X = check_array(X)
# center and scale data
X_new = np.dot(X - lda.xbar_, lda.scalings_)
n_components = X.shape[1] if lda.n_components is None \
else lda.n_components
return X_new[:, :n_components]
def old_transform(lda, X):
X = check_array(X)
# center and scale data
X = np.dot(X - lda.xbar_, lda.scalings_)
n_comp = X.shape[1] if lda.n_components is None \
else lda.n_components
return np.dot(X, lda.coef_[:n_comp].T)
def plot_transform(X, y, transform, title):
lda = LDA(n_components=2).fit(X, y)
Z = transform(lda, X)
plt.scatter(Z[:, 0], Z[:, 1], c=y)
plt.gca().set_aspect('equal')
acc = LinearSVC().fit(Z, y).score(Z, y)
plt.title(title + ' LinearSVC score: ' + str(acc))
X, y = make_blobs(1000, centers=3, center_box=(-2, 2), random_state=1999)
plt.subplot(2, 2, 1)
plot_transform(X, y, old_transform, 'old')
plt.subplot(2, 2, 3)
plot_transform(X, y, new_transform, 'new')
X, y = make_blobs(1000, centers=3, center_box=(-10, 10), random_state=1999)
plt.subplot(2, 2, 2)
plot_transform(X, y, old_transform, 'old')
plt.subplot(2, 2, 4)
plot_transform(X, y, new_transform, 'new')
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment