Skip to content

Instantly share code, notes, and snippets.

@amueller
Created December 15, 2012 21:26
  • Star 5 You must be signed in to star a gist
  • Fork 2 You must be signed in to fork a gist
Star You must be signed in to star a gist
Save amueller/4299381 to your computer and use it in GitHub Desktop.
Plotting PCAs of pairs of MNIST digit classes
import numpy as np
import matplotlib.pyplot as plt
from itertools import product
from sklearn.decomposition import RandomizedPCA
from sklearn.datasets import fetch_mldata
from sklearn.utils import shuffle
mnist = fetch_mldata("MNIST original")
X_train, y_train = mnist.data[:60000] / 255., mnist.target[:60000]
X_train, y_train = shuffle(X_train, y_train)
X_train, y_train = X_train[:5000], y_train[:5000] # lets subsample a bit for a first impression
pca = RandomizedPCA(n_components=2)
fig, plots = plt.subplots(10, 10)
fig.set_size_inches(50, 50)
plt.prism()
for i, j in product(xrange(10), repeat=2):
if i > j:
continue
X_ = X_train[(y_train == i) + (y_train == j)]
y_ = y_train[(y_train == i) + (y_train == j)]
X_transformed = pca.fit_transform(X_)
plots[i, j].scatter(X_transformed[:, 0], X_transformed[:, 1], c=y_)
plots[i, j].set_xticks(())
plots[i, j].set_yticks(())
plots[j, i].scatter(X_transformed[:, 0], X_transformed[:, 1], c=y_)
plots[j, i].set_xticks(())
plots[j, i].set_yticks(())
if i == 0:
plots[i, j].set_title(j)
plots[j, i].set_ylabel(j)
#plt.scatter(X_transformed[:, 0], X_transformed[:, 1], c=y_)
plt.tight_layout()
plt.savefig("mnist_pairs.png")
@haraldschilly
Copy link

i think the j,i plot should not be the same as the i,j plot, but a transposed image of that. (line 24 vs 28)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment