Skip to content

Instantly share code, notes, and snippets.

@matsuken92
Last active August 29, 2015 14:17
Show Gist options
  • Save matsuken92/3218c44d1723a36039b7 to your computer and use it in GitHub Desktop.
Save matsuken92/3218c44d1723a36039b7 to your computer and use it in GitHub Desktop.
Principal Components Analysis
%matplotlib inline
import numpy as np
import sklearn.decomposition as decomp
import matplotlib.pyplot as plt
import matplotlib.cm as cm
# function definitions
class DigitData:
def __init__(self, data):
self.label = data[0]
self.data = data[1:785]
def getLabel(self):
return self.label
def getData(self):
return self.data
def __repr__(self):
return "label: " + str(self.label) + "\ndata: " + str(self.data) + "\n"
class DigitDataSet:
def __init__(self, data):
self.dataset = {}
self.data = data[:,1:785]
self.label = data[:,0]
for d in data:
item = DigitData(d)
if item.getLabel() not in self.dataset:
self.dataset[item.getLabel()] = [item.getData()]
else:
self.dataset[item.getLabel()].append(item.getData())
def getLabel(self):
return self.label
def getData(self, index=-1):
if index==-1:
return self.data
else:
return self.data[index]
def getByLabel(self, label, num=None):
if label < 0 or 9 < label:
raise Exception('num should be from 0 to 9.')
if num is None:
return np.array(self.dataset[label][0])
if isinstance(num, int) or isinstance(num, float):
return np.array(self.dataset[label][0:num])
if num == 'all':
return np.array(self.dataset[label])
else:
raise Exception('num should be int or float.')
def __repr__(self):
ret_val = ""
for k in self.dataset.keys():
ret_val += str(k) + ", " + str(len(self.dataset[k])) +"\n"
return ret_val
def plot_digits(Z, size, size_x, size_y, counter, title, fontsize=10):
X, Y = np.meshgrid(range(size),range(size))
Z = Z.reshape(size,size)[::-1,:]
plt.subplot(size_x, size_y, counter)
plt.title(title, fontsize=fontsize)
plt.xlim(0,size-1)
plt.ylim(0,size-1)
plt.pcolor(X, Y, Z, cmap=plt.get_cmap('Spectral'))
plt.tick_params(labelbottom="off")
plt.tick_params(labelleft="off")
size = 28
raw_data= np.loadtxt('train_master.csv',delimiter=',',skiprows=1)
dataset = DigitDataSet(raw_data)
data = [None for i in range(10)]
for i in range(10):
data[i] = dataset.getByLabel(i,'all')
# investigate to decide the number of components
comp_items = [5,10,20,30]
cumsum_explained = np.zeros((10,len(comp_items)))
for i, n_comp in zip(range(len(comp_items)), comp_items):
for num in range(10):
pca = decomp.PCA(n_components = n_comp)
pca.fit(data[num])
transformed = pca.transform(data[num])
E = pca.explained_variance_ratio_
cumsum_explained[num, i] = np.cumsum(E)[::-1][0]
print "| label |explained n_comp:5|explained n_comp:10|explained n_comp:20|explained n_comp:30|"
print "|:-----:|:-----:|:-----:|:-----:|:-----:|"
for i in range(10):
print "|%d|%.1f%|%.1f%|%.1f%|%.1f%|"%(i, cumsum_explained[i,0]*100, cumsum_explained[i,1]*100, cumsum_explained[i,2]*100, cumsum_explained[i,3]*100)
# PCA ALL
pca = decomp.PCA(n_components = 2)
pca.fit(dataset.getData())
transformed = pca.transform(dataset.getData())
colors = [plt.cm.hsv(0.1*i, 1) for i in range(10)]
plt.figure(figsize=(16,11))
for i in range(10):
plt.scatter(0,0, alpha=1, c=colors[i],label=str(i))
plt.legend()
for l, d in zip(dataset.getLabel(), transformed):
plt.scatter(d[0],d[1] , c=colors[int(l)], alpha=0.3)
plt.title("PCA(Principal Component Analysis)")
plt.show()
# 各数字の代表値をグラフで描画
transformed = [pca.transform(dataset.getByLabel(label=i,num=('all'))) for i in range(10)]
ave = [np.average(transformed[i],axis=0) for i in range(10)]
var = [np.var(transformed[i],axis=0) for i in range(10)]
plt.clf()
plt.figure(figsize=(14,10))
for j in range(10):
plt.scatter(100,100, alpha=1, c=colors[j],label=str(j))
plt.legend()
plt.xlim(-1500, 1500)
plt.ylim(-1500, 1500)
for i, a, v in zip(range(10), ave, var):
print i, a[0], a[1]
plt.scatter(a[0], a[1], c=colors[i], alpha=0.6, s=v/4, linewidth=1)
plt.scatter(a[0], a[1], c="k", s=10)
if i==4:
padding = -50
elif i==8 or i==9:
padding = 60
elif i==7 or i==5:
padding = 10
else:
padding = 30
plt.text(a[0]+30, a[1]+padding, "digit: %d"%i, fontsize=12)
plt.title("PCA Representative Vector for each digit.")
plt.savefig("PCA_RepVec.png")
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment