Last active
November 1, 2017 22:52
-
-
Save yamasakih/15a6457e87502ea2542421887120de24 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Random Forestを実行してみる。" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"sys.version_info(major=3, minor=6, micro=2, releaselevel='final', serial=0)\n" | |
] | |
} | |
], | |
"source": [ | |
"import pandas as pd\n", | |
"import sys\n", | |
"\n", | |
"from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor\n", | |
"from sklearn.datasets import make_classification, make_regression\n", | |
"from sklearn.metrics import accuracy_score, mean_squared_error\n", | |
"\n", | |
"print(sys.version_info)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"RandomForestはランダムに説明変数を引数`max_features`の数だけ持つ決定木を引数`n_estimator`の数だけ作り \n", | |
"コンセンサスを取ることで精度を上げるアンサンブル機械学習である。`max_features`はデフォルトではすべての説明変数の数の平方根である。 \n", | |
"(例えばクラス分類の時、精度が50%以上の分類器を複数用意しコンセンサスを取ることで精度が増すことは数学的に証明されている)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### 1. RandomForestClassifier" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"[RandomForestClassifier](http://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestClassifier.html#sklearn.ensemble.RandomForestClassifier)を用いることでRandom Forestにてクラス分類を行うことができる。" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"#### 1.1 データの準備" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Toyデータを用意する。 \n", | |
"[make_classification](http://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_classification.html#sklearn.datasets.make_classification)関数を使うと自身が望むようなクラス分類の検証のためのサンプルデータを用意することができる。 \n", | |
"今回は`サンプル数=1000`, `説明変数の数=10`, `クラス数=2` そして再現性を担保するために`random_state=0`, `shuffle=False`とした。" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
" X_clf, y_clf = make_classification(n_samples=1000, n_features=10, n_classes=2,\n", | |
" random_state=0, shuffle=False)\n", | |
"X_clf = pd.DataFrame(X_clf)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>0</th>\n", | |
" <th>1</th>\n", | |
" <th>2</th>\n", | |
" <th>3</th>\n", | |
" <th>4</th>\n", | |
" <th>5</th>\n", | |
" <th>6</th>\n", | |
" <th>7</th>\n", | |
" <th>8</th>\n", | |
" <th>9</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>-1.668532</td>\n", | |
" <td>-1.299013</td>\n", | |
" <td>0.799353</td>\n", | |
" <td>-1.559985</td>\n", | |
" <td>-3.116857</td>\n", | |
" <td>0.644452</td>\n", | |
" <td>-1.913743</td>\n", | |
" <td>0.663562</td>\n", | |
" <td>-0.154072</td>\n", | |
" <td>1.193612</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>-2.972883</td>\n", | |
" <td>-1.088783</td>\n", | |
" <td>1.953804</td>\n", | |
" <td>-1.891656</td>\n", | |
" <td>-0.098161</td>\n", | |
" <td>-0.886614</td>\n", | |
" <td>-0.147354</td>\n", | |
" <td>1.059806</td>\n", | |
" <td>0.026247</td>\n", | |
" <td>-0.114335</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>-0.596141</td>\n", | |
" <td>-1.370070</td>\n", | |
" <td>-0.105818</td>\n", | |
" <td>-1.213570</td>\n", | |
" <td>0.743554</td>\n", | |
" <td>0.210359</td>\n", | |
" <td>-0.005927</td>\n", | |
" <td>1.366060</td>\n", | |
" <td>1.555114</td>\n", | |
" <td>0.613326</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>-1.068947</td>\n", | |
" <td>-1.175057</td>\n", | |
" <td>0.363982</td>\n", | |
" <td>-1.247739</td>\n", | |
" <td>-0.285959</td>\n", | |
" <td>1.496911</td>\n", | |
" <td>1.183120</td>\n", | |
" <td>0.718897</td>\n", | |
" <td>-1.216077</td>\n", | |
" <td>0.140672</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>-1.305269</td>\n", | |
" <td>-0.965926</td>\n", | |
" <td>0.647043</td>\n", | |
" <td>-1.183939</td>\n", | |
" <td>-0.743672</td>\n", | |
" <td>-0.159012</td>\n", | |
" <td>0.240057</td>\n", | |
" <td>0.100159</td>\n", | |
" <td>-0.475175</td>\n", | |
" <td>1.272954</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" 0 1 2 3 4 5 6 \\\n", | |
"0 -1.668532 -1.299013 0.799353 -1.559985 -3.116857 0.644452 -1.913743 \n", | |
"1 -2.972883 -1.088783 1.953804 -1.891656 -0.098161 -0.886614 -0.147354 \n", | |
"2 -0.596141 -1.370070 -0.105818 -1.213570 0.743554 0.210359 -0.005927 \n", | |
"3 -1.068947 -1.175057 0.363982 -1.247739 -0.285959 1.496911 1.183120 \n", | |
"4 -1.305269 -0.965926 0.647043 -1.183939 -0.743672 -0.159012 0.240057 \n", | |
"\n", | |
" 7 8 9 \n", | |
"0 0.663562 -0.154072 1.193612 \n", | |
"1 1.059806 0.026247 -0.114335 \n", | |
"2 1.366060 1.555114 0.613326 \n", | |
"3 0.718897 -1.216077 0.140672 \n", | |
"4 0.100159 -0.475175 1.272954 " | |
] | |
}, | |
"execution_count": 3, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"X_clf.head()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([0, 0, 0, 0, 0])" | |
] | |
}, | |
"execution_count": 4, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"y_clf[:5]" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"サンプルサイズと説明変数の数は`shape`に保存されている。" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(1000, 10)" | |
] | |
}, | |
"execution_count": 5, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"X_clf.shape" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"#### 1.2 機械学習" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"##### 1.2.1 RandomForestClassifierの使い方" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"決定木の数である`n_estimators=10`, 決定木の最大の深さ`max_depth=2`, \n", | |
"そして再現性を担保するために`random_state=0`で機械学習を行う。" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"RandomForestClassifier(bootstrap=True, class_weight=None, criterion='gini',\n", | |
" max_depth=2, max_features='auto', max_leaf_nodes=None,\n", | |
" min_impurity_split=1e-07, min_samples_leaf=1,\n", | |
" min_samples_split=2, min_weight_fraction_leaf=0.0,\n", | |
" n_estimators=10, n_jobs=1, oob_score=False, random_state=0,\n", | |
" verbose=0, warm_start=False)" | |
] | |
}, | |
"execution_count": 6, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"clf = RandomForestClassifier(n_estimators=10, max_depth=2, random_state=0)\n", | |
"clf.fit(X_clf, y_clf)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"`clf.feature_importances_`に各説明変数の重要度が保存されている。 \n", | |
"値が高ければ高いほど重要度が高い説明変数といえる。" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[ 0.05107855 0.44565534 0.09075854 0.36217254 0.00492064 0.\n", | |
" 0.00199749 0.04153638 0.00188052 0. ]\n" | |
] | |
} | |
], | |
"source": [ | |
"print(clf.feature_importances_)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"予測には`predict`メソッドを用いる。 \n", | |
"例えば10個の説明変数がすべて0のサンプルの予測を行いたい時は以下のようにする。" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[1]\n" | |
] | |
} | |
], | |
"source": [ | |
"print(clf.predict([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Training setの正答率を求めてみる。" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"[accuracy_score](http://scikit-learn.org/stable/modules/generated/sklearn.metrics.accuracy_score.html#sklearn.metrics.accuracy_score)関数を用いてTraining setの正答率を求めてみる。" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"95.399999999999991" | |
] | |
}, | |
"execution_count": 9, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"y_predicted = clf.predict(X_clf)\n", | |
"accuracy_score(y_predicted, y_clf) * 100" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"正答率は約95.4%であった。" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"##### 1.2.2 決定木の数を増やす" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"決定木の数である`n_estimators`を100に増やして同様の解析を行ってみる。" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"RandomForestClassifier(bootstrap=True, class_weight=None, criterion='gini',\n", | |
" max_depth=2, max_features='auto', max_leaf_nodes=None,\n", | |
" min_impurity_split=1e-07, min_samples_leaf=1,\n", | |
" min_samples_split=2, min_weight_fraction_leaf=0.0,\n", | |
" n_estimators=100, n_jobs=1, oob_score=False, random_state=0,\n", | |
" verbose=0, warm_start=False)" | |
] | |
}, | |
"execution_count": 10, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"clf = RandomForestClassifier(n_estimators=100, max_depth=2, random_state=0)\n", | |
"clf.fit(X_clf, y_clf)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[ 0.02802468 0.46142492 0.12352124 0.34518939 0.00193351 0.0029199\n", | |
" 0.00777319 0.00947897 0.0108007 0.00893351]\n" | |
] | |
} | |
], | |
"source": [ | |
"print(clf.feature_importances_)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"95.799999999999997" | |
] | |
}, | |
"execution_count": 12, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"y_predicted = clf.predict(X_clf)\n", | |
"accuracy_score(y_predicted, y_clf) * 100" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"正答率は95.8%であった。わずかであるが決定木の数を増やしよりコンセンサスを取りやすくすることで正答率があがった。" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"##### 1.2.3 決定木の複雑さを増やす" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"決定木の最大の深さである`max_depth`を3に増やして同様の解析を行ってみる。" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"RandomForestClassifier(bootstrap=True, class_weight=None, criterion='gini',\n", | |
" max_depth=3, max_features='auto', max_leaf_nodes=None,\n", | |
" min_impurity_split=1e-07, min_samples_leaf=1,\n", | |
" min_samples_split=2, min_weight_fraction_leaf=0.0,\n", | |
" n_estimators=10, n_jobs=1, oob_score=False, random_state=0,\n", | |
" verbose=0, warm_start=False)" | |
] | |
}, | |
"execution_count": 13, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"clf = RandomForestClassifier(n_estimators=10, max_depth=3, random_state=0)\n", | |
"clf.fit(X_clf, y_clf)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"96.099999999999994" | |
] | |
}, | |
"execution_count": 14, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"y_predicted = clf.predict(X_clf)\n", | |
"accuracy_score(y_predicted, y_clf) * 100" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"正答率は約96.1%であった。わずかであるが決定木をより複雑にできるようにすることで正答率があがった。" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"ただし、冒頭でも伝えたようにRandomForestは複数の分類器によるコンセンサスで精度を上げるため \n", | |
"一つ一つの分類機は弱学習器でよく、最大の深さである`max_depth`はあまり変化させる必要がない。" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"むしろ、以下のように決定木に使う説明変数の数である`max_features`の数を変化させて最適化することのほうが多い。" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"##### 1.2.4 決定木で用いる説明変数の数を変化させる" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"決定木に用いる説明変数の数である`max_features`の数を3に減らしてみる。 \n", | |
"`max_features`はデフォルトでは全ての説明変数の数の平方根である。" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"metadata": { | |
"scrolled": true | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"RandomForestClassifier(bootstrap=True, class_weight=None, criterion='gini',\n", | |
" max_depth=None, max_features=3, max_leaf_nodes=None,\n", | |
" min_impurity_split=1e-07, min_samples_leaf=1,\n", | |
" min_samples_split=2, min_weight_fraction_leaf=0.0,\n", | |
" n_estimators=10, n_jobs=1, oob_score=False, random_state=0,\n", | |
" verbose=0, warm_start=False)" | |
] | |
}, | |
"execution_count": 15, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"clf = RandomForestClassifier(n_estimators=10, max_features=3, random_state=0)\n", | |
"clf.fit(X_clf, y_clf)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"metadata": { | |
"scrolled": true | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"99.799999999999997" | |
] | |
}, | |
"execution_count": 16, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"y_predicted = clf.predict(X_clf)\n", | |
"accuracy_score(y_predicted, y_clf) * 100" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"続けて、決定木に用いる説明変数の数である`max_features`の数を8に増やしてみる。 " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"metadata": { | |
"scrolled": true | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"RandomForestClassifier(bootstrap=True, class_weight=None, criterion='gini',\n", | |
" max_depth=None, max_features=8, max_leaf_nodes=None,\n", | |
" min_impurity_split=1e-07, min_samples_leaf=1,\n", | |
" min_samples_split=2, min_weight_fraction_leaf=0.0,\n", | |
" n_estimators=10, n_jobs=1, oob_score=False, random_state=0,\n", | |
" verbose=0, warm_start=False)" | |
] | |
}, | |
"execution_count": 17, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"clf = RandomForestClassifier(n_estimators=10, max_features=8, random_state=0)\n", | |
"clf.fit(X_clf, y_clf)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"metadata": { | |
"scrolled": true | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"99.700000000000003" | |
] | |
}, | |
"execution_count": 18, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"y_predicted = clf.predict(X_clf)\n", | |
"accuracy_score(y_predicted, y_clf) * 100" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"今回の例では用いる説明変数の数を減らすことで正答率があがった。 \n", | |
"用いる説明変数の数が多すぎると分類器に差がなくなることもあり単純に増やせばよいというわけではない。 \n", | |
"説明変数の数を色々変化させてcross validationを行うのが最も良い。" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"---" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### 2. RandomForestRegressor" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"[RandomForestRegressor](http://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestRegressor.html#sklearn.ensemble.RandomForestRegressor)を用いることでRandom Forestにて回帰分析を行うことができる。" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"#### 2.1 データの準備" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Toyデータを用意する。 \n", | |
"[make_regression](http://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_regression.html)関数を使うと自身が望むような回帰分析の検証のためのサンプルデータを用意することができる。 \n", | |
"今回は`サンプル数=1000`, `説明変数の数=10`, そして再現性を担保するために`random_state=0`, `shuffle=False`とした。" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
" X_regr, y_regr = make_regression(n_samples=1000, n_features=10,\n", | |
" random_state=0, shuffle=False)\n", | |
"X_regr = pd.DataFrame(X_regr)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 20, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>0</th>\n", | |
" <th>1</th>\n", | |
" <th>2</th>\n", | |
" <th>3</th>\n", | |
" <th>4</th>\n", | |
" <th>5</th>\n", | |
" <th>6</th>\n", | |
" <th>7</th>\n", | |
" <th>8</th>\n", | |
" <th>9</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>1.764052</td>\n", | |
" <td>0.400157</td>\n", | |
" <td>0.978738</td>\n", | |
" <td>2.240893</td>\n", | |
" <td>1.867558</td>\n", | |
" <td>-0.977278</td>\n", | |
" <td>0.950088</td>\n", | |
" <td>-0.151357</td>\n", | |
" <td>-0.103219</td>\n", | |
" <td>0.410599</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>0.144044</td>\n", | |
" <td>1.454274</td>\n", | |
" <td>0.761038</td>\n", | |
" <td>0.121675</td>\n", | |
" <td>0.443863</td>\n", | |
" <td>0.333674</td>\n", | |
" <td>1.494079</td>\n", | |
" <td>-0.205158</td>\n", | |
" <td>0.313068</td>\n", | |
" <td>-0.854096</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>-2.552990</td>\n", | |
" <td>0.653619</td>\n", | |
" <td>0.864436</td>\n", | |
" <td>-0.742165</td>\n", | |
" <td>2.269755</td>\n", | |
" <td>-1.454366</td>\n", | |
" <td>0.045759</td>\n", | |
" <td>-0.187184</td>\n", | |
" <td>1.532779</td>\n", | |
" <td>1.469359</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>0.154947</td>\n", | |
" <td>0.378163</td>\n", | |
" <td>-0.887786</td>\n", | |
" <td>-1.980796</td>\n", | |
" <td>-0.347912</td>\n", | |
" <td>0.156349</td>\n", | |
" <td>1.230291</td>\n", | |
" <td>1.202380</td>\n", | |
" <td>-0.387327</td>\n", | |
" <td>-0.302303</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>-1.048553</td>\n", | |
" <td>-1.420018</td>\n", | |
" <td>-1.706270</td>\n", | |
" <td>1.950775</td>\n", | |
" <td>-0.509652</td>\n", | |
" <td>-0.438074</td>\n", | |
" <td>-1.252795</td>\n", | |
" <td>0.777490</td>\n", | |
" <td>-1.613898</td>\n", | |
" <td>-0.212740</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" 0 1 2 3 4 5 6 \\\n", | |
"0 1.764052 0.400157 0.978738 2.240893 1.867558 -0.977278 0.950088 \n", | |
"1 0.144044 1.454274 0.761038 0.121675 0.443863 0.333674 1.494079 \n", | |
"2 -2.552990 0.653619 0.864436 -0.742165 2.269755 -1.454366 0.045759 \n", | |
"3 0.154947 0.378163 -0.887786 -1.980796 -0.347912 0.156349 1.230291 \n", | |
"4 -1.048553 -1.420018 -1.706270 1.950775 -0.509652 -0.438074 -1.252795 \n", | |
"\n", | |
" 7 8 9 \n", | |
"0 -0.151357 -0.103219 0.410599 \n", | |
"1 -0.205158 0.313068 -0.854096 \n", | |
"2 -0.187184 1.532779 1.469359 \n", | |
"3 1.202380 -0.387327 -0.302303 \n", | |
"4 0.777490 -1.613898 -0.212740 " | |
] | |
}, | |
"execution_count": 20, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"X_regr.head()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 21, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([ 300.2064792 , 194.27686437, 13.25754564, -24.6027897 ,\n", | |
" -98.76061926])" | |
] | |
}, | |
"execution_count": 21, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"y_regr[:5]" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"サンプルサイズと説明変数の数は`shape`に保存されている。" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 22, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(1000, 10)" | |
] | |
}, | |
"execution_count": 22, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"X_regr.shape" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"#### 2.2 機械学習" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"##### 2.2.1 RandomForestRegressorの使い方" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"決定木の数である`n_estimators=10`, 決定木の最大の深さ`max_depth=2`, \n", | |
"そして再現性を担保するために`random_state=0`で機械学習を行う。" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 23, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"RandomForestRegressor(bootstrap=True, criterion='mse', max_depth=2,\n", | |
" max_features='auto', max_leaf_nodes=None,\n", | |
" min_impurity_split=1e-07, min_samples_leaf=1,\n", | |
" min_samples_split=2, min_weight_fraction_leaf=0.0,\n", | |
" n_estimators=10, n_jobs=1, oob_score=False, random_state=0,\n", | |
" verbose=0, warm_start=False)" | |
] | |
}, | |
"execution_count": 23, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"regr = RandomForestRegressor(n_estimators=10, max_depth=2, random_state=0)\n", | |
"regr.fit(X_regr, y_regr)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"`clf.feature_importances_`に各説明変数の重要度が保存されている。 \n", | |
"値が高ければ高いほど重要度が高い説明変数といえる。" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 24, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[ 0. 0. 0.02265102 0.50486712 0. 0.\n", | |
" 0.31572423 0.15675763 0. 0. ]\n" | |
] | |
} | |
], | |
"source": [ | |
"print(regr.feature_importances_)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"予測には`predict`メソッドを用いる。 \n", | |
"例えば10個の説明変数がすべて0のサンプルの予測を行いたい時は以下のようにする。" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 25, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[-32.17449828]\n" | |
] | |
} | |
], | |
"source": [ | |
"print(regr.predict([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"[mean_squared_error](http://scikit-learn.org/stable/modules/generated/sklearn.metrics.mean_squared_error.html)関数を用いてTraining setのMean Square Error (MSE) を求めてみる。" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 26, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"11343.859988077465" | |
] | |
}, | |
"execution_count": 26, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"y_predicted = regr.predict(X_regr)\n", | |
"mean_squared_error(y_regr, y_predicted)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"MSEは約11343.9であった。" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"##### 2.2.2 決定木の数を増やす" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"決定木の数である`n_estimators`を100に増やして同様の解析を行ってみる。" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 27, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"RandomForestRegressor(bootstrap=True, criterion='mse', max_depth=2,\n", | |
" max_features='auto', max_leaf_nodes=None,\n", | |
" min_impurity_split=1e-07, min_samples_leaf=1,\n", | |
" min_samples_split=2, min_weight_fraction_leaf=0.0,\n", | |
" n_estimators=100, n_jobs=1, oob_score=False, random_state=0,\n", | |
" verbose=0, warm_start=False)" | |
] | |
}, | |
"execution_count": 27, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"regr = RandomForestRegressor(n_estimators=100, max_depth=2, random_state=0)\n", | |
"regr.fit(X_regr, y_regr)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 28, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[ 0. 0. 0.02318644 0.46182728 0. 0.00220828\n", | |
" 0.30260993 0.21016807 0. 0. ]\n" | |
] | |
} | |
], | |
"source": [ | |
"print(regr.feature_importances_)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 29, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"11328.125407195668" | |
] | |
}, | |
"execution_count": 29, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"y_predicted = regr.predict(X_regr)\n", | |
"mean_squared_error(y_regr, y_predicted)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"MSEは約11328.1であった。わずかであるが決定木の数を増やしよりコンセンサスを取りやすくすることでMSEが小さくなった。" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"##### 2.2.3 決定木の複雑さを増やす" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"決定木の最大の深さである`max_depth`を3に増やして同様の解析を行ってみる。" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 30, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"RandomForestRegressor(bootstrap=True, criterion='mse', max_depth=3,\n", | |
" max_features='auto', max_leaf_nodes=None,\n", | |
" min_impurity_split=1e-07, min_samples_leaf=1,\n", | |
" min_samples_split=2, min_weight_fraction_leaf=0.0,\n", | |
" n_estimators=10, n_jobs=1, oob_score=False, random_state=0,\n", | |
" verbose=0, warm_start=False)" | |
] | |
}, | |
"execution_count": 30, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"regr = RandomForestRegressor(n_estimators=10, max_depth=3, random_state=0)\n", | |
"regr.fit(X_regr, y_regr)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 31, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"8462.4899704699292" | |
] | |
}, | |
"execution_count": 31, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"y_predicted = regr.predict(X_regr)\n", | |
"mean_squared_error(y_regr, y_predicted)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"MSEは約8462.5であった。今回のデータセットでは決定木をより複雑にできるようにすることで大幅にMSEが小さくなった。" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"##### 2.2.4 決定木で用いる説明変数の数を変化させる" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"決定木に用いる説明変数の数である`max_features`の数を3に減らしてみる。 \n", | |
"`max_features`はデフォルトでは全ての説明変数の数の平方根である。" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 32, | |
"metadata": { | |
"scrolled": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"RandomForestRegressor(bootstrap=True, criterion='mse', max_depth=None,\n", | |
" max_features=3, max_leaf_nodes=None, min_impurity_split=1e-07,\n", | |
" min_samples_leaf=1, min_samples_split=2,\n", | |
" min_weight_fraction_leaf=0.0, n_estimators=10, n_jobs=1,\n", | |
" oob_score=False, random_state=0, verbose=0, warm_start=False)" | |
] | |
}, | |
"execution_count": 32, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"regr = RandomForestRegressor(n_estimators=10, max_features=3, random_state=0)\n", | |
"regr.fit(X_regr, y_regr)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 33, | |
"metadata": { | |
"scrolled": true | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"970.52943197845366" | |
] | |
}, | |
"execution_count": 33, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"y_predicted = regr.predict(X_regr)\n", | |
"mean_squared_error(y_regr, y_predicted)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"続けて、決定木に用いる説明変数の数である`max_features`の数を8に増やしてみる。 " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 34, | |
"metadata": { | |
"scrolled": true | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"RandomForestRegressor(bootstrap=True, criterion='mse', max_depth=None,\n", | |
" max_features=8, max_leaf_nodes=None, min_impurity_split=1e-07,\n", | |
" min_samples_leaf=1, min_samples_split=2,\n", | |
" min_weight_fraction_leaf=0.0, n_estimators=10, n_jobs=1,\n", | |
" oob_score=False, random_state=0, verbose=0, warm_start=False)" | |
] | |
}, | |
"execution_count": 34, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"regr = RandomForestRegressor(n_estimators=10, max_features=8, random_state=0)\n", | |
"regr.fit(X_regr, y_regr)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 35, | |
"metadata": { | |
"scrolled": true | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"873.40663827214928" | |
] | |
}, | |
"execution_count": 35, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"y_predicted = regr.predict(X_regr)\n", | |
"mean_squared_error(y_regr, y_predicted)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"今回の例では用いる説明変数の数を増やすことでMSEが小さくなった。 \n", | |
"ただ、用いる説明変数の数が多すぎると分類器に差がなくなることもあり単純に増やせばよいというわけではない。 \n", | |
"説明変数の数を色々変化させてcross validationを行うのが最も良い。" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.2" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment