Skip to content

Instantly share code, notes, and snippets.

@phyblas
Last active September 22, 2019 03:56
Show Gist options
  • Save phyblas/b4969ee3cf7e6a9b88c9ae571d1c2a5a to your computer and use it in GitHub Desktop.
Save phyblas/b4969ee3cf7e6a9b88c9ae571d1c2a5a to your computer and use it in GitHub Desktop.
MNIST手書き数字データをnumpyで書いたロジスティック回帰で学習して結果を分析する ref: https://qiita.com/phyblas/items/375ab130e53b0d04f784
[1276 0 4 2 3 9 8 3 6 1]
[ 0 1556 11 4 1 5 0 4 17 6]
[ 9 16 1225 21 11 3 15 16 26 6]
[ 4 12 33 1269 0 45 3 17 34 10]
[ 5 7 11 2 1245 3 16 5 12 56]
[ 12 6 16 44 10 1105 25 10 39 13]
[ 12 3 9 0 5 16 1343 2 4 3]
[ 6 4 20 5 14 1 1 1377 1 32]
[ 6 34 12 36 7 24 10 5 1242 14]
[ 10 6 6 20 35 12 1 46 10 1273]
gakushuuritsu = 0.24 # 学習率
n_batch = 100 # ミニバッチのサイズ
n = len(y) # 全てのデータの数
nf = 5 # 何個にする
nn = int(n/nf)+(np.arange(nf)<(n%nf)) # 各グループのデータの数
kurikaeshi = 30 # 繰り返す回数(今回は早期終了はしない)
kunren_seikaku = []
kenshou_seikaku = []
s = np.random.permutation(n)
mmk = Mikumikukaiki(gakushuuritsu)
for i in range(nf):
X_kunren = X[s[nn[i]:]]
y_kunren = y[s[nn[i]:]]
X_kenshou = X[s[:nn[i]]]
y_kenshou = y[s[:nn[i]]]
s = np.roll(s,nn[i],0) # 回すことでデータを分割するところは毎回変わる
mmk.gakushuu(X_kunren,y_kunren,kurikaeshi,n_batch,X_kenshou,y_kenshou)
kunren_seikaku.append(mmk.kunren_seikaku)
kenshou_seikaku.append(mmk.kenshou_seikaku)
kunren_seikaku = np.stack(kunren_seikaku)
kenshou_seikaku = np.stack(kenshou_seikaku)
plt.figure(figsize=[8,6])
plt.errorbar(np.arange(kurikaeshi),kunren_seikaku.mean(0),yerr=kunren_seikaku.std(0),color='#dd0000')
plt.errorbar(np.arange(kurikaeshi),kenshou_seikaku.mean(0),yerr=kenshou_seikaku.std(0),color='#00aa00')
plt.title(u'正確度 (%)',fontname='AppleGothic',size=18)
plt.legend([u'訓練',u'檢證'],prop={'family':'AppleGothic','size':17})
plt.show()
import numpy as np
from sklearn import datasets
mnist = datasets.fetch_openml('mnist_784')
X,y = mnist.data,mnist.target
X = X/255.
print(X.shape) # (70000, 784)
print(y.shape) # (70000,)
print(y) # [ 0. 0. 0. ..., 9. 9. 9.]
import matplotlib.pyplot as plt
for i in range(1,10):
plt.subplot(330+i)
plt.imshow(X[30+i*6500].reshape(28,28),cmap='gray_r')
plt.show()
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import KFold
kf = KFold(n_splits=5,shuffle=True)
for kr,ks in kf.split(y):
X_kunren = X[kr]
y_kunren = y[kr]
X_kenshou = X[ks]
y_kenshou = y[ks]
# ...使う部分
from sklearn.model_selection import train_test_split
X_kunren,X_kenshou,y_kunren,y_kenshou = train_test_split(X,y,test_size=0.2)
s = np.random.permutation(len(y))
X = X[s[:700]]
y = y[s[:700]]
s = np.random.permutation(len(y))
X = X[s[:7000]]
y = y[s[:7000]]
n = len(X)
s = np.random.permutation(n)
nn = int(n/5)
X_kunren,X_kenshou = X[s[nn:]],X[s[:nn]]
y_kunren,y_kenshou = y[s[nn:]],y[s[:nn]]
class Mikumikukaiki:
def __init__(self,gakushuuritsu):
self.gakushuuritsu = gakushuuritsu # 学習率
def gakushuu(self,X,y,kurikaeshi,n_batch=0,X_kenshou=0,y_kenshou=0,patience=0):
n = len(y) # データの数
# もし検証データが渡されなければ、代わりに訓練データを検証データにも使う
if(type(X_kenshou)!=np.ndarray):
X_kenshou,y_kenshou = X,y
# バッチの数が指定されていないか、データの数より多い場合、ミニバッチをしないことにする
if(n_batch==0 or n<n_batch):
n_batch = n
self.n_group = int(y.max()+1) # 種類の数
y_1h = y[:,None]==range(self.n_group) # 正解ラベルの配列をone-hotにしておく
self.w = np.zeros([X.shape[1]+1,self.n_group])
# 毎回の損失と訓練データに対する正確度と検証データに対する正確度を記録するためのリスト
self.sonshitsu = []
self.kunren_seikaku = []
self.kenshou_seikaku = []
saikou = 0 # 今までの最高の正確度
agaranai = 0 # 正確度が何回上がっていない
for j in range(kurikaeshi):
s = np.random.permutation(n)
for i in range(0,n,n_batch):
Xn = X[s[i:i+n_batch]]
yn = y_1h[s[i:i+n_batch]]
phi = self.softmax(Xn)
eee = (yn-phi)/len(yn)*self.gakushuuritsu
self.w[1:] += np.dot(eee.T,Xn).T
self.w[0] += eee.sum(0)
seigo = self.yosoku(X)==y
kunren_seikaku = seigo.mean()*100 # 訓練データに対する正確度
seigo = self.yosoku(X_kenshou)==y_kenshou
kenshou_seikaku = seigo.mean()*100 # 検証データに対する正確度
if(kenshou_seikaku > saikou):
# 正確度が以前より高くなるとその値を取っておく
saikou = kenshou_seikaku
agaranai = 0
w = self.w.copy()
else:
agaranai += 1 # 上がらなければ、カウント
self.kunren_seikaku += [kunren_seikaku]
self.kenshou_seikaku += [kenshou_seikaku]
self.sonshitsu += [self.entropy(X,y_1h)]
print(u'%d回目、正確度%.3f%%、最高%.3f%%'%(j+1,self.kenshou_seikaku[-1],saikou))
if(patience!=0 and agaranai>=patience):
break # 正確度が何回たっても上がらなければ学習が終わる
self.w = w # 最後に取っておいた重みを採用する
def yosoku(self,X):
# 予測値を計算する
return (np.dot(X,self.w[1:])+self.w[0]).argmax(1)
def softmax(self,X):
# ソフトマックス関数で確率を計算する
h = np.dot(X,self.w[1:])+self.w[0]
exp_h = np.exp(h.T-h.max(1))
return (exp_h/exp_h.sum(0)).T
def entropy(self,X,y_1h):
# 交差エントロピーを計算する
return -(y_1h*np.log(self.softmax(X)+1e-7)).mean()
[1276 0 4 2 3 9 8 3 6 1]
[ 0 1556 11 4 1 5 0 4 17 6]
[ 9 16 1225 21 11 3 15 16 26 6]
[ 4 12 33 1269 0 45 3 17 34 10]
[ 5 7 11 2 1245 3 16 5 12 56]
[ 12 6 16 44 10 1105 25 10 39 13]
[ 12 3 9 0 5 16 1343 2 4 3]
[ 6 4 20 5 14 1 1 1377 1 32]
[ 6 34 12 36 7 24 10 5 1242 14]
[ 10 6 6 20 35 12 1 46 10 1273]
gakushuuritsu = 0.24 # 学習率
kurikaeshi = 1000 # 学習が終わらない場合の繰り返す回数
n_batch = 100 # ミニバッチのサイズ
patience = 10 # 正確度が何回上がらなければ学習が終わる
mmk = Mikumikukaiki(gakushuuritsu)
mmk.gakushuu(X_kunren,y_kunren,kurikaeshi,n_batch,X_kenshou,y_kenshou,patience)
# 学習進歩のグラフを描く
plt.figure(figsize=[8,8])
ax = plt.subplot(211)
plt.plot(mmk.sonshitsu,'#000077')
plt.legend([u'損失'],prop={'family':'AppleGothic','size':17})
plt.tick_params(labelbottom='off')
ax = plt.subplot(212)
ax.set_ylabel(u'正確度 (%)',fontname='AppleGothic',size=18)
plt.plot(mmk.kunren_seikaku,'#dd0000')
plt.plot(mmk.kenshou_seikaku,'#00aa00')
plt.legend([u'訓練',u'檢證'],prop={'family':'AppleGothic','size':17})
plt.show()
t = mmk.yosoku(X_kenshou)
conma = confusion_matrix(y_kenshou,t)
for c in conma: print(c)
import matplotlib as mpl
def plotconma(conma,log=0):
n = len(conma)
plt.figure(figsize=[9,8])
plt.gca(xticks=np.arange(n),xticklabels=np.arange(n),yticks=np.arange(n),yticklabels=np.arange(n))
plt.xlabel(u'予測',fontname='AppleGothic',size=16)
plt.ylabel(u'正解',fontname='AppleGothic',size=16)
for i in range(n):
for j in range(n):
plt.text(j,i,conma[i,j],ha='center',va='center',size=14)
if(log):
plt.imshow(conma,cmap='autumn_r',norm=mpl.colors.LogNorm())
else:
plt.imshow(conma,cmap='autumn_r')
plt.colorbar(pad=0.01)
plt.show()
plotconma(conma,log=1)
for i in range(1,10):
plt.subplot(330+i)
plt.imshow(mmk.w[1:,i].reshape(28,28),cmap='gray_r')
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment