Created
September 21, 2018 07:41
-
-
Save jsun/cacd32f5a2417849d2f37d36b07c6d07 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sklearn.datasets import fetch_mldata | |
from sklearn.model_selection import train_test_split | |
from sklearn.preprocessing import StandardScaler | |
from sklearn.decomposition import KernelPCA | |
import matplotlib.pyplot as plt | |
# download MNIST data | |
mnist = fetch_mldata('MNIST original', data_home='./data/minist') | |
print(mnist) | |
## {'DESCR': 'mldata.org dataset: mnist-original', 'COL_NAMES': ['label', 'data'], 'target': array([0., 0., 0., ..., 9., 9., 9.]), 'data': array([[0, 0, 0, ..., 0, 0, 0], | |
## [0, 0, 0, ..., 0, 0, 0], | |
## [0, 0, 0, ..., 0, 0, 0], | |
## ..., | |
## [0, 0, 0, ..., 0, 0, 0], | |
## [0, 0, 0, ..., 0, 0, 0], | |
## [0, 0, 0, ..., 0, 0, 0]], dtype=uint8)} | |
print(mnist.data.shape) | |
## (70000, 784) | |
print(mnist.target.shape) | |
## (70000,) | |
# only use the first 10,000 images (to reduce runtime) | |
img = mnist.data[1:10000] | |
# standardization | |
scaler = StandardScaler() | |
scaler.fit(img) | |
img_scaled = scaler.transform(img) | |
# kernel PCA | |
kpca = KernelPCA(n_components=5, kernel='rbf', gamma=15, | |
fit_inverse_transform=True) | |
kpca.fit(img_scaled) | |
img_scaled_kpca = kpca.transform(img_scaled) | |
# get standardized & kernel PCA-transformed data | |
img_lowdim = scaler.inverse_transform(kpca.inverse_transform(img_scaled_kpca)) | |
# plot figures | |
plt.figure(figsize=(8,4)); | |
# original image | |
plt.subplot(1, 2, 1); | |
plt.imshow(img[1].reshape(28,28), | |
cmap=plt.cm.gray, interpolation='nearest', | |
clim=(0, 255)); | |
plt.title('original image', fontsize = 20); | |
# standardized & PCA-transformed image | |
plt.subplot(1, 2, 2); | |
plt.imshow(img_lowdim[1].reshape(28, 28), | |
cmap=plt.cm.gray, interpolation='nearest', | |
clim=(0, 255)); | |
plt.title('5 explained variances', fontsize = 20); |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment