Skip to content

Instantly share code, notes, and snippets.

@yamasakih
Last active October 30, 2017 13:06
Show Gist options
  • Save yamasakih/5226f65a7ba91e2338a6d904ab4bf5ba to your computer and use it in GitHub Desktop.
Save yamasakih/5226f65a7ba91e2338a6d904ab4bf5ba to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"K-fold cross-validationやLeave-one-out cross-validationを行ってみる。"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"sys.version_info(major=3, minor=6, micro=2, releaselevel='final', serial=0)\n"
]
}
],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"import sys\n",
"\n",
"from sklearn import svm\n",
"from sklearn.datasets import load_iris, load_digits\n",
"from sklearn.model_selection import KFold, ShuffleSplit, GroupKFold, \\\n",
" StratifiedKFold, StratifiedShuffleSplit, \\\n",
" LeaveOneOut, LeavePOut, \\\n",
" cross_val_predict, cross_val_score, GridSearchCV\n",
"\n",
"\n",
"print(sys.version_info)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"scikit-learnではcross validationのための様々なクラスや関数が用意されているが、 \n",
"最も多用することになるのが4で紹介している`cross_val_predict`, `cross_val_score`, `GridSearchCV`である。 \n",
"これらをまずデフォルトで使い、慣れてきたら1や2で紹介しているcross validationを引数として与えてよりよいcross validationを行うと良い。"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 1. K-fold cross validationのための様々なクラスや関数"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### 1-1.kFold"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"[kFold](http://scikit-learn.org/stable/modules/generated/sklearn.model_selection.KFold.html#sklearn.model_selection.KFold)を使うことでfoldを作成し、Training setとValidation setに分割することができる。<br>\n",
"再現性を担保せずに分割したい場合は`random_state=None`, `shuffle=True`とする。 \n",
"このルールは他のクラスでも適用される。"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"KFold(n_splits=3, random_state=None, shuffle=False)\n",
"TRAIN: [2 3 4 5] VALIDATION: [0 1]\n",
"TRAIN: [0 1 4 5] VALIDATION: [2 3]\n",
"TRAIN: [0 1 2 3] VALIDATION: [4 5]\n",
"KFold(n_splits=3, random_state=None, shuffle=True)\n",
"TRAIN: [1 2 3 4] VALIDATION: [0 5]\n",
"TRAIN: [0 2 3 5] VALIDATION: [1 4]\n",
"TRAIN: [0 1 4 5] VALIDATION: [2 3]\n"
]
}
],
"source": [
"X = np.array([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]])\n",
"y = np.array([1, 2, 1, 2, 1, 2])\n",
"k_fold = KFold(n_splits=3)\n",
"print(k_fold)\n",
"for train_index, validation_index in k_fold.split(X):\n",
" print(\"TRAIN:\", train_index, \"VALIDATION:\", validation_index)\n",
"k_fold = KFold(n_splits=3, random_state=None, shuffle=True)\n",
"print(k_fold)\n",
"for train_index, validation_index in k_fold.split(X):\n",
" print(\"TRAIN:\", train_index, \"VALIDATION:\", validation_index) \n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"`random_state=123`などのように、intを指定するとseed値を固定することも可能である。 \n",
"以下の例では2回3-foldを行っているが`random_state=123`としているので2回のTraining set, Validation setは一致している。"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"KFold(n_splits=3, random_state=123, shuffle=True)\n",
"1st\n",
"TRAIN: [0 2 4 5] VALIDATION: [1 3]\n",
"TRAIN: [1 2 3 5] VALIDATION: [0 4]\n",
"TRAIN: [0 1 3 4] VALIDATION: [2 5]\n",
"2nd\n",
"TRAIN: [0 2 4 5] VALIDATION: [1 3]\n",
"TRAIN: [1 2 3 5] VALIDATION: [0 4]\n",
"TRAIN: [0 1 3 4] VALIDATION: [2 5]\n"
]
}
],
"source": [
"X = np.array([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]])\n",
"y = np.array([1, 2, 1, 2, 1, 2])\n",
"k_fold = KFold(n_splits=3)\n",
"k_fold = KFold(n_splits=3, random_state=123, shuffle=True)\n",
"print(k_fold)\n",
"print('1st')\n",
"for train_index, validation_index in k_fold.split(X):\n",
" print(\"TRAIN:\", train_index, \"VALIDATION:\", validation_index) \n",
"print('2nd')\n",
"for train_index, validation_index in k_fold.split(X):\n",
" print(\"TRAIN:\", train_index, \"VALIDATION:\", validation_index) "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### 1-2. ShuffleSplit"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"[ShuffleSplit](http://scikit-learn.org/stable/modules/generated/sklearn.model_selection.ShuffleSplit.html#sklearn.model_selection.ShuffleSplit)を使ってもランダムにfoldを作成し、Training setとValidation setに分割することができる。<br>\n",
"再現性を担保せずに分割したい場合は`random_state=None`とする。"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"ShuffleSplit(n_splits=3, random_state=None, test_size=0.5, train_size=0.5)\n",
"TRAIN: [2 1 0] VALIDATION: [3 4 5]\n",
"TRAIN: [2 1 3] VALIDATION: [4 5 0]\n",
"TRAIN: [1 0 4] VALIDATION: [2 5 3]\n"
]
}
],
"source": [
"X = np.array([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]])\n",
"y = np.array([1, 2, 1, 2, 1, 2])\n",
"rs = ShuffleSplit(n_splits=3, random_state=None, test_size=0.5, train_size=0.5)\n",
"print(rs)\n",
"for train_index, validation_index in rs.split(X):\n",
" print(\"TRAIN:\", train_index, \"VALIDATION:\", validation_index)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### 1-3. GroupKFold"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"[GroupKFold](http://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GroupKFold.html#sklearn.model_selection.GroupKFold)を使うと必ず同じfoldに入るためのグループを作ることができる。<br>\n",
"例えば以下の例では0, 1番目のデータ、2, 3番目のデータはそれぞれ必ず同じfoldに分類される。"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"GroupKFold(n_splits=2)\n",
"TRAIN: [0 1 2] VALIDATION: [3 4 5]\n",
"[[1 2]\n",
" [3 4]\n",
" [5 6]] [[ 7 8]\n",
" [ 9 10]\n",
" [11 12]] [1 2 1] [2 1 2]\n",
"TRAIN: [3 4 5] VALIDATION: [0 1 2]\n",
"[[ 7 8]\n",
" [ 9 10]\n",
" [11 12]] [[1 2]\n",
" [3 4]\n",
" [5 6]] [2 1 2] [1 2 1]\n"
]
}
],
"source": [
"X = np.array([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]])\n",
"y = np.array([1, 2, 1, 2, 1, 2])\n",
"groups = np.array([0, 0, 0, 2, 2, 2])\n",
"skf = group_kfold = GroupKFold(n_splits=2)\n",
"group_kfold.get_n_splits(X, y, groups)\n",
"print(skf)\n",
"for train_index, validation_index in group_kfold.split(X, y, groups):\n",
" print(\"TRAIN:\", train_index, \"VALIDATION:\", validation_index)\n",
" X_train, X_test = X[train_index], X[validation_index]\n",
" y_train, y_test = y[train_index], y[validation_index]\n",
" print(X_train, X_test, y_train, y_test)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### 1-4. StratifiedKFold"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"[StratifiedKFold](http://scikit-learn.org/stable/modules/generated/sklearn.model_selection.StratifiedKFold.html#sklearn.model_selection.StratifiedKFold)を使うと目的変数のクラスごとの割合を考慮してfoldを作成することができる。<br>以下の例では<br>y=1であるインデックス番号0, 1, 2, 3、<br>y=2であるインデックス番号4, 5, 6, 7<br>から半分ずつ取り出しTraining setとし、残りをValidation setとしfoldを作成している。<br>\n",
"再現性を担保せずに分割したい場合は`shuffle=True`とし、seed値を固定しない場合は`random_state=None`とする。"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"StratifiedKFold(n_splits=2, random_state=None, shuffle=False)\n",
"TRAIN: [2 3 6 7] VALIDATION: [0 1 4 5]\n",
"TRAIN: [0 1 4 5] VALIDATION: [2 3 6 7]\n",
"StratifiedKFold(n_splits=4, random_state=None, shuffle=True)\n",
"TRAIN: [0 1 2 4 6 7] VALIDATION: [3 5]\n",
"TRAIN: [1 2 3 4 5 6] VALIDATION: [0 7]\n",
"TRAIN: [0 1 3 4 5 7] VALIDATION: [2 6]\n",
"TRAIN: [0 2 3 5 6 7] VALIDATION: [1 4]\n"
]
}
],
"source": [
"X = np.array([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], [15, 16]])\n",
"y = np.array([1, 1, 1, 1, 2, 2, 2, 2])\n",
"skf = StratifiedKFold(n_splits=2)\n",
"print(skf)\n",
"for train_index, validation_index in skf.split(X, y):\n",
" print(\"TRAIN:\", train_index, \"VALIDATION:\", validation_index)\n",
" X_train, X_test = X[train_index], X[validation_index]\n",
" y_train, y_test = y[train_index], y[validation_index]\n",
"skf = StratifiedKFold(n_splits=4, random_state=None, shuffle=True)\n",
"print(skf)\n",
"for train_index, validation_index in skf.split(X, y):\n",
" print(\"TRAIN:\", train_index, \"VALIDATION:\", validation_index)\n",
" X_train, X_test = X[train_index], X[validation_index]\n",
" y_train, y_test = y[train_index], y[validation_index] "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### 1-5. StratifiedShuffleSplit"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"[StratifiedShuffleSplit](http://scikit-learn.org/stable/modules/generated/sklearn.model_selection.StratifiedShuffleSplit.html#sklearn.model_selection.StratifiedShuffleSplit)を使ってもランダムにクラスごとの割合を考慮しつつfoldを作成し、<br>Training setとValidation setに分割することができる。<br>\n",
"乱数を固定したくない場合はrandom_state=Noneとする。"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"StratifiedShuffleSplit(n_splits=2, random_state=0, test_size=0.5,\n",
" train_size=None)\n",
"TRAIN: [6 2 4 3] VALIDATION: [0 1 5 7]\n",
"TRAIN: [2 1 4 7] VALIDATION: [3 6 5 0]\n",
"StratifiedShuffleSplit(n_splits=2, random_state=None, test_size=0.25,\n",
" train_size=None)\n",
"TRAIN: [3 7 2 1 5 6] VALIDATION: [4 0]\n",
"TRAIN: [7 5 4 0 3 2] VALIDATION: [1 6]\n"
]
}
],
"source": [
"X = np.array([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], [15, 16]])\n",
"y = np.array([1, 1, 1, 1, 2, 2, 2, 2])\n",
"sss = StratifiedShuffleSplit(n_splits=2, test_size=0.5, random_state=0)\n",
"print(sss)\n",
"for train_index, validation_index in sss.split(X, y):\n",
" print(\"TRAIN:\", train_index, \"VALIDATION:\", validation_index)\n",
"sss = StratifiedShuffleSplit(n_splits=2, test_size=0.25, random_state=None)\n",
"print(sss)\n",
"for train_index, validation_index in sss.split(X, y):\n",
" print(\"TRAIN:\", train_index, \"VALIDATION:\", validation_index)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 2. Leave-one-out cross validationのための様々なクラスや関数"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### 2-1. LeaveOneOut"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"[LeaveOneOut](http://scikit-learn.org/stable/modules/generated/sklearn.model_selection.LeaveOneOut.html#sklearn.model_selection.LeaveOneOut)を使うことでデータセットのサンプル一つずつをValidation setとすることができる。"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"LeaveOneOut()\n",
"TRAIN: [1 2 3 4 5 6 7] VALIDATION: [0]\n",
"TRAIN: [0 2 3 4 5 6 7] VALIDATION: [1]\n",
"TRAIN: [0 1 3 4 5 6 7] VALIDATION: [2]\n",
"TRAIN: [0 1 2 4 5 6 7] VALIDATION: [3]\n",
"TRAIN: [0 1 2 3 5 6 7] VALIDATION: [4]\n",
"TRAIN: [0 1 2 3 4 6 7] VALIDATION: [5]\n",
"TRAIN: [0 1 2 3 4 5 7] VALIDATION: [6]\n",
"TRAIN: [0 1 2 3 4 5 6] VALIDATION: [7]\n"
]
}
],
"source": [
"X = np.array([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], [15, 16]])\n",
"y = np.array([1, 1, 1, 1, 2, 2, 2, 2])\n",
"loo = LeaveOneOut()\n",
"print(loo)\n",
"for train_index, validation_index in loo.split(X):\n",
" print(\"TRAIN:\", train_index, \"VALIDATION:\", validation_index)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### 2-2. LeavePOut"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"1個ではなく2個や3個取り出したいときはLeavePOutを使う(P=2, 3など)。<br>\n",
"データセットのサンプルP個ずつをValidation setとすることができる。"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"LeavePOut(p=2)\n",
"TRAIN: [2 3 4 5 6 7] VALIDATION: [0 1]\n",
"TRAIN: [1 3 4 5 6 7] VALIDATION: [0 2]\n",
"TRAIN: [1 2 4 5 6 7] VALIDATION: [0 3]\n",
"TRAIN: [1 2 3 5 6 7] VALIDATION: [0 4]\n",
"TRAIN: [1 2 3 4 6 7] VALIDATION: [0 5]\n",
"TRAIN: [1 2 3 4 5 7] VALIDATION: [0 6]\n",
"TRAIN: [1 2 3 4 5 6] VALIDATION: [0 7]\n",
"TRAIN: [0 3 4 5 6 7] VALIDATION: [1 2]\n",
"TRAIN: [0 2 4 5 6 7] VALIDATION: [1 3]\n",
"TRAIN: [0 2 3 5 6 7] VALIDATION: [1 4]\n",
"TRAIN: [0 2 3 4 6 7] VALIDATION: [1 5]\n",
"TRAIN: [0 2 3 4 5 7] VALIDATION: [1 6]\n",
"TRAIN: [0 2 3 4 5 6] VALIDATION: [1 7]\n",
"TRAIN: [0 1 4 5 6 7] VALIDATION: [2 3]\n",
"TRAIN: [0 1 3 5 6 7] VALIDATION: [2 4]\n",
"TRAIN: [0 1 3 4 6 7] VALIDATION: [2 5]\n",
"TRAIN: [0 1 3 4 5 7] VALIDATION: [2 6]\n",
"TRAIN: [0 1 3 4 5 6] VALIDATION: [2 7]\n",
"TRAIN: [0 1 2 5 6 7] VALIDATION: [3 4]\n",
"TRAIN: [0 1 2 4 6 7] VALIDATION: [3 5]\n",
"TRAIN: [0 1 2 4 5 7] VALIDATION: [3 6]\n",
"TRAIN: [0 1 2 4 5 6] VALIDATION: [3 7]\n",
"TRAIN: [0 1 2 3 6 7] VALIDATION: [4 5]\n",
"TRAIN: [0 1 2 3 5 7] VALIDATION: [4 6]\n",
"TRAIN: [0 1 2 3 5 6] VALIDATION: [4 7]\n",
"TRAIN: [0 1 2 3 4 7] VALIDATION: [5 6]\n",
"TRAIN: [0 1 2 3 4 6] VALIDATION: [5 7]\n",
"TRAIN: [0 1 2 3 4 5] VALIDATION: [6 7]\n"
]
}
],
"source": [
"X = np.array([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], [15, 16]])\n",
"y = np.array([1, 1, 1, 1, 2, 2, 2, 2])\n",
"lpo = LeavePOut(2)\n",
"print(lpo)\n",
"for train_index, validation_index in lpo.split(X):\n",
" print(\"TRAIN:\", train_index, \"VALIDATION:\", validation_index)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 3. k-Fold や Leave-One-Outの適用例"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"上記のk-FoldやLeave-One-Outのクラスは引数やメソッドに共通点があるため<br>\n",
"ほぼ同じようにモデル構築に適用することができる。"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"iris (機械学習でよく使われるあやめのデータセット) を読み込む。"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"iris = load_iris()\n",
"X_iris = pd.DataFrame(iris.data, columns=iris.feature_names)\n",
"y_iris = iris.target"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"サンプルサイズは150である。<br>\n",
"そして説明変数としてとしてsepal length (がくの長さ), sepal width (がくの幅), petal length (花びらの長さ), petal width (花びらの幅)がある。"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style>\n",
" .dataframe thead tr:only-child th {\n",
" text-align: right;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: left;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>sepal length (cm)</th>\n",
" <th>sepal width (cm)</th>\n",
" <th>petal length (cm)</th>\n",
" <th>petal width (cm)</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>5.1</td>\n",
" <td>3.5</td>\n",
" <td>1.4</td>\n",
" <td>0.2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>4.9</td>\n",
" <td>3.0</td>\n",
" <td>1.4</td>\n",
" <td>0.2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>4.7</td>\n",
" <td>3.2</td>\n",
" <td>1.3</td>\n",
" <td>0.2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>4.6</td>\n",
" <td>3.1</td>\n",
" <td>1.5</td>\n",
" <td>0.2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>5.0</td>\n",
" <td>3.6</td>\n",
" <td>1.4</td>\n",
" <td>0.2</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" sepal length (cm) sepal width (cm) petal length (cm) petal width (cm)\n",
"0 5.1 3.5 1.4 0.2\n",
"1 4.9 3.0 1.4 0.2\n",
"2 4.7 3.2 1.3 0.2\n",
"3 4.6 3.1 1.5 0.2\n",
"4 5.0 3.6 1.4 0.2"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X_iris.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"目的変数はあやめの種類である。0: setosa, 1: versicolor, 2: virginica"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
" 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,\n",
" 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,\n",
" 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y_iris"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"サンプルサイズと説明変数の数は`shape`に保存されている。"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(150, 4)"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X_iris.shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"続けて、例としてStratifiedShuffleSplitを用いてSVMによるクラス分類を行ってみる。"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"StratifiedShuffleSplit(n_splits=4, random_state=0, test_size=0.5,\n",
" train_size=None)\n",
"TRAIN: [ 16 69 15 4 78 138 111 10 93 45 74 58 106 22 56 28 107 27\n",
" 94 72 66 33 143 87 96 115 73 84 26 126 11 91 128 105 79 48\n",
" 7 148 31 119 59 124 38 57 95 101 83 137 112 52 92 30 63 42\n",
" 14 108 125 122 141 32 140 35 76 41 2 18 146 135 127 116 80 29\n",
" 104 82 34] VALIDATION: [139 65 145 6 129 25 85 23 118 64 17 121 71 39 67 36 131 149\n",
" 24 0 89 8 136 110 132 147 117 9 130 75 134 144 97 114 19 43\n",
" 49 21 50 86 37 20 61 81 5 123 44 99 77 102 98 3 142 40\n",
" 88 60 12 103 53 109 90 133 70 100 13 47 54 1 51 68 113 62\n",
" 120 46 55]\n",
"TRAIN: [ 7 10 141 6 94 31 113 140 108 11 128 96 149 110 98 4 101 44\n",
" 5 2 144 102 112 86 41 20 59 118 148 115 99 132 88 57 105 103\n",
" 83 45 138 62 74 81 52 13 114 67 40 47 82 33 106 38 18 135\n",
" 63 75 79 37 55 72 70 111 95 142 15 64 121 19 91 42 26 126\n",
" 12 1 69] VALIDATION: [120 78 29 46 58 134 125 25 53 48 51 104 146 123 54 131 9 68\n",
" 35 139 50 43 147 145 73 130 32 3 77 127 24 109 16 87 71 56\n",
" 93 0 124 14 23 66 100 27 89 137 76 17 133 116 34 136 8 22\n",
" 36 30 39 60 122 65 129 28 97 85 119 21 92 117 80 90 49 143\n",
" 84 107 61]\n",
"TRAIN: [131 76 116 145 114 18 4 95 52 61 94 87 29 103 142 9 0 65\n",
" 45 46 81 121 88 44 24 32 28 56 122 89 71 90 77 72 115 60\n",
" 78 85 49 58 41 129 91 117 127 69 107 99 113 11 33 74 119 34\n",
" 105 147 102 3 101 30 111 100 106 109 2 19 23 51 40 143 15 97\n",
" 20 38 123] VALIDATION: [130 17 125 124 53 43 62 86 79 80 31 137 55 36 42 64 141 8\n",
" 5 82 98 84 10 134 83 7 135 27 50 48 12 132 35 16 133 139\n",
" 59 140 136 63 93 73 148 108 1 138 110 66 68 21 128 96 70 13\n",
" 26 92 126 39 54 22 75 37 149 67 6 118 25 57 104 144 120 47\n",
" 112 146 14]\n",
"TRAIN: [138 137 136 139 82 64 27 21 3 99 9 33 30 149 100 72 47 145\n",
" 50 40 96 79 97 140 109 134 7 142 92 133 112 58 45 42 131 32\n",
" 4 66 60 118 110 95 10 13 75 12 38 57 44 77 41 20 83 51\n",
" 81 130 73 101 117 18 28 91 89 62 125 68 106 148 143 78 36 34\n",
" 120 108 2] VALIDATION: [ 23 123 26 1 37 102 15 113 25 90 31 124 147 22 111 105 71 86\n",
" 146 56 128 24 69 88 126 121 122 61 0 43 87 141 144 119 19 127\n",
" 49 17 65 5 46 85 135 76 16 67 52 14 103 11 6 55 93 94\n",
" 116 53 74 70 84 104 115 39 132 129 54 63 35 48 114 80 29 8\n",
" 98 107 59]\n"
]
}
],
"source": [
"sss = StratifiedShuffleSplit(n_splits=4, test_size=0.5, random_state=0)\n",
"print(sss)\n",
"for train_index, validation_index in sss.split(X_iris, y_iris):\n",
" print(\"TRAIN:\", train_index, \"VALIDATION:\", validation_index)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"pandasのデータフレームおよびnumpyのarrayはどちらもインデックスを指定することで特定の行や要素だけ取り出すことができる。<br>試しにわざとbreakを入れて一つ目のfoldのTraining set setだけ取り出してみる。"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"StratifiedShuffleSplit(n_splits=4, random_state=0, test_size=0.5,\n",
" train_size=None)\n",
"TRAIN: [ 16 69 15 4 78 138 111 10 93 45 74 58 106 22 56 28 107 27\n",
" 94 72 66 33 143 87 96 115 73 84 26 126 11 91 128 105 79 48\n",
" 7 148 31 119 59 124 38 57 95 101 83 137 112 52 92 30 63 42\n",
" 14 108 125 122 141 32 140 35 76 41 2 18 146 135 127 116 80 29\n",
" 104 82 34]\n",
" sepal length (cm) sepal width (cm) petal length (cm) petal width (cm)\n",
"16 5.4 3.9 1.3 0.4\n",
"69 5.6 2.5 3.9 1.1\n",
"15 5.7 4.4 1.5 0.4\n",
"4 5.0 3.6 1.4 0.2\n",
"78 6.0 2.9 4.5 1.5\n",
"138 6.0 3.0 4.8 1.8\n",
"111 6.4 2.7 5.3 1.9\n",
"10 5.4 3.7 1.5 0.2\n",
"93 5.0 2.3 3.3 1.0\n",
"45 4.8 3.0 1.4 0.3\n",
"74 6.4 2.9 4.3 1.3\n",
"58 6.6 2.9 4.6 1.3\n",
"106 4.9 2.5 4.5 1.7\n",
"22 4.6 3.6 1.0 0.2\n",
"56 6.3 3.3 4.7 1.6\n",
"28 5.2 3.4 1.4 0.2\n",
"107 7.3 2.9 6.3 1.8\n",
"27 5.2 3.5 1.5 0.2\n",
"94 5.6 2.7 4.2 1.3\n",
"72 6.3 2.5 4.9 1.5\n",
"66 5.6 3.0 4.5 1.5\n",
"33 5.5 4.2 1.4 0.2\n",
"143 6.8 3.2 5.9 2.3\n",
"87 6.3 2.3 4.4 1.3\n",
"96 5.7 2.9 4.2 1.3\n",
"115 6.4 3.2 5.3 2.3\n",
"73 6.1 2.8 4.7 1.2\n",
"84 5.4 3.0 4.5 1.5\n",
"26 5.0 3.4 1.6 0.4\n",
"126 6.2 2.8 4.8 1.8\n",
".. ... ... ... ...\n",
"101 5.8 2.7 5.1 1.9\n",
"83 6.0 2.7 5.1 1.6\n",
"137 6.4 3.1 5.5 1.8\n",
"112 6.8 3.0 5.5 2.1\n",
"52 6.9 3.1 4.9 1.5\n",
"92 5.8 2.6 4.0 1.2\n",
"30 4.8 3.1 1.6 0.2\n",
"63 6.1 2.9 4.7 1.4\n",
"42 4.4 3.2 1.3 0.2\n",
"14 5.8 4.0 1.2 0.2\n",
"108 6.7 2.5 5.8 1.8\n",
"125 7.2 3.2 6.0 1.8\n",
"122 7.7 2.8 6.7 2.0\n",
"141 6.9 3.1 5.1 2.3\n",
"32 5.2 4.1 1.5 0.1\n",
"140 6.7 3.1 5.6 2.4\n",
"35 5.0 3.2 1.2 0.2\n",
"76 6.8 2.8 4.8 1.4\n",
"41 4.5 2.3 1.3 0.3\n",
"2 4.7 3.2 1.3 0.2\n",
"18 5.7 3.8 1.7 0.3\n",
"146 6.3 2.5 5.0 1.9\n",
"135 7.7 3.0 6.1 2.3\n",
"127 6.1 3.0 4.9 1.8\n",
"116 6.5 3.0 5.5 1.8\n",
"80 5.5 2.4 3.8 1.1\n",
"29 4.7 3.2 1.6 0.2\n",
"104 6.5 3.0 5.8 2.2\n",
"82 5.8 2.7 3.9 1.2\n",
"34 4.9 3.1 1.5 0.1\n",
"\n",
"[75 rows x 4 columns]\n",
"[0 1 0 0 1 2 2 0 1 0 1 1 2 0 1 0 2 0 1 1 1 0 2 1 1 2 1 1 0 2 0 1 2 2 1 0 0\n",
" 2 0 2 1 2 0 1 1 2 1 2 2 1 1 0 1 0 0 2 2 2 2 0 2 0 1 0 0 0 2 2 2 2 1 0 2 1\n",
" 0]\n"
]
}
],
"source": [
"sss = StratifiedShuffleSplit(n_splits=4, test_size=0.5, random_state=0)\n",
"print(sss)\n",
"for train_index, validation_index in sss.split(X_iris, y_iris):\n",
" print(\"TRAIN:\", train_index)\n",
" print(X_iris.loc[train_index, :])\n",
" print(y_iris[train_index])\n",
" break"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"この性質を利用してfoldごとにモデルを構築し予測するスクリプトは以下のようになる。SVCはSVMによるクラス分類を行うクラスである。"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0 0 0 0 1 1 1 0 1 2 2 2 1 2 1 0 0 2 0 1 2 1 1 0 2 0 0 1 2 1 0 1 2 2 0 1 2\n",
" 2]\n",
"[0 0 2 2 1 0 0 2 2 2 1 0 0 1 1 1 2 1 2 1 2 1 1 0 1 0 2 1 2 0 1 0 1 0 0 0 2\n",
" 2]\n",
"[1 2 0 2 0 0 1 2 0 1 1 2 1 1 0 2 2 2 1 0 2 1 2 1 0 0 0 2 0 2 2 0 1 0 1 2 1\n",
" 1]\n",
"[0 0 1 0 2 0 2 2 1 2 1 0 0 1 2 1 2 0 0 1 0 1 2 1 2 2 2 0 2 1 0 2 1 1 2 1 2\n",
" 0]\n"
]
}
],
"source": [
"sss = StratifiedShuffleSplit(n_splits=4, test_size=0.25, random_state=0)\n",
"for train_index, validation_index in sss.split(X_iris, y_iris):\n",
" X_train, y_train = X_iris.loc[train_index, :], y_iris[train_index]\n",
" X_test, y_test = X_iris.loc[validation_index, :], y_iris[validation_index]\n",
" svc = svm.SVC()\n",
" svc.fit(X_train, y_train)\n",
" print(svc.predict(X_test))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"さらにSVCクラスのscoreメソッドを用いると実測値(観測値)と予測値を比較し正答率を計算することができる。"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1.0\n",
"0.933333333333\n",
"1.0\n",
"1.0\n",
"1.0\n",
"0.933333333333\n",
"0.933333333333\n",
"1.0\n",
"1.0\n",
"1.0\n"
]
}
],
"source": [
"skf = StratifiedKFold(n_splits=10)\n",
"for train_index, validation_index in skf.split(X_iris, y_iris):\n",
" X_train, y_train = X_iris.loc[train_index, :], y_iris[train_index]\n",
" X_test, y_test = X_iris.loc[validation_index, :], y_iris[validation_index]\n",
" svc = svm.SVC()\n",
" svc.fit(X_train, y_train)\n",
" print(svc.score(X_test, y_test))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"これらの一連の工程は4に参照するcross_val_predictやcross_val_scoreを用いることで簡単に行うことができる。"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 4. Cross validtionおよびモデル作成をし予測まで一括で行う関数"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### 4.1 cross_val_predict"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"[cross_val_predict](http://scikit-learn.org/stable/modules/generated/sklearn.model_selection.cross_val_predict.html#sklearn.model_selection.cross_val_predict)を用いるとkFoldあるいはLeave-one-out cross validationを行いValidation setとして予測した値を求めることができる。<br>デフォルトでは3-fold cross validationを行うが上記のLeaveOneOutなどの関数を引数に与えるとそちらを使ってcross validationをすることができる。"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
" 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1,\n",
" 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2,\n",
" 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,\n",
" 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"svc = svm.SVC()\n",
"y_pred = cross_val_predict(svc, X_iris, y_iris)\n",
"y_pred"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### 4.2 cross_val_score"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"[cross_val_score](http://scikit-learn.org/stable/modules/generated/sklearn.model_selection.cross_val_score.html#sklearn.model_selection.cross_val_score)を用いるとkFoldあるいはLeave-one-out cross validationを行いValidation setとした予測値の正答率を求めることができる。<br>\n",
"デフォルトでは3-fold cross validationを行うが上記のLeaveOneOutなどの関数を引数に与えるとそちらを使ってcross validationをすることができる。"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([ 0.98039216, 0.96078431, 0.97916667])"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"svc = svm.SVC()\n",
"scores = cross_val_score(svc, X_iris, y_iris)\n",
"scores"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"cross_val_predict, cross_val_score, GridSearchCVの引数cvにintを指定するとその値のK-Fold cross validationを行う。 \n",
"さらにクラス分類の場合はK-Fold StratifiedKFoldでデータを分割しcross validationを行う。 \n",
"例えば以下の例では`cv=10`としているので10-fold cross validationを行う。"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([ 1. , 0.93333333, 1. , 1. , 1. ,\n",
" 0.93333333, 0.93333333, 1. , 1. , 1. ])"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"svc = svm.SVC()\n",
"scores = cross_val_score(svc, X_iris, y_iris, cv=10)\n",
"scores"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"求められたscoresは3の最後に示したスクリプト( In[17] ) で得られた正答率と同じになっている。 \n",
"したがってcross_val_scoreを用いることで簡単に同じ工程を行うことができるといえる。"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"引数cvにintではなくStratifiedKFoldクラスなどのFoldのクラスをあたえることで目的変数の各クラスの割合を考慮した \n",
"Training set set, Validation set setを作成したk-fold cross validationを行い、その予測値の正答率を求めてくれる。"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([ 1. , 0.93333333, 1. , 1. , 1. ,\n",
" 0.93333333, 0.93333333, 1. , 1. , 1. ])"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"skf = StratifiedKFold(n_splits=10)\n",
"svc = svm.SVC()\n",
"scores = cross_val_score(svc, X_iris, y_iris, cv=skf)\n",
"scores"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"求められたscoresは3の最後に示したスクリプト( In[17], In[25] ) で得られた正答率と同じになっている。<br>\n",
"以上のことよりcross_val_scoreの引数cvをうまく用いることで自分が望むcross validationを簡単に行うことができるといえる。"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### 4.3 GridSearchCV"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"[GridSearchCV](http://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GridSearchCV.html#sklearn.model_selection.GridSearchCV)はSVMのパラメータの最適な値を求めるグリッドサーチのような網羅的な探索が必要な時に使う関数である。<br>デフォルトでは3-fold cross validationを行うが上記のLeaveOneOutなどの関数を引数に与えるとそちらを使ってcross validationをすることができる。"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Best Estimator:\n",
"SVC(C=1, cache_size=200, class_weight=None, coef0=0.0,\n",
" decision_function_shape=None, degree=3, gamma='auto', kernel='linear',\n",
" max_iter=-1, probability=False, random_state=None, shrinking=True,\n",
" tol=0.001, verbose=False)\n",
"\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style>\n",
" .dataframe thead tr:only-child th {\n",
" text-align: right;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: left;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>mean_fit_time</th>\n",
" <th>mean_score_time</th>\n",
" <th>mean_test_score</th>\n",
" <th>mean_train_score</th>\n",
" <th>param_C</th>\n",
" <th>param_kernel</th>\n",
" <th>params</th>\n",
" <th>rank_test_score</th>\n",
" <th>split0_test_score</th>\n",
" <th>split0_train_score</th>\n",
" <th>split1_test_score</th>\n",
" <th>split1_train_score</th>\n",
" <th>split2_test_score</th>\n",
" <th>split2_train_score</th>\n",
" <th>std_fit_time</th>\n",
" <th>std_score_time</th>\n",
" <th>std_test_score</th>\n",
" <th>std_train_score</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0.002394</td>\n",
" <td>0.000743</td>\n",
" <td>0.980000</td>\n",
" <td>0.989998</td>\n",
" <td>1</td>\n",
" <td>linear</td>\n",
" <td>{'C': 1, 'kernel': 'linear'}</td>\n",
" <td>1</td>\n",
" <td>1.000000</td>\n",
" <td>0.979798</td>\n",
" <td>0.960784</td>\n",
" <td>1.0</td>\n",
" <td>0.979167</td>\n",
" <td>0.990196</td>\n",
" <td>0.001029</td>\n",
" <td>0.000261</td>\n",
" <td>0.016179</td>\n",
" <td>0.008249</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>0.004125</td>\n",
" <td>0.001150</td>\n",
" <td>0.973333</td>\n",
" <td>0.983363</td>\n",
" <td>1</td>\n",
" <td>rbf</td>\n",
" <td>{'C': 1, 'kernel': 'rbf'}</td>\n",
" <td>4</td>\n",
" <td>0.980392</td>\n",
" <td>0.969697</td>\n",
" <td>0.960784</td>\n",
" <td>1.0</td>\n",
" <td>0.979167</td>\n",
" <td>0.980392</td>\n",
" <td>0.000370</td>\n",
" <td>0.000160</td>\n",
" <td>0.009021</td>\n",
" <td>0.012548</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>0.001152</td>\n",
" <td>0.000363</td>\n",
" <td>0.966667</td>\n",
" <td>0.983363</td>\n",
" <td>5</td>\n",
" <td>linear</td>\n",
" <td>{'C': 5, 'kernel': 'linear'}</td>\n",
" <td>6</td>\n",
" <td>1.000000</td>\n",
" <td>0.969697</td>\n",
" <td>0.901961</td>\n",
" <td>1.0</td>\n",
" <td>1.000000</td>\n",
" <td>0.980392</td>\n",
" <td>0.000215</td>\n",
" <td>0.000099</td>\n",
" <td>0.046442</td>\n",
" <td>0.012548</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>0.002128</td>\n",
" <td>0.000806</td>\n",
" <td>0.980000</td>\n",
" <td>0.983363</td>\n",
" <td>5</td>\n",
" <td>rbf</td>\n",
" <td>{'C': 5, 'kernel': 'rbf'}</td>\n",
" <td>1</td>\n",
" <td>0.980392</td>\n",
" <td>0.969697</td>\n",
" <td>0.960784</td>\n",
" <td>1.0</td>\n",
" <td>1.000000</td>\n",
" <td>0.980392</td>\n",
" <td>0.000817</td>\n",
" <td>0.000309</td>\n",
" <td>0.015925</td>\n",
" <td>0.012548</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>0.002020</td>\n",
" <td>0.000460</td>\n",
" <td>0.973333</td>\n",
" <td>0.979996</td>\n",
" <td>10</td>\n",
" <td>linear</td>\n",
" <td>{'C': 10, 'kernel': 'linear'}</td>\n",
" <td>4</td>\n",
" <td>1.000000</td>\n",
" <td>0.959596</td>\n",
" <td>0.921569</td>\n",
" <td>1.0</td>\n",
" <td>1.000000</td>\n",
" <td>0.980392</td>\n",
" <td>0.000762</td>\n",
" <td>0.000159</td>\n",
" <td>0.037154</td>\n",
" <td>0.016497</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>0.001791</td>\n",
" <td>0.000729</td>\n",
" <td>0.980000</td>\n",
" <td>0.979996</td>\n",
" <td>10</td>\n",
" <td>rbf</td>\n",
" <td>{'C': 10, 'kernel': 'rbf'}</td>\n",
" <td>1</td>\n",
" <td>0.980392</td>\n",
" <td>0.959596</td>\n",
" <td>0.960784</td>\n",
" <td>1.0</td>\n",
" <td>1.000000</td>\n",
" <td>0.980392</td>\n",
" <td>0.000688</td>\n",
" <td>0.000512</td>\n",
" <td>0.015925</td>\n",
" <td>0.016497</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" mean_fit_time mean_score_time mean_test_score mean_train_score param_C \\\n",
"0 0.002394 0.000743 0.980000 0.989998 1 \n",
"1 0.004125 0.001150 0.973333 0.983363 1 \n",
"2 0.001152 0.000363 0.966667 0.983363 5 \n",
"3 0.002128 0.000806 0.980000 0.983363 5 \n",
"4 0.002020 0.000460 0.973333 0.979996 10 \n",
"5 0.001791 0.000729 0.980000 0.979996 10 \n",
"\n",
" param_kernel params rank_test_score \\\n",
"0 linear {'C': 1, 'kernel': 'linear'} 1 \n",
"1 rbf {'C': 1, 'kernel': 'rbf'} 4 \n",
"2 linear {'C': 5, 'kernel': 'linear'} 6 \n",
"3 rbf {'C': 5, 'kernel': 'rbf'} 1 \n",
"4 linear {'C': 10, 'kernel': 'linear'} 4 \n",
"5 rbf {'C': 10, 'kernel': 'rbf'} 1 \n",
"\n",
" split0_test_score split0_train_score split1_test_score \\\n",
"0 1.000000 0.979798 0.960784 \n",
"1 0.980392 0.969697 0.960784 \n",
"2 1.000000 0.969697 0.901961 \n",
"3 0.980392 0.969697 0.960784 \n",
"4 1.000000 0.959596 0.921569 \n",
"5 0.980392 0.959596 0.960784 \n",
"\n",
" split1_train_score split2_test_score split2_train_score std_fit_time \\\n",
"0 1.0 0.979167 0.990196 0.001029 \n",
"1 1.0 0.979167 0.980392 0.000370 \n",
"2 1.0 1.000000 0.980392 0.000215 \n",
"3 1.0 1.000000 0.980392 0.000817 \n",
"4 1.0 1.000000 0.980392 0.000762 \n",
"5 1.0 1.000000 0.980392 0.000688 \n",
"\n",
" std_score_time std_test_score std_train_score \n",
"0 0.000261 0.016179 0.008249 \n",
"1 0.000160 0.009021 0.012548 \n",
"2 0.000099 0.046442 0.012548 \n",
"3 0.000309 0.015925 0.012548 \n",
"4 0.000159 0.037154 0.016497 \n",
"5 0.000512 0.015925 0.016497 "
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"parameters = {'kernel':('linear', 'rbf'), 'C':[1, 5, 10]}\n",
"svc = svm.SVC()\n",
"clf = GridSearchCV(svc, parameters)\n",
"clf.fit(X_iris, y_iris)\n",
"print(f'Best Estimator:\\n{clf.best_estimator_}\\n')\n",
"pd.DataFrame(clf.cv_results_)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"best_estimator_フィールドを確認することで最も良いモデルが得られた時のパラメータの値を確認することができる。<br>\n",
"今回はkernel='linear', C=1でありデータフレームを見ると確かにmean_test_scoreが最も良い。"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 5. cross_validationモジュールを使用してcross validationを行う"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"scikit-learn version2.0で廃止される予定だがcross_validationモジュールを使用してもcross validationを行うことができる。<br>\n",
"例えばcross_validationモジュールの[KFold](http://scikit-learn.org/0.16/modules/generated/sklearn.cross_validation.KFold.html#sklearn.cross_validation.KFold)クラスを用いることでKFold cross validationを行うことができる。\n",
"<br>shuffle=True, random_state=Noneとすることでランダムにすることができる。"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/conda/lib/python3.6/site-packages/sklearn/cross_validation.py:44: DeprecationWarning: This module was deprecated in version 0.18 in favor of the model_selection module into which all the refactored classes and functions are moved. Also note that the interface of the new CV iterators are different from that of this module. This module will be removed in 0.20.\n",
" \"This module will be removed in 0.20.\", DeprecationWarning)\n"
]
}
],
"source": [
"from sklearn import cross_validation"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"上記の赤枠ででている英文はversion 0.2になるとcross_validationモジュールは廃止されることを説明している。<br>\n",
"したがって新たに学ぶなら1〜4章で説明したモジュールの関数やクラスを学んだ方が良い。"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"sklearn.cross_validation.KFold(n=8, n_folds=2, shuffle=False, random_state=None)\n",
"TRAIN: [4 5 6 7] VALIDATION: [0 1 2 3]\n",
"TRAIN: [0 1 2 3] VALIDATION: [4 5 6 7]\n",
"sklearn.cross_validation.KFold(n=8, n_folds=2, shuffle=True, random_state=None)\n",
"TRAIN: [0 2 6 7] VALIDATION: [1 3 4 5]\n",
"TRAIN: [1 3 4 5] VALIDATION: [0 2 6 7]\n"
]
}
],
"source": [
"X = np.array([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], [15, 16]])\n",
"y = np.array([1, 1, 1, 1, 2, 2, 2, 2])\n",
"kf = cross_validation.KFold(8, n_folds=2)\n",
"print(kf) \n",
"for train_index, validation_index in kf:\n",
" print(\"TRAIN:\", train_index, \"VALIDATION:\", validation_index)\n",
"kf = cross_validation.KFold(8, n_folds=2, shuffle=True, random_state=None)\n",
"print(kf) \n",
"for train_index, validation_index in kf:\n",
" print(\"TRAIN:\", train_index, \"VALIDATION:\", validation_index)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"他にもLeaveOneOut, ShuffleSplit, StratifiedKFoldなどがmodel_selectionモジュールと同様に存在する。<br>\n",
"詳しくは[こちら](http://scikit-learn.org/0.16/modules/classes.html#module-sklearn.cross_validation)を参照してほしい。"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment