Last active
May 18, 2017 12:56
-
-
Save nazoking/e761f1a29be10523ac503d5dd8d0e61e to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Kaggle Digit Recognizer using Keras\n", | |
"\n", | |
"[Kaggle の Digit Recognizer(mnist)](https://www.kaggle.com/c/digit-recognizer) を [keras](https://keras.io/ja/) を使って解く\n", | |
"\n", | |
"\n", | |
"* (事前にデータをダウンロードして gzip し、 `mnist-train.csv.gz` , `mnist-test.csv.gz` という名前で保存、アップロードしておくこと)\n", | |
"* keras のドキュメントの日本語は v1 用で、ところどころ v2 と変更されているので注意(v2 のドキュメントは https://github.com/fchollet/keras-docs-ja で翻訳中。英語のドキュメントが一番正しい )" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"# 諸々ライブラリ読み込み\n", | |
"import keras\n", | |
"\n", | |
"# 行列計算ライブラリ numpy\n", | |
"import numpy as np\n", | |
"# データ解析支援ライブラリ pandas\n", | |
"import pandas as pd\n", | |
"# グラフ表示 matplotlib\n", | |
"import matplotlib.pyplot as plt\n", | |
"%matplotlib inline" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"# データを読み込む\n", | |
"train = pd.read_csv(\"mnist-train.csv.gz\")\n", | |
"train.head() # 最初の部分を表示してみる" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"# ピクセル情報とラベルに分割\n", | |
"X_train = (train.values[:,1:]).astype('float32') # ピクセル情報\n", | |
"# 28x28 の2次元のピクセル画像が 784の1次元データになっているので2次元x1チャンネルに戻す\n", | |
"X_train = X_train.reshape(X_train.shape[0], 28, 28,1)\n", | |
"\n", | |
"# ラベルは数字で入る\n", | |
"y_train = train.values[:,0].astype('int32')\n", | |
"# それぞれの次元数を確認してみる\n", | |
"print(X_train.shape, y_train.shape)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"# 画面に表示 画面を cols のグリッドに区切って表示\n", | |
"def protout(X, y, cols=10):\n", | |
" for i in range(cols):\n", | |
" plt.subplot(1,cols+1,1+i) # 画面を 1xcols に分割してi番目に描写\n", | |
" plt.imshow(X[i].reshape(28,28), 'gray') # グレイスケール=チャンネルなしにリシェイプして表示\n", | |
" plt.title(y[i]) # 文字を追加\n", | |
" plt.axis('off') # ルーラーを描写しない\n", | |
"\n", | |
"protout(X_train, y_train)\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"print(y_train[0:3])\n", | |
"\n", | |
"# ラベルをワンホットベクトルに変更\n", | |
"# v2\n", | |
"# y_train_c = keras.utils.to_categorical(y_train,10)\n", | |
"\n", | |
"# v1\n", | |
"from keras.utils import np_utils\n", | |
"y_train_c = np_utils.to_categorical(y_train,10)\n", | |
"\n", | |
"print(y_train_c[0:3])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"from keras.callbacks import Callback\n", | |
"# 学習の途中で protout するクラス\n", | |
"class ProtoutCallback(Callback):\n", | |
" def __init__(self, *epochs):\n", | |
" self.epochs=epochs\n", | |
" def on_epoch_begin(self, epoch, logs):\n", | |
" if epoch == 0 and 0 in self.epochs:\n", | |
" self.draw(-1)\n", | |
" def on_epoch_end(self, epoch, logs):\n", | |
" if (epoch+1) in self.epochs:\n", | |
" self.draw(epoch)\n", | |
" def draw(self, epoch):\n", | |
" # 現時点での予測を描写\n", | |
" y = self.model.predict_classes(X_train[0:10], verbose=0)\n", | |
" # エポック数を左上に描写\n", | |
" protout(X_train, y)\n", | |
" plt.subplot(1,11,1)\n", | |
" plt.text(0,-20,\"epoch %d\" % (epoch+1))\n", | |
" plt.show() # 毎回 show する" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"keras では Sequential に [レイヤー](https://keras.io/ja/layers/about-keras-layers/) を追加していく\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"# モデルの作成\n", | |
"\n", | |
"from keras.models import Sequential\n", | |
"from keras.layers.core import Dense, Flatten, Dropout\n", | |
"from keras.preprocessing.image import ImageDataGenerator\n", | |
"\n", | |
"# モデルを作る関数(初期化が簡単にできるように関数化)\n", | |
"def create_model():\n", | |
" # 逐次モデル\n", | |
" model= Sequential()\n", | |
" # Flattenは入力を1D配列に変換します。\n", | |
" model.add(Flatten(input_shape=(28,28,1)))\n", | |
" # Dense は全結合。100次元に変換 活性化関数は relu\n", | |
" model.add(Dense(100, activation='relu'))\n", | |
" # 同じ内容でもう一層\n", | |
" model.add(Dense(100, activation='relu'))\n", | |
" # 更に全結合層。結合した後 softmax で 10 のラベルに分類\n", | |
" model.add(Dense(10, activation='softmax'))\n", | |
" model.compile(\n", | |
" loss='categorical_crossentropy', # ロス関数を設定\n", | |
" optimizer='sgd',# オプティマイザを確率的勾配降下法 (SGD: Stochastic Gradient Descent)に設定\n", | |
" metrics=['accuracy']) # 途中経過の精度測定は「精度」を使う \n", | |
" return model\n", | |
"\n", | |
"model = create_model()\n", | |
"# モデルの概要を表示\n", | |
"model.summary()\n", | |
"print(\"input shape \",model.input_shape)\n", | |
"print(\"output shape \",model.output_shape)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"# 学習前はこんな予想結果(乱数なのでむちゃくちゃ)\n", | |
"y = model.predict_classes(X_train[0:10])\n", | |
"protout(X_train, y)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"model = create_model()\n", | |
"# 訓練する\n", | |
"history = model.fit(X_train, y_train_c,\n", | |
" batch_size=640, # バッチサイズ\n", | |
" epochs=10, # v2 epochs = 10, # 学習エポック数,\n", | |
" validation_split=0.1, # 検証用のデータの割合\n", | |
" callbacks = [ProtoutCallback(3,8)], # epoch 3と8 で途中経過の画像表示\n", | |
" )" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"# 学習後はこんな予想結果\n", | |
"y = model.predict_classes(X_train[0:10])\n", | |
"protout(X_train[0:10], y)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"# 精度の履歴をプロット\n", | |
"plt.plot(history.history['acc'])\n", | |
"plt.plot(history.history['val_acc'])\n", | |
"plt.title('model accuracy')\n", | |
"plt.xlabel('epoch')\n", | |
"plt.ylabel('accuracy')\n", | |
"plt.legend(['acc', 'val_acc'], loc='lower right')\n", | |
"plt.show()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"# 損失の履歴をプロット\n", | |
"plt.plot(history.history['loss'])\n", | |
"plt.plot(history.history['val_loss'])\n", | |
"plt.title('model loss')\n", | |
"plt.xlabel('epoch')\n", | |
"plt.ylabel('loss')\n", | |
"plt.legend(['loss', 'val_loss'], loc='upper right')\n", | |
"plt.show()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"y = model.predict_classes(X_train[0:10])\n", | |
"protout(X_train, y)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## モデルで十分な精度が出せたらテスト用データで実践してみる" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"test = pd.read_csv(\"mnist-test.csv.gz\")\n", | |
"# test には label がない\n", | |
"test.head()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"# 同じように変換\n", | |
"X_test = test.values.astype('float32')\n", | |
"X_test = X_test.reshape(X_test.shape[0], 28, 28,1)\n", | |
"# 文字を予想する\n", | |
"y_test = model.predict_classes(X_test)\n", | |
"protout(X_test,y_test)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"# 提出用にフォーマットして出力( ImageID, Label )\n", | |
"import csv\n", | |
"with open(\"results.csv\", \"w\") as f:\n", | |
" f.write(('\"ImageId\",\"Label\"\\n'))\n", | |
" for i, result in enumerate(y_test):\n", | |
" f.write(('%d,\"%d\"\\n'%(i+1,result)))\n", | |
"# results.csv をダウンロードすること\n", | |
"print(\"ok download `results.csv` and upload to kaggle\")" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"collapsed": true | |
}, | |
"source": [ | |
"## 課題\n", | |
"\n", | |
"* バッチ数やエポック数を変更してみよう\n", | |
" * 精度や学習時間がどのように変化するか観察してみよう\n", | |
" * 同じモデルでも実行タイミングによって結果が変わるのを注意(乱数シード)\n", | |
"* 多層にしてみよう https://keras.io/ja/layers/core/\n", | |
" * Dropout 付け足したり、活性化関数を変更してみよう\n", | |
"* [最適化関数](https://keras.io/ja/optimizers/)を変更してみよう\n", | |
"* [ImageDataGenerator](https://keras.io/ja/preprocessing/image/) を使ってデータを増やしてみよう\n", | |
" * ノイズを加えてみよう\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python3 (anaconda3-4.1.0)", | |
"language": "python", | |
"name": "anaconda3-4.1.0" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.5.1" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
graphviz==0.7 | |
jupyter==1.0.0 | |
Keras==2.0.4 | |
matplotlib==2.0.2 | |
numpy==1.12.1 | |
pandas==0.20.1 | |
pydot==1.2.3 | |
scikit-learn==0.18.1 | |
tensorflow==1.1.0 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment