Last active
September 22, 2019 03:56
-
-
Save phyblas/b4969ee3cf7e6a9b88c9ae571d1c2a5a to your computer and use it in GitHub Desktop.
MNIST手書き数字データをnumpyで書いたロジスティック回帰で学習して結果を分析する ref: https://qiita.com/phyblas/items/375ab130e53b0d04f784
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
[1276 0 4 2 3 9 8 3 6 1] | |
[ 0 1556 11 4 1 5 0 4 17 6] | |
[ 9 16 1225 21 11 3 15 16 26 6] | |
[ 4 12 33 1269 0 45 3 17 34 10] | |
[ 5 7 11 2 1245 3 16 5 12 56] | |
[ 12 6 16 44 10 1105 25 10 39 13] | |
[ 12 3 9 0 5 16 1343 2 4 3] | |
[ 6 4 20 5 14 1 1 1377 1 32] | |
[ 6 34 12 36 7 24 10 5 1242 14] | |
[ 10 6 6 20 35 12 1 46 10 1273] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
gakushuuritsu = 0.24 # 学習率 | |
n_batch = 100 # ミニバッチのサイズ | |
n = len(y) # 全てのデータの数 | |
nf = 5 # 何個にする | |
nn = int(n/nf)+(np.arange(nf)<(n%nf)) # 各グループのデータの数 | |
kurikaeshi = 30 # 繰り返す回数(今回は早期終了はしない) | |
kunren_seikaku = [] | |
kenshou_seikaku = [] | |
s = np.random.permutation(n) | |
mmk = Mikumikukaiki(gakushuuritsu) | |
for i in range(nf): | |
X_kunren = X[s[nn[i]:]] | |
y_kunren = y[s[nn[i]:]] | |
X_kenshou = X[s[:nn[i]]] | |
y_kenshou = y[s[:nn[i]]] | |
s = np.roll(s,nn[i],0) # 回すことでデータを分割するところは毎回変わる | |
mmk.gakushuu(X_kunren,y_kunren,kurikaeshi,n_batch,X_kenshou,y_kenshou) | |
kunren_seikaku.append(mmk.kunren_seikaku) | |
kenshou_seikaku.append(mmk.kenshou_seikaku) | |
kunren_seikaku = np.stack(kunren_seikaku) | |
kenshou_seikaku = np.stack(kenshou_seikaku) | |
plt.figure(figsize=[8,6]) | |
plt.errorbar(np.arange(kurikaeshi),kunren_seikaku.mean(0),yerr=kunren_seikaku.std(0),color='#dd0000') | |
plt.errorbar(np.arange(kurikaeshi),kenshou_seikaku.mean(0),yerr=kenshou_seikaku.std(0),color='#00aa00') | |
plt.title(u'正確度 (%)',fontname='AppleGothic',size=18) | |
plt.legend([u'訓練',u'檢證'],prop={'family':'AppleGothic','size':17}) | |
plt.show() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from sklearn import datasets | |
mnist = datasets.fetch_openml('mnist_784') | |
X,y = mnist.data,mnist.target | |
X = X/255. | |
print(X.shape) # (70000, 784) | |
print(y.shape) # (70000,) | |
print(y) # [ 0. 0. 0. ..., 9. 9. 9.] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import matplotlib.pyplot as plt | |
for i in range(1,10): | |
plt.subplot(330+i) | |
plt.imshow(X[30+i*6500].reshape(28,28),cmap='gray_r') | |
plt.show() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sklearn.metrics import confusion_matrix |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sklearn.model_selection import KFold | |
kf = KFold(n_splits=5,shuffle=True) | |
for kr,ks in kf.split(y): | |
X_kunren = X[kr] | |
y_kunren = y[kr] | |
X_kenshou = X[ks] | |
y_kenshou = y[ks] | |
# ...使う部分 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sklearn.model_selection import train_test_split | |
X_kunren,X_kenshou,y_kunren,y_kenshou = train_test_split(X,y,test_size=0.2) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
s = np.random.permutation(len(y)) | |
X = X[s[:700]] | |
y = y[s[:700]] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
s = np.random.permutation(len(y)) | |
X = X[s[:7000]] | |
y = y[s[:7000]] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
n = len(X) | |
s = np.random.permutation(n) | |
nn = int(n/5) | |
X_kunren,X_kenshou = X[s[nn:]],X[s[:nn]] | |
y_kunren,y_kenshou = y[s[nn:]],y[s[:nn]] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Mikumikukaiki: | |
def __init__(self,gakushuuritsu): | |
self.gakushuuritsu = gakushuuritsu # 学習率 | |
def gakushuu(self,X,y,kurikaeshi,n_batch=0,X_kenshou=0,y_kenshou=0,patience=0): | |
n = len(y) # データの数 | |
# もし検証データが渡されなければ、代わりに訓練データを検証データにも使う | |
if(type(X_kenshou)!=np.ndarray): | |
X_kenshou,y_kenshou = X,y | |
# バッチの数が指定されていないか、データの数より多い場合、ミニバッチをしないことにする | |
if(n_batch==0 or n<n_batch): | |
n_batch = n | |
self.n_group = int(y.max()+1) # 種類の数 | |
y_1h = y[:,None]==range(self.n_group) # 正解ラベルの配列をone-hotにしておく | |
self.w = np.zeros([X.shape[1]+1,self.n_group]) | |
# 毎回の損失と訓練データに対する正確度と検証データに対する正確度を記録するためのリスト | |
self.sonshitsu = [] | |
self.kunren_seikaku = [] | |
self.kenshou_seikaku = [] | |
saikou = 0 # 今までの最高の正確度 | |
agaranai = 0 # 正確度が何回上がっていない | |
for j in range(kurikaeshi): | |
s = np.random.permutation(n) | |
for i in range(0,n,n_batch): | |
Xn = X[s[i:i+n_batch]] | |
yn = y_1h[s[i:i+n_batch]] | |
phi = self.softmax(Xn) | |
eee = (yn-phi)/len(yn)*self.gakushuuritsu | |
self.w[1:] += np.dot(eee.T,Xn).T | |
self.w[0] += eee.sum(0) | |
seigo = self.yosoku(X)==y | |
kunren_seikaku = seigo.mean()*100 # 訓練データに対する正確度 | |
seigo = self.yosoku(X_kenshou)==y_kenshou | |
kenshou_seikaku = seigo.mean()*100 # 検証データに対する正確度 | |
if(kenshou_seikaku > saikou): | |
# 正確度が以前より高くなるとその値を取っておく | |
saikou = kenshou_seikaku | |
agaranai = 0 | |
w = self.w.copy() | |
else: | |
agaranai += 1 # 上がらなければ、カウント | |
self.kunren_seikaku += [kunren_seikaku] | |
self.kenshou_seikaku += [kenshou_seikaku] | |
self.sonshitsu += [self.entropy(X,y_1h)] | |
print(u'%d回目、正確度%.3f%%、最高%.3f%%'%(j+1,self.kenshou_seikaku[-1],saikou)) | |
if(patience!=0 and agaranai>=patience): | |
break # 正確度が何回たっても上がらなければ学習が終わる | |
self.w = w # 最後に取っておいた重みを採用する | |
def yosoku(self,X): | |
# 予測値を計算する | |
return (np.dot(X,self.w[1:])+self.w[0]).argmax(1) | |
def softmax(self,X): | |
# ソフトマックス関数で確率を計算する | |
h = np.dot(X,self.w[1:])+self.w[0] | |
exp_h = np.exp(h.T-h.max(1)) | |
return (exp_h/exp_h.sum(0)).T | |
def entropy(self,X,y_1h): | |
# 交差エントロピーを計算する | |
return -(y_1h*np.log(self.softmax(X)+1e-7)).mean() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
[1276 0 4 2 3 9 8 3 6 1] | |
[ 0 1556 11 4 1 5 0 4 17 6] | |
[ 9 16 1225 21 11 3 15 16 26 6] | |
[ 4 12 33 1269 0 45 3 17 34 10] | |
[ 5 7 11 2 1245 3 16 5 12 56] | |
[ 12 6 16 44 10 1105 25 10 39 13] | |
[ 12 3 9 0 5 16 1343 2 4 3] | |
[ 6 4 20 5 14 1 1 1377 1 32] | |
[ 6 34 12 36 7 24 10 5 1242 14] | |
[ 10 6 6 20 35 12 1 46 10 1273] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
gakushuuritsu = 0.24 # 学習率 | |
kurikaeshi = 1000 # 学習が終わらない場合の繰り返す回数 | |
n_batch = 100 # ミニバッチのサイズ | |
patience = 10 # 正確度が何回上がらなければ学習が終わる | |
mmk = Mikumikukaiki(gakushuuritsu) | |
mmk.gakushuu(X_kunren,y_kunren,kurikaeshi,n_batch,X_kenshou,y_kenshou,patience) | |
# 学習進歩のグラフを描く | |
plt.figure(figsize=[8,8]) | |
ax = plt.subplot(211) | |
plt.plot(mmk.sonshitsu,'#000077') | |
plt.legend([u'損失'],prop={'family':'AppleGothic','size':17}) | |
plt.tick_params(labelbottom='off') | |
ax = plt.subplot(212) | |
ax.set_ylabel(u'正確度 (%)',fontname='AppleGothic',size=18) | |
plt.plot(mmk.kunren_seikaku,'#dd0000') | |
plt.plot(mmk.kenshou_seikaku,'#00aa00') | |
plt.legend([u'訓練',u'檢證'],prop={'family':'AppleGothic','size':17}) | |
plt.show() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
t = mmk.yosoku(X_kenshou) | |
conma = confusion_matrix(y_kenshou,t) | |
for c in conma: print(c) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import matplotlib as mpl | |
def plotconma(conma,log=0): | |
n = len(conma) | |
plt.figure(figsize=[9,8]) | |
plt.gca(xticks=np.arange(n),xticklabels=np.arange(n),yticks=np.arange(n),yticklabels=np.arange(n)) | |
plt.xlabel(u'予測',fontname='AppleGothic',size=16) | |
plt.ylabel(u'正解',fontname='AppleGothic',size=16) | |
for i in range(n): | |
for j in range(n): | |
plt.text(j,i,conma[i,j],ha='center',va='center',size=14) | |
if(log): | |
plt.imshow(conma,cmap='autumn_r',norm=mpl.colors.LogNorm()) | |
else: | |
plt.imshow(conma,cmap='autumn_r') | |
plt.colorbar(pad=0.01) | |
plt.show() | |
plotconma(conma,log=1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
for i in range(1,10): | |
plt.subplot(330+i) | |
plt.imshow(mmk.w[1:,i].reshape(28,28),cmap='gray_r') | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment