Skip to content

Instantly share code, notes, and snippets.

@kiwamizamurai
Created February 6, 2019 11:04
Show Gist options
  • Save kiwamizamurai/6d7ded92a2c13c50bf57b6c3027b1d9b to your computer and use it in GitHub Desktop.
Save kiwamizamurai/6d7ded92a2c13c50bf57b6c3027b1d9b to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Understanding the usage of Cross-Validation\n",
"- https://kevinzakka.github.io/2016/07/13/k-nearest-neighbor/\n",
"- https://towardsdatascience.com/interactive-controls-for-jupyter-notebooks-f5c94829aee6\n",
"- https://jakevdp.github.io/PythonDataScienceHandbook/05.03-hyperparameters-and-model-validation.html\n",
"- https://towardsdatascience.com/train-test-split-and-cross-validation-in-python-80b61beca4b6"
]
},
{
"cell_type": "code",
"execution_count": 84,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>sepal length (cm)</th>\n",
" <th>sepal width (cm)</th>\n",
" <th>petal length (cm)</th>\n",
" <th>petal width (cm)</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>5.1</td>\n",
" <td>3.5</td>\n",
" <td>1.4</td>\n",
" <td>0.2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>4.9</td>\n",
" <td>3.0</td>\n",
" <td>1.4</td>\n",
" <td>0.2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>4.7</td>\n",
" <td>3.2</td>\n",
" <td>1.3</td>\n",
" <td>0.2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>4.6</td>\n",
" <td>3.1</td>\n",
" <td>1.5</td>\n",
" <td>0.2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>5.0</td>\n",
" <td>3.6</td>\n",
" <td>1.4</td>\n",
" <td>0.2</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" sepal length (cm) sepal width (cm) petal length (cm) petal width (cm)\n",
"0 5.1 3.5 1.4 0.2\n",
"1 4.9 3.0 1.4 0.2\n",
"2 4.7 3.2 1.3 0.2\n",
"3 4.6 3.1 1.5 0.2\n",
"4 5.0 3.6 1.4 0.2"
]
},
"execution_count": 84,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.datasets import load_iris\n",
"import pandas\n",
"iris = load_iris()\n",
"X = iris.data\n",
"y = iris.target\n",
"pandas.DataFrame(X, columns=iris.feature_names).head()"
]
},
{
"cell_type": "code",
"execution_count": 85,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "86881cb5becb43e999f84aa75c691422",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"interactive(children=(Dropdown(description='column', options=('sepal length (cm)', 'sepal width (cm)', 'petal …"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import ipywidgets as widgets\n",
"from ipywidgets import interact, interact_manual\n",
"\n",
"df = pandas.DataFrame(X, columns=iris.feature_names)\n",
"\n",
"@interact\n",
"def show_articles_more_than(column=['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)'], \n",
" x=(0, 4, 0.3)):\n",
" return df.loc[df[column] > x]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"> No module named 'cufflinks'\n",
"\n",
"とエラーが出たから\n",
"\n",
"> pip install cufflinks\n",
"\n",
"とする"
]
},
{
"cell_type": "code",
"execution_count": 86,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "39293106d99c447eb4f5d03315519fbb",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"interactive(children=(Dropdown(description='x', options=('sepal length (cm)', 'sepal width (cm)', 'petal lengt…"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import cufflinks as cf\n",
"\n",
"@interact\n",
"def scatter_plot(x=list(df.select_dtypes('number').columns), \n",
" y=list(df.select_dtypes('number').columns)[1:],\n",
" theme=list(cf.themes.THEMES.keys()), \n",
" colorscale=list(cf.colors._scales_names.keys())):\n",
" \n",
" df.iplot(kind='scatter', x=x, y=y, mode='markers', \n",
" xTitle=x.title(), yTitle=y.title(), \n",
" text='title',\n",
" title=f'{y.title()} vs {x.title()}',\n",
" theme=theme, colorscale=colorscale)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"エラーがなおならない、このライブラリを使うのが少し楽しみだったのだが諦める\n",
"\n",
"気を取り直して**交差検証**を行う!"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# K-means  (k=3)を使った例"
]
},
{
"cell_type": "code",
"execution_count": 87,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.neighbors import KNeighborsClassifier\n",
"model = KNeighborsClassifier(n_neighbors=3)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"1. train_test_split でデータセットを分割する\n",
"2. trainで学習し, testで精度を確認する"
]
},
{
"cell_type": "code",
"execution_count": 88,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/kiwamizamurai/anaconda3/lib/python3.6/site-packages/sklearn/cross_validation.py:41: DeprecationWarning:\n",
"\n",
"This module was deprecated in version 0.18 in favor of the model_selection module into which all the refactored classes and functions are moved. Also note that the interface of the new CV iterators are different from that of this module. This module will be removed in 0.20.\n",
"\n"
]
},
{
"data": {
"text/plain": [
"0.9777777777777777"
]
},
"execution_count": 88,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.cross_validation import train_test_split\n",
"from sklearn.metrics import accuracy_score\n",
"\n",
"X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0, train_size=0.7)\n",
"\n",
"model.fit(X_train, y_train)\n",
"pred = model.predict(X_test)\n",
"accuracy_score(y_test, pred)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"上はcross-validationの一動作だった、なので次で交差検証を行う \n",
"下の例ではデータセットを5等分割して交差検証している"
]
},
{
"cell_type": "code",
"execution_count": 89,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([0.96666667, 0.96666667, 0.93333333, 0.96666667, 1. ])"
]
},
"execution_count": 89,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.cross_validation import cross_val_score\n",
"cross_val_score(model, X, y, cv=5)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# k=3がベストなのか?と疑問に思う\n",
"# そんな時、交差検証を行う\n",
"\n",
"下では次のようにしてベストなkを選んでいる\n",
"1. kを1から30まで回す\n",
"2. 各kに対してcv-10で交差検証を行う\n",
"3. プロットしてエラーが最も小さいところをみる\n"
]
},
{
"cell_type": "code",
"execution_count": 90,
"metadata": {},
"outputs": [],
"source": [
"neighbors = list(range(1, 30))\n",
"cv_scores = []\n",
"for k in neighbors:\n",
" knn = KNeighborsClassifier(n_neighbors=k)\n",
" scores = cross_val_score(knn, X_train, y_train, cv=10, scoring='accuracy')\n",
" cv_scores.append(scores.mean())"
]
},
{
"cell_type": "code",
"execution_count": 91,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The optimal number of neighbors is 6\n"
]
},
{
"data": {
"text/plain": [
"Text(0,0.5,'Misclassification Error')"
]
},
"execution_count": 91,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 720x432 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"%matplotlib inline\n",
"\n",
"MSE = [1 - x for x in cv_scores]\n",
"\n",
"optimal_k = neighbors[MSE.index(min(MSE))]\n",
"print(\"The optimal number of neighbors is %d\" % optimal_k)\n",
"\n",
"fig, ax = plt.subplots(figsize=(10, 6))\n",
"ax.plot(neighbors, MSE)\n",
"ax.set_xlabel('Number of Neighbors K')\n",
"ax.set_ylabel('Misclassification Error')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 前やったニューラルネットワークを使ってみよう\n",
"# 今回はkerasだよ"
]
},
{
"cell_type": "code",
"execution_count": 92,
"metadata": {},
"outputs": [],
"source": [
"from keras.models import Sequential\n",
"from keras.layers import Dense, Activation"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 中間のactivationにはreluを、出力層のものにはsoftmaxを使うよ\n",
"# ちなみに最適化には最新のADAMを使うよ\n",
"- https://keras.io/ja/optimizers/"
]
},
{
"cell_type": "code",
"execution_count": 93,
"metadata": {},
"outputs": [],
"source": [
"model = Sequential()\n",
"model.add(Dense(6, input_dim=4))\n",
"model.add(Activation('relu'))\n",
"model.add(Dense(3, input_dim=12))\n",
"model.add(Activation('softmax'))\n",
"model.compile(optimizer='Adam', \n",
" loss='categorical_crossentropy',\n",
" metrics=['accuracy']\n",
" )"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 前回、前処理でやったエンコーダーを使うよ\n",
"これがないとエラーが起こるから注意。ソフトマックスではカテゴリカルなラベルが必要"
]
},
{
"cell_type": "code",
"execution_count": 94,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.preprocessing import LabelEncoder\n",
"from keras.utils import np_utils"
]
},
{
"cell_type": "code",
"execution_count": 95,
"metadata": {},
"outputs": [],
"source": [
"categol_y = np_utils.to_categorical(y)"
]
},
{
"cell_type": "code",
"execution_count": 96,
"metadata": {},
"outputs": [],
"source": [
"X_train, X_test, y_train, y_test = train_test_split(X, categol_y, random_state=0, train_size=0.7)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# とりあえず一回やってみる.その後に交差検証する"
]
},
{
"cell_type": "code",
"execution_count": 97,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/kiwamizamurai/anaconda3/lib/python3.6/site-packages/ipykernel_launcher.py:1: UserWarning:\n",
"\n",
"The `nb_epoch` argument in `fit` has been renamed `epochs`.\n",
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/100\n",
"105/105 [==============================] - 1s 5ms/step - loss: 2.6313 - acc: 0.3714\n",
"Epoch 2/100\n",
"105/105 [==============================] - 0s 312us/step - loss: 2.0706 - acc: 0.3714\n",
"Epoch 3/100\n",
"105/105 [==============================] - 0s 293us/step - loss: 1.7114 - acc: 0.3619\n",
"Epoch 4/100\n",
"105/105 [==============================] - 0s 258us/step - loss: 1.4884 - acc: 0.2000\n",
"Epoch 5/100\n",
"105/105 [==============================] - 0s 228us/step - loss: 1.3478 - acc: 0.1619\n",
"Epoch 6/100\n",
"105/105 [==============================] - 0s 228us/step - loss: 1.2316 - acc: 0.1619\n",
"Epoch 7/100\n",
"105/105 [==============================] - 0s 245us/step - loss: 1.1312 - acc: 0.1333\n",
"Epoch 8/100\n",
"105/105 [==============================] - 0s 257us/step - loss: 1.0486 - acc: 0.1619\n",
"Epoch 9/100\n",
"105/105 [==============================] - 0s 229us/step - loss: 0.9687 - acc: 0.2476\n",
"Epoch 10/100\n",
"105/105 [==============================] - 0s 227us/step - loss: 0.9146 - acc: 0.4095\n",
"Epoch 11/100\n",
"105/105 [==============================] - 0s 227us/step - loss: 0.8829 - acc: 0.4476\n",
"Epoch 12/100\n",
"105/105 [==============================] - 0s 226us/step - loss: 0.8590 - acc: 0.4667\n",
"Epoch 13/100\n",
"105/105 [==============================] - 0s 246us/step - loss: 0.8328 - acc: 0.4571\n",
"Epoch 14/100\n",
"105/105 [==============================] - 0s 227us/step - loss: 0.8098 - acc: 0.4667\n",
"Epoch 15/100\n",
"105/105 [==============================] - 0s 235us/step - loss: 0.7952 - acc: 0.4762\n",
"Epoch 16/100\n",
"105/105 [==============================] - 0s 235us/step - loss: 0.7669 - acc: 0.4952\n",
"Epoch 17/100\n",
"105/105 [==============================] - 0s 247us/step - loss: 0.7486 - acc: 0.5238\n",
"Epoch 18/100\n",
"105/105 [==============================] - 0s 229us/step - loss: 0.7323 - acc: 0.5333\n",
"Epoch 19/100\n",
"105/105 [==============================] - 0s 232us/step - loss: 0.7151 - acc: 0.5048\n",
"Epoch 20/100\n",
"105/105 [==============================] - 0s 225us/step - loss: 0.6990 - acc: 0.5524\n",
"Epoch 21/100\n",
"105/105 [==============================] - 0s 243us/step - loss: 0.6869 - acc: 0.5333\n",
"Epoch 22/100\n",
"105/105 [==============================] - 0s 227us/step - loss: 0.6699 - acc: 0.5429\n",
"Epoch 23/100\n",
"105/105 [==============================] - 0s 232us/step - loss: 0.6560 - acc: 0.5905\n",
"Epoch 24/100\n",
"105/105 [==============================] - 0s 229us/step - loss: 0.6444 - acc: 0.6095\n",
"Epoch 25/100\n",
"105/105 [==============================] - 0s 230us/step - loss: 0.6384 - acc: 0.5524\n",
"Epoch 26/100\n",
"105/105 [==============================] - 0s 221us/step - loss: 0.6206 - acc: 0.6286\n",
"Epoch 27/100\n",
"105/105 [==============================] - 0s 229us/step - loss: 0.6107 - acc: 0.6667\n",
"Epoch 28/100\n",
"105/105 [==============================] - 0s 226us/step - loss: 0.6002 - acc: 0.6286\n",
"Epoch 29/100\n",
"105/105 [==============================] - 0s 227us/step - loss: 0.5915 - acc: 0.6190\n",
"Epoch 30/100\n",
"105/105 [==============================] - 0s 230us/step - loss: 0.5813 - acc: 0.6381\n",
"Epoch 31/100\n",
"105/105 [==============================] - 0s 227us/step - loss: 0.5714 - acc: 0.6571\n",
"Epoch 32/100\n",
"105/105 [==============================] - 0s 231us/step - loss: 0.5660 - acc: 0.6286\n",
"Epoch 33/100\n",
"105/105 [==============================] - 0s 232us/step - loss: 0.5626 - acc: 0.6762\n",
"Epoch 34/100\n",
"105/105 [==============================] - 0s 238us/step - loss: 0.5501 - acc: 0.6571\n",
"Epoch 35/100\n",
"105/105 [==============================] - 0s 233us/step - loss: 0.5405 - acc: 0.6667\n",
"Epoch 36/100\n",
"105/105 [==============================] - 0s 233us/step - loss: 0.5335 - acc: 0.6952\n",
"Epoch 37/100\n",
"105/105 [==============================] - 0s 231us/step - loss: 0.5262 - acc: 0.6857\n",
"Epoch 38/100\n",
"105/105 [==============================] - 0s 237us/step - loss: 0.5243 - acc: 0.6857\n",
"Epoch 39/100\n",
"105/105 [==============================] - 0s 242us/step - loss: 0.5121 - acc: 0.6857\n",
"Epoch 40/100\n",
"105/105 [==============================] - 0s 231us/step - loss: 0.5128 - acc: 0.6762\n",
"Epoch 41/100\n",
"105/105 [==============================] - 0s 227us/step - loss: 0.5009 - acc: 0.7238\n",
"Epoch 42/100\n",
"105/105 [==============================] - 0s 254us/step - loss: 0.4966 - acc: 0.7238\n",
"Epoch 43/100\n",
"105/105 [==============================] - 0s 273us/step - loss: 0.4872 - acc: 0.7238\n",
"Epoch 44/100\n",
"105/105 [==============================] - 0s 260us/step - loss: 0.4848 - acc: 0.7048\n",
"Epoch 45/100\n",
"105/105 [==============================] - 0s 253us/step - loss: 0.4780 - acc: 0.7524\n",
"Epoch 46/100\n",
"105/105 [==============================] - 0s 262us/step - loss: 0.4711 - acc: 0.7524\n",
"Epoch 47/100\n",
"105/105 [==============================] - 0s 239us/step - loss: 0.4715 - acc: 0.7429\n",
"Epoch 48/100\n",
"105/105 [==============================] - 0s 261us/step - loss: 0.4606 - acc: 0.7619\n",
"Epoch 49/100\n",
"105/105 [==============================] - 0s 244us/step - loss: 0.4564 - acc: 0.7619\n",
"Epoch 50/100\n",
"105/105 [==============================] - 0s 255us/step - loss: 0.4531 - acc: 0.7524\n",
"Epoch 51/100\n",
"105/105 [==============================] - 0s 268us/step - loss: 0.4487 - acc: 0.7524\n",
"Epoch 52/100\n",
"105/105 [==============================] - 0s 257us/step - loss: 0.4406 - acc: 0.7714\n",
"Epoch 53/100\n",
"105/105 [==============================] - 0s 248us/step - loss: 0.4386 - acc: 0.8000\n",
"Epoch 54/100\n",
"105/105 [==============================] - 0s 243us/step - loss: 0.4323 - acc: 0.8000\n",
"Epoch 55/100\n",
"105/105 [==============================] - 0s 239us/step - loss: 0.4298 - acc: 0.8190\n",
"Epoch 56/100\n",
"105/105 [==============================] - 0s 251us/step - loss: 0.4237 - acc: 0.8190\n",
"Epoch 57/100\n",
"105/105 [==============================] - 0s 252us/step - loss: 0.4256 - acc: 0.8000\n",
"Epoch 58/100\n",
"105/105 [==============================] - 0s 250us/step - loss: 0.4227 - acc: 0.7714\n",
"Epoch 59/100\n",
"105/105 [==============================] - 0s 270us/step - loss: 0.4136 - acc: 0.8000\n",
"Epoch 60/100\n",
"105/105 [==============================] - 0s 251us/step - loss: 0.4089 - acc: 0.8476\n",
"Epoch 61/100\n",
"105/105 [==============================] - 0s 247us/step - loss: 0.4015 - acc: 0.8476\n",
"Epoch 62/100\n",
"105/105 [==============================] - 0s 251us/step - loss: 0.4004 - acc: 0.8286\n",
"Epoch 63/100\n",
"105/105 [==============================] - 0s 254us/step - loss: 0.3967 - acc: 0.8667\n",
"Epoch 64/100\n",
"105/105 [==============================] - 0s 252us/step - loss: 0.3912 - acc: 0.8476\n",
"Epoch 65/100\n",
"105/105 [==============================] - 0s 275us/step - loss: 0.3853 - acc: 0.8667\n",
"Epoch 66/100\n",
"105/105 [==============================] - 0s 256us/step - loss: 0.3835 - acc: 0.8857\n",
"Epoch 67/100\n",
"105/105 [==============================] - 0s 256us/step - loss: 0.3799 - acc: 0.8857\n",
"Epoch 68/100\n",
"105/105 [==============================] - 0s 272us/step - loss: 0.3759 - acc: 0.8857\n",
"Epoch 69/100\n",
"105/105 [==============================] - 0s 246us/step - loss: 0.3760 - acc: 0.8667\n",
"Epoch 70/100\n",
"105/105 [==============================] - 0s 241us/step - loss: 0.3688 - acc: 0.9238\n",
"Epoch 71/100\n",
"105/105 [==============================] - 0s 252us/step - loss: 0.3638 - acc: 0.9143\n",
"Epoch 72/100\n",
"105/105 [==============================] - 0s 247us/step - loss: 0.3597 - acc: 0.9143\n",
"Epoch 73/100\n",
"105/105 [==============================] - 0s 261us/step - loss: 0.3589 - acc: 0.9048\n",
"Epoch 74/100\n",
"105/105 [==============================] - 0s 249us/step - loss: 0.3538 - acc: 0.9238\n",
"Epoch 75/100\n",
"105/105 [==============================] - 0s 252us/step - loss: 0.3497 - acc: 0.9238\n",
"Epoch 76/100\n",
"105/105 [==============================] - 0s 252us/step - loss: 0.3482 - acc: 0.9143\n",
"Epoch 77/100\n",
"105/105 [==============================] - 0s 245us/step - loss: 0.3436 - acc: 0.9333\n",
"Epoch 78/100\n",
"105/105 [==============================] - 0s 267us/step - loss: 0.3394 - acc: 0.9333\n",
"Epoch 79/100\n",
"105/105 [==============================] - 0s 258us/step - loss: 0.3370 - acc: 0.9238\n",
"Epoch 80/100\n",
"105/105 [==============================] - 0s 249us/step - loss: 0.3333 - acc: 0.9333\n",
"Epoch 81/100\n",
"105/105 [==============================] - 0s 251us/step - loss: 0.3330 - acc: 0.9333\n",
"Epoch 82/100\n",
"105/105 [==============================] - 0s 274us/step - loss: 0.3279 - acc: 0.9238\n",
"Epoch 83/100\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"105/105 [==============================] - 0s 243us/step - loss: 0.3233 - acc: 0.9143\n",
"Epoch 84/100\n",
"105/105 [==============================] - 0s 261us/step - loss: 0.3224 - acc: 0.9333\n",
"Epoch 85/100\n",
"105/105 [==============================] - 0s 289us/step - loss: 0.3206 - acc: 0.9143\n",
"Epoch 86/100\n",
"105/105 [==============================] - 0s 247us/step - loss: 0.3143 - acc: 0.9333\n",
"Epoch 87/100\n",
"105/105 [==============================] - 0s 229us/step - loss: 0.3108 - acc: 0.9333\n",
"Epoch 88/100\n",
"105/105 [==============================] - 0s 215us/step - loss: 0.3093 - acc: 0.9238\n",
"Epoch 89/100\n",
"105/105 [==============================] - 0s 220us/step - loss: 0.3068 - acc: 0.9333\n",
"Epoch 90/100\n",
"105/105 [==============================] - 0s 224us/step - loss: 0.3020 - acc: 0.9333\n",
"Epoch 91/100\n",
"105/105 [==============================] - 0s 237us/step - loss: 0.3049 - acc: 0.9143\n",
"Epoch 92/100\n",
"105/105 [==============================] - 0s 232us/step - loss: 0.2995 - acc: 0.9333\n",
"Epoch 93/100\n",
"105/105 [==============================] - 0s 239us/step - loss: 0.2972 - acc: 0.9238\n",
"Epoch 94/100\n",
"105/105 [==============================] - 0s 221us/step - loss: 0.2923 - acc: 0.9238\n",
"Epoch 95/100\n",
"105/105 [==============================] - 0s 225us/step - loss: 0.2909 - acc: 0.9429\n",
"Epoch 96/100\n",
"105/105 [==============================] - 0s 223us/step - loss: 0.2853 - acc: 0.9238\n",
"Epoch 97/100\n",
"105/105 [==============================] - 0s 228us/step - loss: 0.2838 - acc: 0.9333\n",
"Epoch 98/100\n",
"105/105 [==============================] - 0s 227us/step - loss: 0.2800 - acc: 0.9238\n",
"Epoch 99/100\n",
"105/105 [==============================] - 0s 224us/step - loss: 0.2774 - acc: 0.9333\n",
"Epoch 100/100\n",
"105/105 [==============================] - 0s 235us/step - loss: 0.2749 - acc: 0.9333\n"
]
},
{
"data": {
"text/plain": [
"<keras.callbacks.History at 0x1a37a03898>"
]
},
"execution_count": 97,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.fit(X_train, y_train, nb_epoch=100, batch_size=5)"
]
},
{
"cell_type": "code",
"execution_count": 98,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy = 0.93\n"
]
}
],
"source": [
"loss, accuracy = model.evaluate(X_test, y_test, verbose=0)\n",
"print(\"Accuracy = {:.2f}\".format(accuracy))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# この0.93のいう数値の妥当性を検証する!\n",
"ちなみにcross_val_score()を使うとエラーが出た。なのでkeras既存のもので行おう"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"- https://machinelearningmastery.com/multi-class-classification-tutorial-keras-deep-learning-library/\n",
"\n",
"関数を作らないとダメらしいので下ではmyとした"
]
},
{
"cell_type": "code",
"execution_count": 99,
"metadata": {},
"outputs": [],
"source": [
"def my():\n",
" model = Sequential()\n",
" model.add(Dense(6, input_dim=4))\n",
" model.add(Activation('relu'))\n",
" model.add(Dense(3, input_dim=12))\n",
" model.add(Activation('softmax'))\n",
" model.compile(optimizer='ADAM', \n",
" loss='categorical_crossentropy',\n",
" metrics=['accuracy']\n",
" )\n",
" return model"
]
},
{
"cell_type": "code",
"execution_count": 100,
"metadata": {},
"outputs": [],
"source": [
"from keras.wrappers.scikit_learn import KerasClassifier\n",
"from sklearn.model_selection import cross_val_score\n",
"from sklearn.model_selection import KFold"
]
},
{
"cell_type": "code",
"execution_count": 101,
"metadata": {},
"outputs": [],
"source": [
"estimator = KerasClassifier(build_fn=my, epochs=200, batch_size=5, verbose=0)"
]
},
{
"cell_type": "code",
"execution_count": 102,
"metadata": {},
"outputs": [],
"source": [
"import numpy\n",
"seed = 1\n",
"numpy.random.seed(seed)\n",
"\n",
"kfold = KFold(n_splits=10, shuffle=True, random_state=seed)"
]
},
{
"cell_type": "code",
"execution_count": 103,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Baseline: 96.67% (4.47%)\n"
]
}
],
"source": [
"results = cross_val_score(estimator, X, dummy_y, cv=kfold)\n",
"print(\"Baseline: %.2f%% (%.2f%%)\" % (results.mean()*100, results.std()*100))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# これで君はもうCross-Validatorだ!\n",
"\n",
"次回は回帰において交差検証をしてみるよ。楽しみにしておいて"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment