Last active
August 29, 2015 14:17
-
-
Save matsuken92/3218c44d1723a36039b7 to your computer and use it in GitHub Desktop.
Principal Components Analysis
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
%matplotlib inline | |
import numpy as np | |
import sklearn.decomposition as decomp | |
import matplotlib.pyplot as plt | |
import matplotlib.cm as cm | |
# function definitions | |
class DigitData: | |
def __init__(self, data): | |
self.label = data[0] | |
self.data = data[1:785] | |
def getLabel(self): | |
return self.label | |
def getData(self): | |
return self.data | |
def __repr__(self): | |
return "label: " + str(self.label) + "\ndata: " + str(self.data) + "\n" | |
class DigitDataSet: | |
def __init__(self, data): | |
self.dataset = {} | |
self.data = data[:,1:785] | |
self.label = data[:,0] | |
for d in data: | |
item = DigitData(d) | |
if item.getLabel() not in self.dataset: | |
self.dataset[item.getLabel()] = [item.getData()] | |
else: | |
self.dataset[item.getLabel()].append(item.getData()) | |
def getLabel(self): | |
return self.label | |
def getData(self, index=-1): | |
if index==-1: | |
return self.data | |
else: | |
return self.data[index] | |
def getByLabel(self, label, num=None): | |
if label < 0 or 9 < label: | |
raise Exception('num should be from 0 to 9.') | |
if num is None: | |
return np.array(self.dataset[label][0]) | |
if isinstance(num, int) or isinstance(num, float): | |
return np.array(self.dataset[label][0:num]) | |
if num == 'all': | |
return np.array(self.dataset[label]) | |
else: | |
raise Exception('num should be int or float.') | |
def __repr__(self): | |
ret_val = "" | |
for k in self.dataset.keys(): | |
ret_val += str(k) + ", " + str(len(self.dataset[k])) +"\n" | |
return ret_val | |
def plot_digits(Z, size, size_x, size_y, counter, title, fontsize=10): | |
X, Y = np.meshgrid(range(size),range(size)) | |
Z = Z.reshape(size,size)[::-1,:] | |
plt.subplot(size_x, size_y, counter) | |
plt.title(title, fontsize=fontsize) | |
plt.xlim(0,size-1) | |
plt.ylim(0,size-1) | |
plt.pcolor(X, Y, Z, cmap=plt.get_cmap('Spectral')) | |
plt.tick_params(labelbottom="off") | |
plt.tick_params(labelleft="off") | |
size = 28 | |
raw_data= np.loadtxt('train_master.csv',delimiter=',',skiprows=1) | |
dataset = DigitDataSet(raw_data) | |
data = [None for i in range(10)] | |
for i in range(10): | |
data[i] = dataset.getByLabel(i,'all') | |
# investigate to decide the number of components | |
comp_items = [5,10,20,30] | |
cumsum_explained = np.zeros((10,len(comp_items))) | |
for i, n_comp in zip(range(len(comp_items)), comp_items): | |
for num in range(10): | |
pca = decomp.PCA(n_components = n_comp) | |
pca.fit(data[num]) | |
transformed = pca.transform(data[num]) | |
E = pca.explained_variance_ratio_ | |
cumsum_explained[num, i] = np.cumsum(E)[::-1][0] | |
print "| label |explained n_comp:5|explained n_comp:10|explained n_comp:20|explained n_comp:30|" | |
print "|:-----:|:-----:|:-----:|:-----:|:-----:|" | |
for i in range(10): | |
print "|%d|%.1f%|%.1f%|%.1f%|%.1f%|"%(i, cumsum_explained[i,0]*100, cumsum_explained[i,1]*100, cumsum_explained[i,2]*100, cumsum_explained[i,3]*100) | |
# PCA ALL | |
pca = decomp.PCA(n_components = 2) | |
pca.fit(dataset.getData()) | |
transformed = pca.transform(dataset.getData()) | |
colors = [plt.cm.hsv(0.1*i, 1) for i in range(10)] | |
plt.figure(figsize=(16,11)) | |
for i in range(10): | |
plt.scatter(0,0, alpha=1, c=colors[i],label=str(i)) | |
plt.legend() | |
for l, d in zip(dataset.getLabel(), transformed): | |
plt.scatter(d[0],d[1] , c=colors[int(l)], alpha=0.3) | |
plt.title("PCA(Principal Component Analysis)") | |
plt.show() | |
# 各数字の代表値をグラフで描画 | |
transformed = [pca.transform(dataset.getByLabel(label=i,num=('all'))) for i in range(10)] | |
ave = [np.average(transformed[i],axis=0) for i in range(10)] | |
var = [np.var(transformed[i],axis=0) for i in range(10)] | |
plt.clf() | |
plt.figure(figsize=(14,10)) | |
for j in range(10): | |
plt.scatter(100,100, alpha=1, c=colors[j],label=str(j)) | |
plt.legend() | |
plt.xlim(-1500, 1500) | |
plt.ylim(-1500, 1500) | |
for i, a, v in zip(range(10), ave, var): | |
print i, a[0], a[1] | |
plt.scatter(a[0], a[1], c=colors[i], alpha=0.6, s=v/4, linewidth=1) | |
plt.scatter(a[0], a[1], c="k", s=10) | |
if i==4: | |
padding = -50 | |
elif i==8 or i==9: | |
padding = 60 | |
elif i==7 or i==5: | |
padding = 10 | |
else: | |
padding = 30 | |
plt.text(a[0]+30, a[1]+padding, "digit: %d"%i, fontsize=12) | |
plt.title("PCA Representative Vector for each digit.") | |
plt.savefig("PCA_RepVec.png") | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment