Skip to content

Instantly share code, notes, and snippets.

@saboyutaka
Created December 12, 2016 08:46
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save saboyutaka/6dbd68b6a489e93025391a8c70ec1bd4 to your computer and use it in GitHub Desktop.
Save saboyutaka/6dbd68b6a489e93025391a8c70ec1bd4 to your computer and use it in GitHub Desktop.
Titanic challenge at Koza Machine Learning Bootcamp
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"# kaggle - Titanic: Machine Learning from Disaster"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": false,
"slideshow": {
"slide_type": "slide"
}
},
"outputs": [],
"source": [
"import pandas as pd\n",
"import warnings\n",
"import matplotlib.pyplot as plt\n",
"%matplotlib inline"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 変数\n",
"\n",
"## 独立変数(説明変数)\n",
"\n",
"- PassengerId: 乗客ID\n",
"- Pclass: 客室の等級(1st, 2nd , 3rd)\n",
"- Name: 名前\n",
"- Sex: 性別\n",
"- Age: 年齢\n",
"- SibSp: 共に乗船していた兄弟(siblings)や 配偶者(spouses)の数\n",
"- Parch: 共に乗船していた親(parents)や子供(children)の数\n",
"- Ticket: チケットのタイプ\n",
"- Fare: チケットの料金\n",
"- Cabin: 客室番号\n",
"- Embarked: 乗船港(**Q**ueenstown, **C**herbourg, **S**outhampton)\n",
"\n",
"## 従属変数(目的変数)\n",
"- Survived:生存者かどうか(1: 助かった、0:助からなかった)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Pandasで下ごしらえ"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": false,
"scrolled": true
},
"outputs": [],
"source": [
"# データの読み込み(トレーニングデータとテストデータにすでに分かれていることに注目)\n",
"df_train = pd.read_csv('../input/train.csv') # トレーニングデータ\n",
"df_test = pd.read_csv('../input/test.csv') # テストデータ"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# SexId を追加\n",
"df_train['SexId'] = df_train['Sex'].map({'male': 1, 'female': 0})\n",
"df_test['SexId'] = df_test['Sex'].map({'male': 1, 'female': 0})"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# FamilySize = SibSp + Parch\n",
"df_train['FamilySize'] = df_train['SibSp'] + df_train['Parch']\n",
"df_test['FamilySize'] = df_test['SibSp'] + df_test['Parch']\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# Ageの欠損値保管\n",
"df_train['AgeNull'] = df_train['Age'].isnull()\n",
"age_median = df_train['Age'].median()\n",
"df_train['Age'].fillna(age_median, inplace=True)\n",
"df_test['Age'].fillna(age_median, inplace=True)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# Embarked\n",
"common_embarked = df_train['Embarked'].value_counts().index[0]\n",
"df_train['Embarked'].fillna(common_embarked, inplace=True)\n",
"df_test['Embarked'].fillna(common_embarked, inplace=True)\n",
"df_train['EmbarkedNum'] = df_train.Embarked.map({'S': 0, 'C': 1, 'Q': 2})\n",
"df_test['EmbarkedNum'] = df_test.Embarked.map({'S': 0, 'C': 1, 'Q': 2})"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# inputs = ['FamilySize', 'SexId', 'Age', 'EmbarkedNum']\n",
"inputs = ['FamilySize', 'SexId', 'Age']"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"X_train = df_train[inputs].values.astype('float32')\n",
"X_test = df_test[inputs].values.astype('float32')"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"y_train = df_train['Survived'].values"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"if df_train.columns.values.__contains__('PassengerId'):\n",
" df_train.index = df_train.pop('PassengerId') "
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"if df_test.columns.values.__contains__('PassengerId'):\n",
" df_test.index = df_test.pop('PassengerId')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 機械学習"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### モデル選択\n"
]
},
{
"cell_type": "code",
"execution_count": 63,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# 分類モデルの読み込み\n",
"\n",
"# ロジスティック回帰\n",
"from sklearn.linear_model import LogisticRegression\n",
"# K最近傍法\n",
"from sklearn.neighbors import KNeighborsClassifier\n",
"# サーポートベクターマシン\n",
"from sklearn.svm import SVC\n",
"# 決定木\n",
"from sklearn.tree import DecisionTreeClassifier\n",
"# ランダムフォレスト\n",
"from sklearn.ensemble import RandomForestClassifier\n",
"# 勾配ブースティング\n",
"from sklearn.ensemble import GradientBoostingClassifier"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# 交差検証用モジュールの読み込み\n",
"\n",
"from sklearn.cross_validation import KFold, StratifiedKFold, cross_val_score"
]
},
{
"cell_type": "code",
"execution_count": 64,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# 複数の分類器を用意\n",
"classifiers = [\n",
" ('lr', LogisticRegression()), \n",
" ('knn', KNeighborsClassifier()),\n",
" ('linear svc', SVC(kernel=\"linear\")),\n",
" ('rbf svc', SVC(gamma=2)),\n",
" ('dt', DecisionTreeClassifier()),\n",
" ('rf', RandomForestClassifier(random_state=42)),\n",
" ('gbc', GradientBoostingClassifier())\n",
"]"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"collapsed": false,
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.791 (0.018): time 0.12s, lr\n",
"0.765 (0.012): time 0.01s, knn\n",
"0.787 (0.019): time 0.18s, linear svc\n",
"0.773 (0.015): time 0.08s, rbf svc\n",
"0.779 (0.016): time 0.01s, dt\n",
"0.781 (0.020): time 0.14s, rf\n",
"0.823 (0.016): time 0.32s, gbc\n"
]
}
],
"source": [
"# それぞれのモデルに対して、交差検証(CV)をかける\n",
"import time\n",
"results = {}\n",
"exec_times = {}\n",
"\n",
"for name, model in classifiers:\n",
" tic = time.time()\n",
" result = cross_val_score(model, X_train, y_train, cv=5, scoring='accuracy')\n",
" exec_time = time.time() - tic\n",
" exec_times[name] = exec_time\n",
" results[name] = result\n",
" \n",
" print(\"{0:.3f} ({1:.3f}): time {2:.2f}s, {3}\".format(result.mean(), result.std(), exec_time, name))"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# 結果をデータフレームに入れる\n",
"df_results = pd.DataFrame(results)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZwAAAD7CAYAAABexyJvAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAGLJJREFUeJzt3X+QJHd53/H3R0II9Ivdw5XYFvEtxDGE2KdF4kc5gFkB\nicFCOE6wxY8KGpxKqEBZyJFckEpUN0elQoEjJFeEo6SwtWcDsmQJjBSMDZY1nA0xErobTshIdqKc\nwKicMtydkGJbkeQnf0zv3dxqd7Sz29Pd32c/r6otTfd093w/M6N+dr5P36wiAjMzs1k7qe0BmJnZ\n9uCCY2ZmjXDBMTOzRrjgmJlZI1xwzMysES44ZmbWiKe1PYAmSfI14GZmmxAR2uoxtt0nnIhI+7N7\n9+7Wx+B8zrcd82XOFlHf7+nbruBkdujQobaHMFPOV7bM+TJnq5MLjpmZNcIFJ5Fer9f2EGbK+cqW\nOV/mbHVSnfNzXScptlNeM7M6SCJ80YCNGwwGbQ9hppyvbJnzZc5WJxccMzNrhKfUzLaRHTvgyJG2\nRwGBELP/f3F+Hg4fnvnDpFfXlJoLjtk2IkEn/hdoaCCdyVs493DsSbLPIztf2TLny5ytTsUVHEm3\nSzq37XGYmdl0iptSk3Q7cFlE7N/Evp5Ss6JVUxtb2L8jU0yeUtuUrb7+W3zc3FNqkq6QdK+kfZI+\nIemy6q63Szog6aCkl1Tbni7pV6t1Q0k/1eLQzcxslc4WHEkvBn4K+BHgJ4AXw7HLWp4ZES8C3g38\narXuCuBoROyKiEXg9xsecuuyzyM7X9ky58ucrU5d/vMELwc+HRGPAY9JugUQo6JzPUBE/IGkMyU9\nC3gtcNHKzhHxUAtjNjOzdXT2E84axucPV09i/s1GD9Lr9ej3+/T7fa6++uoTfjMZDAZFL6+s68p4\nnK/+fDDqS4x+BvR6a2/f74/uH/0c3x66na/uZTgx/8rz0e+vvX2vt7ntzz9/aabHP/76nfhczur5\nGz1279j5si6dvWigmlK7ltEnnVOAu4D/ClwIfD0i3iXpFcBHIuIcSR8ATo2If1PtPxcRR1cd0xcN\nWNF80UAnH6YxvmhgRiLiK8AtwFeBzwAHgYcYfbr5a0n7gV8Gfrba5T8AOyTdLekAsNT4oFu2+rfJ\nbJyvbJnzZc5Wpy73cACujIj3S3omsA+4KyJ+Za0NI+L/Ar0mB2dmZhvX2Sk1AEkfB14InAosR8SH\ntng8T6lZ0Tyl1smHaUzpU2qdLjh1c8Gx7a4zJ2AXnKKk7+HY9LLPIztf2TLny5ytTi44ZmbWCE+p\nmW0j2vKkSD3893DKUteUWtevUjOzGnXn961ooNxY13hKLZHs88jOV7bM+TJnq5MLjpmZNcI9HDMz\nm8iXRZuZWVFccBLJPo/sfGXLnC9ztjq54JiZWSPcwzEzs4ncwzEzs6K44CSSfR7Z+cqWOV/mbHVy\nwTEzs0a4h2NmZhO5h2NmZkVxwUkk+zyy85Utc77M2erkb4u2bWfHDjhypLnHm9VX8fur96007uHY\nttP4nx2e0QP6zydbU9zDMTOzorjgJJJ9Htn5ypY5X+ZsdUpRcCQ93PYYrCQd+TvLLVNX/t60bRsp\nejiSvhsRZ61ad3JEPLFqXaoezmAwYGlpqe1hFKeaj27yATvZw2n8ebBiuYezBkmvkrRP0qeBe9oe\nz6z5Y7yZlSRVwam8CPi5iHhB2wNpWvYC5Hxly5wvc7Y6Zfx3OHdExDfWu7PX67GwsADA3Nwci4uL\nx6alVt40XV4eDoccPXoUgD179nDo0CEWFhZYWlpiOBy2Pr5ZLteVb0Vj49/g402bDwaM2jDHlwF2\n716i33/y9r3egL17x7c/cVq2lNfPy7NfHgwGLC8vAxw7X9YhVQ9H0quAyyLijetsl6qH0+/36ff7\nbQ+jOO7hrOzvHo5tjHs4J/LlNmZmHZel4GzLX9OOT62MrJ4yysb5ypY5X+ZsdUrRw1m5JDoivgB8\noeXhNGZ1wTEz67IUPZyNytbDsc3xd6mZTcc9HDMzK4oLTiLZ55Gdr2yZ82XOVqcUPRyzaTX5NWIx\no8ebn6//mGaz5B6OmZlN5B6OmZkVxQUnkezzyM5Xtsz5MmerkwuOmZk1wj0cMzObyD0cMzMrigtO\nItnnkZ2vbJnzZc5WJxccMzNrhHs4ZmY2kXs4ZmZWFBecRLLPIztf2TLny5ytTi44ZmbWCPdwzMxs\nIvdwzMysKC44iWSfR3a+smXOlzlbnfz3cGxb2LEDjhwZ3Q6EKH9qdX4eDh9uexRmG+cejm0LEhx7\n6U9YKFeSGFYA93DMzKwoLjiJZJ9Hdr6yZc6XOVudnrLgSHq4+u/3Sbpx9kMym4UtzwZ0UMZMltlT\n9nAkfTcizmpkMNLJEfHEDI/fqR7OYDBgaWmp7WFsC9Uc9MpCiubHCZnMZqjxHo6knZLurm5fLOlm\nSZ+VdJ+kD45t948kfUnSVyTdIOm0av0Vkr4s6aCka8e2v13SVZLuAC5Z9Zg/JumApP2S7pJ0uqTr\nJb1+bJvrJP1TSSdJ+k+S7pY0lPTuLTwvjfDHcDPbTqbt4Yz/OnUO8NPALuAiSWdLejbw74HXRMSL\ngbuAy6rt/3NEvCwidgGnSbpg7FinRMRLI+KqVY93OfCuiDgXeCXw18ANwEUAkk4BXg18Bngn8APA\nrohYBD4+ZbbiZS9gzle2zPkyZ6vTVv4dzm0R8QiApHuAncA88ELgi5IEnAL8j2r710j6BeC0aruv\nMSoUMCoia/kicJWkjwOfjIhvSfoscHVVbF4P7IuIRyW9BvgvK3NmEXF0rQP2ej0WFhYAmJubY3Fx\n8di01sqbZpbLw+GQo0dHQ9uzZw+HDh1iYWHhhKm1zR5/OBzOfPxtLteWDzqRZ6v5VtZ1Zfzb/f2Z\naXkwGLC8vAxw7HxZhw33cCTtBG6NiF2SLgbOi4hLqm1uBX4ROAt4S0S8bdUxTgUeAM6NiAcl7QYi\nIt4v6XbgsojYv87j/wPgAuBdwD+OiD+RtAzcDLwZuD4i/rukmxgVnNsmZOlUD6ff79Pv99sexrbg\nHo7Z5jXZw5nmQf4IeLmkvwsg6TRJfw94BqPpuO9IOgN400YOJul5EXFPRHwIuBN4QXXXjcA7gFcA\nv1Ot+zzwTkknV/vOTzFuMzObsY0UnI38CrUyjfVtoAdcL+mrwJeA50fEQ8BHgXuAzwJ3bPD4l65c\nBAD8v2pfgM8BPwZ8PiIer9Z9FPgmcFDSAeAtGxh3q8anRuqw8pE4K+crW+Z8mbPV6Sl7OCuXREfE\nA4wuECAi9gJ7x7Z549jtAfDSNY5zBXDFGutfPeGxL1ln/ePA96xa9wSjCxQuW2ufLqq74JiZdZm/\nS822BX+Xmtnm+bvUzMysKC44iWSfR3a+smXOlzlbnfz3cGzbUDUhEGO3Szbv6zCtMO7hmJnZRO7h\nmJlZUVxwEsk+j+x8ZcucL3O2OrngmJlZI9zDMTOzidzDMTOzorjgJJJ9Htn5ypY5X+ZsdXLBMTOz\nRriHY2ZmE7mHY2ZmRXHBSST7PLLzlS1zvszZ6uSCY2ZmjXAPx8zMJnIPx8zMiuKCk0j2eWTnK1vm\nfJmz1cl/D8dat2MHHDky3T6BEKPp0fl5OHx4BgMzs1q5h2Otk2Dql2Vsp03tb2Yb5h6OmZkVxQUn\nkezzyM5Xtsz5MmerU5oejqTdwCPAt4HfjYg/b3lItmECtjInttX9zawJaXo4YwXnDcDlEXHXGtt0\nqoczGAxYWlpqexitq+aHp91prIezif0L4veJtc09HEDSv5N0n6R9wPOr1S8GPiZpv6RTWxzeU/LH\ncNsIv08si2ILjqRzgZ8BdgEXAC9hNK9yJ/C2iDg3Ih5tcYiNy35icr6yZc6XOVudSu7hvBL4VFVU\nHpX0aUaT+RP1ej0WFhYAmJubY3Fx8dh0xcqbZpbLw+GQo0ePArBnzx4OHTrEwsLCCVMmmz3+cDic\n+fhnsbxis/k2un/Xl8fzDQYDlpeXAdi7dy8Ahw4dYnFxkUsvvbQT490u78/tuDz+/ls5X9ah2B6O\npPcA8xHRr5avBB5k9Gnn8ojYv8Y+nerh9Pt9+v1+28NonXs4k/l9Ym1zDwf2Af9E0qmSzgQuZDSl\n9jBwVqsjMzOzJym24ETEAeAG4CDwGeCO6q5l4NoSLhpY+Shbl9VTTNls13x1v0/akvn1y5ytTiX3\ncIiIDwAfWOOuTzU9ls3IciKx2fL7xLIotoezGV3r4diIv0vNrNvcwzEzs6K44CSSfR7Z+cqWOV/m\nbHUquodjeWjKD+sxts/8fO3DMbMZcA/HzMwmcg/HzMyK4oKTSPZ5ZOcrW+Z8mbPVyQXHzMwa4R6O\nmZlN5B6OmZkVxQUnkezzyM5Xtsz5MmerkwuOmZk1wj0cMzObyD0cMzMrigtOItnnkZ2vbJnzZc5W\nJxccMzNrhHs4ZmY2kXs4ZmZWFBecRLLPIztf2TLny5ytTv57ONaKHTvgyJGNbx+I87l9dgPapPl5\nOHy47VGYlcE9HGuFBFO9FFPv0IyODsusVu7hmJlZUVxwEsk+j+x8ZcucL3O2OqUsOJLeJOmPJd3W\n8OM2+XBmU/N71NqUroej0f9RvwPsiYgvrbpvpj2cap5zZsfPxD2cdvg9apvhHs4YSTsl3StpL/AE\n8FrgVyR9sOWhmZlZJUXBqfwgcE1EnAR8AXhrRLy35TE1Kvs8svOVLXO+zNnqlKngPBARd1a3Vf08\nSa/Xo9/v0+/3ufrqq094owwGgy0tA0gDJOj3196+1xvdP/oZ1Lr9cDic6fHr3H4zz+9wOKz19apj\nGep7/tbKV/frM67p56uLr5+X114eDAb0er1j58u6pOjhSNoJ3BoRu6rl24HLImL/qu3cw+kI93Da\n4feobYZ7OE/my2/MzDosU8GJdW5vG+MfjzNyvrJlzpc5W51SfJdaRDwA7BpbfnWLwzEzszWk6OFs\nlL9LrTvcwzErh3s4ZmZWFBecRLLPIztf2TLny5ytTil6OFamab7WK4Dzz5/ZUDZtfr7tEZiVwz0c\nMzObyD0cMzMrigtOItnnkZ2vbJnzZc5WJxccMzNrhHs4ZmY2kXs4ZmZWFBecRLLPIztf2TLny5yt\nTi44ZmbWCPdwzMxsIvdwzMysKC44iWSfR3a+smXOlzlbnVxwzMysEe7hmJnZRO7hmJlZUVxwEsk+\nj+x8ZcucL3O2Ovnv4RRsxw44cqSdxw6EyDE9OT8Phw+3PQqz/NzDKZgErcVp9cHrlSiK2Uy4h2Nm\nZkVxwUkk+zyy85Utc77M2erUaMGR9PA6658v6YCkuyQ9t8kxlUTa8idaq5lfE7ONa6yHo9H/mQ9F\nxFlr3Pde4OSI+I8zHkPRPZxqHnVs2T2cOmwlyurXxCyjzvdwJO2UdK+kvZLuBv7OaLU+LOlrkj4v\n6dmSXg9cCvxrSbetOsZJkq6TdFDSVyW9p/o09OVVj3Owuv0SSV+UNJT0R5JOn1U+MzObzqyn1H4Q\nuCYifiQivgGcDtwRET8M7AN2R8RngWuBqyLiNav2XwTOjohdEXEOcF1E3AecImlntc1FwPWSTgF+\nA/i5iFgEXgv81YzzdUr2eWTnK1vmfJmz1WnW/w7ngYi4c2z5CeDG6vbHgJufYv/7gedK+iXgt4HP\nVetvZFRoPlT992eA5wMPRsR+gIh4ZK0D9no9FhYWAJibm2NxcZGlpSXg+Jumq8sA0gBYWR4wGBy/\nfzgcdmq8dS/PKh8sMWrFHF8GuPjiAb3ek7cfDJbYs2d8ezqdryvL2fNlWh4MBiwvLwMcO1/WYWY9\nnOoTyK0RsWts3WPAqRHxN9XFATdFxHmSdgMPR8SH1zjOacCPA/8cOBIR/0LS84DfBN4MfCIiXiLp\nh4FrI+IVE8bkHk59g3EPB/dwbHvofA+nsnqAJwNvqm6/DfjDiTtLz2Z0McGngCuAFwFExP2MPi1d\nAdxQbX4f8L2Szqv2PUOSL/s2M+uIWZ+QV//q9wjw0uoigiXg/U+x/9nAQNIB4NeB943ddwOjonUj\nQEQ8xmh67RpJQ0bTb8/YaoCSHJ8iysn5ypY5X+ZsdZpZDyciHgB2rVq3ckn05avW71nnGAeB89a5\n70rgylXr7gJ+dJNDNjOzGfJ3qRXMPZx6JIpiNhOl9HDMzMwAF5xUss8jO1/ZMufLnK1O/ns4hWvr\nq7yixceu2/x82yMw2x7cwzEzs4ncwzEzs6K44CSSfR7Z+cqWOV/mbHVywTEzs0a4h2NmZhO5h2Nm\nZkVxwUkk+zyy85Utc77M2erkgmNmZo1wD8fMzCZyD8fMzIrigpNI9nlk5ytb5nyZs9XJBcfMzBrh\nHo6ZmU3kHo6ZmRXFBSeR7PPIzle2zPkyZ6uT/x5OIhdeCI88svHtAyHan2Kcn4fDh9sehZnNmns4\niUgwVbypd5iNjgzDzNbhHo6ZmRWlswVH0k5Jd7c9jrIM2h7ATGWfJ3e+cmXOVqfOFpxK5ydapC1/\nyrQN8PNsVr7O9nAk7QRujYhdkp4H3AR8AvhR4DTgecBvRcR7q+0fBn4JeAPwl8BPRsRfrDpm7T2c\nam6z1mNuVuYeTpeeZ7PtZtv0cCT9EKNi83bgL4BzgJ8GdgEXSTq72vR04EsRsQj8AfAvWxiumZmt\no+sF528BvwW8NSK+Vq27LSIeiYhHgT8GdlbrH42I365u3wUsNDrSThi0PYCZyj5P7nzlypytTl3/\ndzgPAd8AXgncW617dOz+Jzie4bF11p+g1+uxsLAAwNzcHIuLiywtLQHH3zTTLq9YWR4MltizB44X\ngNH2F188oNd78v51bQ9DBoMpxj9aueX8W3/+lhi1aCbnHw6HrYyvqWXn83JXlgeDAcvLywDHzpd1\n6HwPB3gZ8Dngl4GnA+dFxCXVNrcCvxgR+yQ9HBFnVuv/GXBBRPzsqmO6h7OlHWbDPRyzbts2PZyI\n+CtGFwJcCpy5+u51bpuZWcd0tuBExAMRsau6/VBEvCwirln5dFOtf2NE7KtunzW2/ubVn262h0Hb\nA5ip41NwOTlfuTJnq1NnC46ZmeXS2R7OLPi71La6w2x0ZBhmto5t08MxM7McXHBSGbQ9gJnKPk/u\nfOXKnK1OXf93ODalab5yLKbcflbm59segZk1wT0cMzObyD0cMzMrigtOItnnkZ2vbJnzZc5WJxec\nRFa+qyor5ytb5nyZs9XJBSeRo0ePtj2EmXK+smXOlzlbnVxwzMysES44iRw6dKjtIcyU85Utc77M\n2eq07S6LbnsMZmYlquOy6G1VcMzMrD2eUjMzs0a44JiZWSPSFBxJr5N0r6Q/kfTeNe6/XNIBSfsl\n3S3pcUlzY/efVN13S7Mj35it5JP0LEm/Kenrku6R9LLmE6xvi9l+XtLXJB2U9HFJT28+wWQbyHeW\npFskDat8vY3u2wWbzSfpOZJ+v3pP3i3pkicdvAO28vpV95d+bpn0/pzu3BIRxf8wKpz/E9gJnAIM\ngRdM2P4NwO+tWvfzwMeAW9rOU3c+YBl4R3X7acBZbWeqIxvw/cD9wNOr5RuAt7edadp8wL8FPlDd\n/h7gO9XrNNVzU2C+7wUWq/VnAPdlyjd2f9Hnlkn5pj23ZPmE81LgT2P0Z6kfA34D+MkJ278FuH5l\nQdJzgJ8APjrTUW7epvNJOgt4ZURcBxARj0fEd2c94Cls6bUDTgZOl/Q04DTgwZmNdHM2ki+AM6vb\nZwLfiYjHN7hv2zadLyL+PCKGABHxCPB14OyGxr1RW3n9spxb1sy3mXNLloJzNvDNseU/Y503rqRn\nAq8Dbh5bfRXwC4ye2C7aSr7nAt+WdF31sf6/Vdt0xaazRcSDwJXAN4BvAUcj4vdmOtrpbSTfNcAL\nJT0IfBV4zxT7tm0r+Y6RtAAsAl+eySg3b6v5Mpxb1ss39bklS8GZxoXAH0bEUQBJFwD/p/pNS9VP\nyU7Ix+hj7rnARyLiXOAvgfe1NbgtWv3azTH6bWwno+m1MyS9tcXxbdaPAwci4vuBFwEfkXRGy2Oq\n08R81e2bgPdUn3RKs2a+ROeW9V6/qc8tWQrOt4AfGFt+TrVuLW/mxCmZlwNvlHR/tf58Sb82k1Fu\n3lby/RnwzYj4SrV8E6M3SVdsJdtrgfsj4nBEPAF8EviHMxnl5m0k3zsYjZ2I+F/A/wZesMF927aV\nfFRToTcBvx4Rn575aKe3lXxZzi3r5Zv+3NJ206qmxtfJHG98PZ1R4+vvr7Hdsxg1vJ65znFeRTcb\ne1vKB3wB+KHq9m7gg21nqiMbo/nnu4FnMPrtcRl4d9uZps0HfATYXd3+24ymOHZs9LkpNV+1/GvA\nh9vOMat8Y9sUe255itdvqnNL64FrfOJex+gqlz8F3leteyfwr8a2uRj4xIRjdPJNsdV8wDnAndWb\n6ZPAs9rOU2O23YyazQeBvcApbeeZNh/wfcDvVhkOAm+ZtG/Xfjabj9EngCeq9+UBYD/wurbz1Pn6\njR2j2HPLU7w/pzq3+KttzMysEVl6OGZm1nEuOGZm1ggXHDMza4QLjpmZNcIFx8zMGuGCY2ZmjXDB\nMTOzRrjgmJlZI/4/vY2uCFf3FKAAAAAASUVORK5CYII=\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x117898e50>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# ボックスプロットによる結果の描画\n",
"import warnings\n",
"warnings.simplefilter(\"ignore\")\n",
"\n",
"df_results[df_results.median().sort_values(ascending=True).index].boxplot(vert=False);"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## GradientBoostingClassifierで学習"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# === 線形モデル ===\n",
"# モジュールの読み込み\n",
"# from sklearn import linear_model\n",
"# モデル構築\n",
"# model = linear_model.LogisticRegression()\n",
"\n",
"# === サポートベクターマシン ===\n",
"# モジュールの読み込み\n",
"#from sklearn import svm\n",
"# モデル構築\n",
"#model = svm.SVC()\n",
"\n",
"# === K最近傍法 ===\n",
"# モジュールの読み込み\n",
"#from sklearn.neighbors import KNeighborsClassifier\n",
"# モデル構築\n",
"#model = KNeighborsClassifier()\n",
"\n",
"# === ランダムフォレスト ===\n",
"# モジュールの読み込み\n",
"#from sklearn import ensemble\n",
"# モデル構築\n",
"#model = ensemble.RandomForestClassifier(n_estimators=5, max_depth=10)\n",
"\n",
"# === 勾配ブースティング ===\n",
"# モジュールの読み込み\n",
"from sklearn import ensemble\n",
"# モデル構築\n",
"model = ensemble.GradientBoostingClassifier()"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"GradientBoostingClassifier(init=None, learning_rate=0.1, loss='deviance',\n",
" max_depth=3, max_features=None, max_leaf_nodes=None,\n",
" min_samples_leaf=1, min_samples_split=2,\n",
" min_weight_fraction_leaf=0.0, n_estimators=100,\n",
" presort='auto', random_state=None, subsample=1.0, verbose=0,\n",
" warm_start=False)"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# 学習\n",
"model.fit(X_train, y_train)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# トレーニングセットに対する予測\n",
"y_train_pred = model.predict(X_train)"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# テストセットに対する予測\n",
"y_test_pred = model.predict(X_test)"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# 評価基準モジュール(metrics)の読み込み\n",
"from sklearn import metrics"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.845117845118\n"
]
}
],
"source": [
"# トレーニングデータに対する予測精度を計算\n",
"print(metrics.accuracy_score(y_train, y_train_pred))"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"df_test['Survived'] = y_test_pred"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"df_test[['Survived']].to_csv('output02.csv')"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.axes._subplots.AxesSubplot at 0x11af0cd90>"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZ4AAAD7CAYAAABaMvJSAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAADxFJREFUeJzt3X2wbQVdxvHvIxcbxURJPRrJnSzQGCRAEAzFG74MOhoM\nWCqjGFaak9akNf2hBWjpNDWNhcEMjeJYGY6SjGio2NyjOQnyjoqaNOL4diXClwApXn79cRayud5z\nz97n3vNbZ5/7/czsuWuvvfbaz16z7n7uerlrpaqQJKnLg8YOIEnas1g8kqRWFo8kqZXFI0lqZfFI\nklpZPJKkVpvGDrAeJPGccklaharKrO9xi2dQVXP7OOOMM0bPsCdmN//4D/OP+1gti0eS1MrikSS1\nsng2gC1btowdYdXmOTuYf2zmn0/Zlf10G0WScjlI0mySUJ5cIEla7yweSVIri0eS1MrikSS1sngk\nSa0sHklSK4tHktTK4pEktbJ4JEmtLB5JUiuLR5LUyuKRJLXyDqSDZObr3EnSurCwsJlt224aO8bU\nvDo199362uUgaV5ll+4IuupP9erUkqR5YPFIklpZPJKkVhaPJKmVxSNJamXxSJJaWTySpFYWjySp\nlcUjSWpl8UiSWlk8kqRWFo8kqdVcFU+Sk5Lcm+SgsbNIklZnrooHeAnwb8BLxw4iSVqduSmeJPsA\nxwK/wVA8WXJOkhuSfCzJR5KcPLx2RJLFJFckuSTJwojxJUmDuSke4ETgo1V1I3BLksOBk4EDqupg\n4DTgaQBJNgFnA6dU1VHA+cBbx4ktSZo0T3cgfSnw9mH4fcCpLOV/P0BVfSfJ1uH1JwKHAJdm6dai\nDwK+1RtXkrQjc1E8SR4JHA8csnS3UPZi6ZahH1zuLcDnq+rY6T/lzInhLcNDknSfxcVFFhcXd3k+\nc3Hr6ySvAg6vqtdMjNsKbAWOZGk33GOAG4DfAi4GvgCcVlWXDbveDqqqG5aZv7e+ljTHvPX1Wngx\nP751cyGwAHyDpZJ5D3AV8P2qugt4EfDnSa4FrmE4/iNJGtdcbPHsTJJ9qur2JPsBlwPHVtXNM87D\nLR5Jc2y+tnjm4hjPCj6c5BHA3sCbZy0dSVKvud/i2R3c4pE03+Zri2dejvFIkjYIi0eS1MrikSS1\nsngkSa0sHklSK4tHktTK4pEktbJ4JEmtLB5JUiuLR5LUyuKRJLWyeCRJrTbC1al3k5mvcydJ68LC\nwuaxI8zE4hl4lW5J6uGuNklSK4tHktTK4pEktbJ4JEmtLB5JUiuLR5LUyuKRJLWyeCRJrSweSVIr\ni0eS1MrikSS1sngkSa0sHklSK4tHktTK4pEktbJ4JEmtLB5JUiuLR5LUyuKRJLWyeCRJrSweSVIr\ni0eS1MrikSS1sngkSa0sHklSK4tHktTK4pEktbJ4JEmtLB5JUiuLR5LUyuKRJLWyeCRJrSweSVIr\ni0eS1MrikSS1sngkSa02jR1gvUgydoQNZ2FhM9u23TR2DEnrTKpq7AyjS1Lgctj9guuXtHEloapm\n/le7u9okSa0sHklSK4tHktTK4pEktbJ4JEmtLB5JUiuLR5LUyuKRJLWyeCRJrSweSVIri0eS1Mri\nkSS1WlfFk+SNST6f5LokVyc5ahXz2JrkiB2Mf0WSs3dPUknSaq2b2yIkOQZ4PnBYVd2dZD/gwbv5\nY7xUsiSNbD1t8TwOuKWq7gaoqluraluSI5IsJrkiySVJFpLsleSzSY4DSPK2JG/ZfoZJTk/y5SSX\nAcf2fh1J0o6sp+L5OHBAki8l+dskxyXZBJwNnFJVRwHnA2+tqnuAXwfOTfIs4LnAmZMzS/LYYdzT\ngKcDB3d9EUnS8tbNrraqun04NvMM4HjgAuDPgEOAS7N0i9AHAd8epr8hyT8AHwaOHspo0tHA1qq6\nFSDJ+4ADl09w5sTwluEhSbrP4uIii4uLuzyfdXsH0iSnAL8D/ERV7XA3WZL3stQQr6iqS4dxW4E3\nAI8HTq6qVwzjXwccWFW/u4P5eAfSNeEdSKWNbO7vQJrkoCQ/PzHqMOAG4NHDiQck2ZTk4GH4ZOCR\nwHHAO5I8fLtZXg4cl+SRSfYGfnXNv4QkaUXrZlcb8DDg7CT7AncDNwKvAs6bGL8X8PYk3wHeChxf\nVd8aTpP+a+B0hk2X4cSEM4HLgO8C1zZ/H0nSDqzbXW2d3NW2VtzVJm1kc7+rTZK0Z7B4JEmtLB5J\nUiuLR5LUyuKRJLWyeCRJrSweSVIri0eS1MrikSS1sngkSa0sHklSK4tHktTK4pEktbJ4JEmt1tP9\neEY285W9tYKFhc1jR5C0Dlk8A+8bI0k93NUmSWpl8UiSWlk8kqRWFo8kqZXFI0lqZfFIklpZPJKk\nVhaPJKmVxSNJamXxSJJaWTySpFYWjySplcUjSWpl8UiSWlk8kqRWFo8kqZXFI0lqZfFIklpZPJKk\nVhaPJKmVxSNJamXxSJJaWTySpFYWjySplcUjSWpl8UiSWlk8kqRWFo8kqZXFI0lqZfFIklpZPJKk\nVhaPJKmVxSNJamXxSJJaWTySpFabxg6wXiQZO8KGs7CwmW3bbho7hqR1JlU1dobRJSlwOex+wfVL\n2riSUFUz/6vdXW2SpFYWjySplcUjSWpl8UiSWlk8kqRWFo8kqZXFI0lqZfFIklpZPJKkVhaPJKmV\nxSNJamXxSJJarVg8Se5JcnWSa4Y/D9jVD03y6iQvG4bPT3LyCtO/Msn1Sa4b/nzhMP6sJMfvah5J\nUp8Vr06d5AdV9fA1C5CcD1xcVf+8zOv7A58EDquq25I8FHh0VX1tN2bw6tRrwqtTSxvZWl6d+sdm\nmmRzkk8luXJ4HDOMf2aSxSQXJbkxyduSnJrk8mFr5WeH6c5I8vrt5vnLST448fzZSS4EHgP8ALgD\noKruuK907ttaSvKUiS2y65PcM7z+hCSXJLkiySeTHDTrApIk7V7TFM9DJna1XTiM+w7w7Ko6EngJ\ncPbE9IcCrwIOBl4OHFhVRwPvBF633IdU1VbgiUl+ahh1+vCe64Cbga8meVeSF+zgvVdV1eFVdQTw\nUeAvhpfOA15bVUcBfwicO8X3lSStoWnuQHrH8IM+6cHAO5IcBtwDHDjx2hVVdTNAkv8EPj6M/xyw\nZYXP+nvgZUneDRwDvLyq7gVOSHIk8Czgr5IcUVVv3v7NSV4MHA48N8k+wC8B78/9txfde4rvK0la\nQ6u99fXvA9uq6tAkewE/nHjtfyeG7514fu8Un/du4OLhPe8fSgeAqroSuDLJJ4B3AQ8oniSHAH8C\nPKOqKsmDgO/uoDSXcebE8BZW7khJ2rMsLi6yuLi4y/OZpnh2dOBoX+Drw/BpwF67nASoqm8n+Rbw\nRuDZAEkeBzy2qq4ZJjsceMCJBUn2Bd4LnFZVtw7z+p8kX03yoqr6wDDdoVV1/Y4//czd8RUkacPa\nsmULW7Zs+dHzs846a1XzmaZ4dnRa0jnAhUlOY+mYyu0zvHelaf4ReFRVfXl4vjfwl0MB3Qn8F/Db\n2733ROAA4O+G3Wo1bOm8DDg3yZtY+q4XAMsUjySpw4qnU3dLcjZwdVWd3/iZnk69JjydWtrIVns6\n9boqniRXArcBz6mquxo/1+JZExaPtJFtiOIZi8WzViweaSNby/9AKknSbmPxSJJaWTySpFYWjySp\nlcUjSWpl8UiSWlk8kqRWFo8kqZXFI0lqZfFIklpZPJKkVhaPJKnVau9AugHNfJ07rWBhYfPYESSt\nQxbPwKsoS1IPd7VJklpZPJKkVhaPJKmVxSNJamXxSJJaWTySpFYWjySplcUjSWpl8UiSWlk8kqRW\nFo8kqZXFI0lqZfFIklpZPBvA4uLi2BFWbZ6zg/nHZv75ZPFsAPO88s5zdjD/2Mw/nyweSVIri0eS\n1CreeROSuBAkaRWqKrO+x+KRJLVyV5skqZXFI0lqtccUT5ITknwpyX8k+aNlpvmbJF9Jcm2Sw7oz\n7sxK+ZM8Mcm/J7kzyevHyLgzU+Q/Ncl1w+PTSZ48Rs7lTJH/V4bs1yT5bJJjx8i5nGnW/2G6o5Lc\nleTkznw7M8Wyf2aS7yW5eni8aYycy5nyt2fLsO58PsnW7ow7M8Xy/4Mh+9VJPpfk7iSP2OlMq2rD\nP1gq2BuBzcDewLXAk7ab5nnAR4bho4HLxs49Y/5HAU8B3gK8fuzMq8h/DLDvMHzCHC7/h04MPxn4\n4ti5Z8k/Md2/Ah8GTh479wzL/pnAh8bOugv59wW+AOw/PH/U2LlnXXcmpn8B8ImV5runbPE8FfhK\nVX2tqu4CLgBO3G6aE4H3AFTV5cC+SRZ6Yy5rxfxVdUtVXQXcPUbAFUyT/7Kq+v7w9DJg/+aMOzNN\n/jsmnj4MuLcx30qmWf8BXgd8ALi5M9wKps0+85lVTabJfypwYVV9E5b+Ljdn3Jlpl/99Xgr800oz\n3VOKZ3/g6xPPv8GP/7BtP803dzDNWKbJv57Nmv83gUvWNNFspsqf5KQkXwQuBl7ZlG0aK+ZP8tPA\nSVV1LuvrR3zadedpwy7yjyQ5uCfaVKbJfxCwX5KtSa5I8vK2dCub+u9ukoewtLfiwpVmumm3RJN2\nkyS/DJwOPH3sLLOqqouAi5I8HfhT4DkjR5rF24HJ/ffrqXxWchVwQFXdkeR5wEUs/ZjPi03AEcDx\nwD7AZ5J8pqpuHDfWzF4IfLqqvrfShHtK8XwTOGDi+c8M47af5vErTDOWafKvZ1PlT3IocB5wQlV9\ntynbNGZa/lX16SRPSLJfVd265ulWNk3+I4ELkoSl44XPS3JXVX2oKeNyVsxeVbdNDF+S5Jw5W/bf\nAG6pqjuBO5N8CvhFlo6tjG2Wdf8lTLGbDdhjTi7Yi/sPkD2YpQNkv7DdNM/n/pMLjmF9HdxeMf/E\ntGcAbxg78yqW/wHAV4Bjxs67yvw/NzF8BPD1sXOvZv0Zpj+f9XNywTTLfmFi+KnATWPnnjH/k4BL\nh2kfCnwOOHjs7LOsOyydIPHfwEOmme8escVTVfckeS3wcZaOa72zqr6Y5NVLL9d5VfUvSZ6f5Ebg\ndpZ296wL0+QfToS4EvhJ4N4kv8fSynvb8nPuMU1+4I+B/YBzhn9131VVTx0v9f2mzH9KktOA/wN+\nCPzaeIkfaMr8D3hLe8hlTJn9RUleA9zF0rJ/8XiJH2jK354vJfkYcD1wD3BeVd0wYuwfmWHdOQn4\nWFX9cJr5eskcSVKrPeWsNknSOmHxSJJaWTySpFYWjySplcUjSWpl8UiSWlk8kqRWFo8kqdX/A6Hb\nNRfMVkR9AAAAAElFTkSuQmCC\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x11af0c850>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"warnings.simplefilter(\"ignore\")\n",
"\n",
"df_fi = pd.DataFrame(model.feature_importances_, index=df_train[inputs].columns)\n",
"df_fi.sort(columns=0, inplace=True)\n",
"df_fi.plot(kind='barh', legend=False)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## パラメーターチューニング"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# ランダムサーチ用にRandomizedSearchCVモジュールを読み込む\n",
"from sklearn.grid_search import RandomizedSearchCV\n",
"# 分布を指定するためにscipy.statsを読み込む\n",
"import scipy.stats as stats"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### GradientBoostingClassifier Ver"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# \"loss\": 'deviance', \n",
"# \"learning_rate\": 0.1, \n",
"# \"n_estimators\": 100, \n",
"# \"subsample\": 1.0, \n",
"# \"min_samples_split\": 2, \n",
"# \"min_samples_leaf\": 1, \n",
"# \"min_weight_fraction_leaf\": 0.0, \n",
"# \"max_depth\": 3, \n",
"# \"init\": None, \n",
"# \"random_state\": None, \n",
"# \"max_features\": None, \n",
"# \"verbose\": 0, \n",
"# \"max_leaf_nodes\": None, \n",
"# \"warm_start\": False, \n",
"# \"presort\": 'auto'\n",
"\n",
"# パラメータ空間上に分布を指定する(今回はランダムフォレストを仮定)\n",
"param_dist = {\n",
" \"n_estimators\": np.arange(75, 125),\n",
" \"min_samples_split\": stats.randint(1, 11), \n",
" \"min_samples_leaf\": stats.randint(1, 5), \n",
" \"max_features\": stats.randint(1, 3)\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# ランダムサーチCVオブジェクトを作る\n",
"random_search_gbc = RandomizedSearchCV(GradientBoostingClassifier(random_state=42), \n",
" param_distributions=param_dist, cv=10, \n",
" n_iter=10, random_state=42, n_jobs=-1)"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# ランダムサーチを実行\n",
"tic = time.time() # 時間計測開始\n",
"random_search_gbc.fit(X_train, y_train)\n",
"toc = time.time() # 時間計測終了"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Best score: 0.824915824916\n",
"Execution time: 2.87 sec\n",
"Best param:\n",
"{'max_features': 1,\n",
" 'min_samples_leaf': 4,\n",
" 'min_samples_split': 6,\n",
" 'n_estimators': 82}\n"
]
}
],
"source": [
"# 結果を表示\n",
"from pprint import pprint\n",
"print(\"Best score: {0}\\nExecution time: {1:.2f} sec\".format(random_search_gbc.best_score_, toc - tic))\n",
"print(\"Best param:\")\n",
"pprint(random_search_gbc.best_params_)"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'max_features': 1, 'min_samples_split': 6, 'n_estimators': 82, 'min_samples_leaf': 4}\n"
]
}
],
"source": [
"# ベストなパラメータを別名で保存\n",
"gbc_best_params = random_search_gbc.best_params_\n",
"print(gbc_best_params)"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# モデルの構築(ランダムサーチで見つけたベストなパラメータを使用)\n",
"best_gbc_model = GradientBoostingClassifier(random_state=42, **gbc_best_params)"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"GradientBoostingClassifier(init=None, learning_rate=0.1, loss='deviance',\n",
" max_depth=3, max_features=1, max_leaf_nodes=None,\n",
" min_samples_leaf=4, min_samples_split=6,\n",
" min_weight_fraction_leaf=0.0, n_estimators=82,\n",
" presort='auto', random_state=42, subsample=1.0, verbose=0,\n",
" warm_start=False)"
]
},
"execution_count": 36,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# モデルの学習\n",
"best_gbc_model.fit(X_train, y_train)"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"mean accuracy (train): 0.8316\n"
]
}
],
"source": [
"# トレーニングデータに対する予測精度\n",
"print \"mean accuracy (train): {0:.4f}\".format(best_gbc_model.score(X_train, y_train))"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"df_test[['Survived']].to_csv('gbc_pred.csv')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### LogisticRegressionVer"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# \"penalty\": 'l2', \n",
"# \"dual\": False, \n",
"# \"tol\": 0.0001, \n",
"# \"C\": 1.0, \n",
"# \"fit_intercept\": True, \n",
"# \"intercept_scaling\": 1, \n",
"# \"class_weight\": None, \n",
"# \"random_state\": None, \n",
"# \"solver\": 'liblinear', \n",
"# \"max_iter\": 100, \n",
"# \"multi_class\": 'ovr', \n",
"# \"verbose\": 0, \n",
"# \"warm_start\": False, \n",
"# \"n_jobs\": 1\n",
"\n",
"# パラメータ空間上に分布を指定する(今回はランダムフォレストを仮定)\n",
"param_dist = {\n",
" \"class_weight\": ['balanced', None], \n",
" \"max_iter\": np.arange(75, 125)\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# ランダムサーチCVオブジェクトを作る\n",
"random_search_lr = RandomizedSearchCV(LogisticRegression(random_state=42), \n",
" param_distributions=param_dist, cv=10, \n",
" n_iter=10, random_state=42, n_jobs=-1)"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# ランダムサーチを実行\n",
"tic = time.time() # 時間計測開始\n",
"random_search_lr.fit(X_train, y_train)\n",
"toc = time.time() # 時間計測終了"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Best score: 0.791245791246\n",
"Execution time: 0.48 sec\n",
"Best param:\n",
"{'class_weight': None, 'max_iter': 76}\n"
]
}
],
"source": [
"# 結果を表示\n",
"from pprint import pprint\n",
"print(\"Best score: {0}\\nExecution time: {1:.2f} sec\".format(random_search_lr.best_score_, toc - tic))\n",
"print(\"Best param:\")\n",
"pprint(random_search_lr.best_params_)"
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'max_iter': 76, 'class_weight': None}\n"
]
}
],
"source": [
"# ベストなパラメータを別名で保存\n",
"lr_best_params = random_search_lr.best_params_\n",
"print(lr_best_params)"
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# モデルの構築(ランダムサーチで見つけたベストなパラメータを使用)\n",
"best_lr_model = LogisticRegression(random_state=42, **lr_best_params)"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"LogisticRegression(C=1.0, class_weight=None, dual=False, fit_intercept=True,\n",
" intercept_scaling=1, max_iter=76, multi_class='ovr', n_jobs=1,\n",
" penalty='l2', random_state=42, solver='liblinear', tol=0.0001,\n",
" verbose=0, warm_start=False)"
]
},
"execution_count": 45,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# モデルの学習\n",
"best_lr_model.fit(X_train, y_train)"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"mean accuracy (train): 0.7912\n"
]
}
],
"source": [
"# トレーニングデータに対する予測精度\n",
"print \"mean accuracy (train): {0:.4f}\".format(best_lr_model.score(X_train, y_train))"
]
},
{
"cell_type": "code",
"execution_count": 47,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"df_test[['Survived']].to_csv('lr_pred.csv')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### RandomForestClassifier Ver"
]
},
{
"cell_type": "code",
"execution_count": 48,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# \"n_estimators\": 10, \n",
"# \"criterion\": 'gini', \n",
"# \"max_depth\": None, \n",
"# \"min_samples_split\": 2, \n",
"# \"min_samples_leaf\": 1, \n",
"# \"min_weight_fraction_leaf\": 0.0, \n",
"# \"max_features\": 'auto', \n",
"# \"max_leaf_nodes\": None, \n",
"# \"bootstrap\": True, \n",
"# \"oob_score\": False, \n",
"# \"n_jobs\": 1, \n",
"# \"random_state\": None, \n",
"# \"verbose\": 0, \n",
"# \"warm_start\": False, \n",
"# \"class_weight\": None\n",
"\n",
"# パラメータ空間上に分布を指定する(今回はランダムフォレストを仮定)\n",
"param_dist = {\n",
" \"n_estimators\": np.arange(75, 125),\n",
" \"min_samples_split\": stats.randint(1, 11), \n",
" \"min_samples_leaf\": stats.randint(1, 5), \n",
" \"max_features\": stats.randint(1, 3)\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 49,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# ランダムサーチCVオブジェクトを作る\n",
"random_search_rf = RandomizedSearchCV(RandomForestClassifier(random_state=42), \n",
" param_distributions=param_dist, cv=10, \n",
" n_iter=10, random_state=42, n_jobs=-1)"
]
},
{
"cell_type": "code",
"execution_count": 50,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# ランダムサーチを実行\n",
"tic = time.time() # 時間計測開始\n",
"random_search_rf.fit(X_train, y_train)\n",
"toc = time.time() # 時間計測終了"
]
},
{
"cell_type": "code",
"execution_count": 51,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Best score: 0.819304152637\n",
"Execution time: 20.27 sec\n",
"Best param:\n",
"{'max_features': 2,\n",
" 'min_samples_leaf': 3,\n",
" 'min_samples_split': 1,\n",
" 'n_estimators': 93}\n"
]
}
],
"source": [
"# 結果を表示\n",
"from pprint import pprint\n",
"print(\"Best score: {0}\\nExecution time: {1:.2f} sec\".format(random_search_rf.best_score_, toc - tic))\n",
"print(\"Best param:\")\n",
"pprint(random_search_rf.best_params_)"
]
},
{
"cell_type": "code",
"execution_count": 52,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'max_features': 2, 'min_samples_split': 1, 'n_estimators': 93, 'min_samples_leaf': 3}\n"
]
}
],
"source": [
"# ベストなパラメータを別名で保存\n",
"rf_best_params = random_search_rf.best_params_\n",
"print(rf_best_params)"
]
},
{
"cell_type": "code",
"execution_count": 53,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# モデルの構築(ランダムサーチで見つけたベストなパラメータを使用)\n",
"best_rf_model = RandomForestClassifier(random_state=42, **rf_best_params)"
]
},
{
"cell_type": "code",
"execution_count": 54,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"RandomForestClassifier(bootstrap=True, class_weight=None, criterion='gini',\n",
" max_depth=None, max_features=2, max_leaf_nodes=None,\n",
" min_samples_leaf=3, min_samples_split=1,\n",
" min_weight_fraction_leaf=0.0, n_estimators=93, n_jobs=1,\n",
" oob_score=False, random_state=42, verbose=0, warm_start=False)"
]
},
"execution_count": 54,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# モデルの学習\n",
"best_rf_model.fit(X_train, y_train)"
]
},
{
"cell_type": "code",
"execution_count": 55,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"mean accuracy (train): 0.8395\n"
]
}
],
"source": [
"# トレーニングデータに対する予測精度\n",
"print \"mean accuracy (train): {0:.4f}\".format(best_rf_model.score(X_train, y_train))"
]
},
{
"cell_type": "code",
"execution_count": 56,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"df_test[['Survived']].to_csv('rf_pred.csv')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## モデルアンサンブルによる予測"
]
},
{
"cell_type": "code",
"execution_count": 57,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# VotingClassifierの読み込み\n",
"from sklearn.ensemble import VotingClassifier"
]
},
{
"cell_type": "code",
"execution_count": 58,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# 複数のモデルを用意。各モデルのハイパーパラメータはチューニング済みと仮定\n",
"classifiers = [\n",
" ('gbc', GradientBoostingClassifier(random_state=42, **gbc_best_params)),\n",
" ('lr', LogisticRegression(random_state=42, **lr_best_params)),\n",
" ('rf', RandomForestClassifier(random_state=42, **rf_best_params))\n",
"]"
]
},
{
"cell_type": "code",
"execution_count": 59,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# VotingClassifierの作成\n",
"models = VotingClassifier(classifiers, weights=[1, 1, 1])"
]
},
{
"cell_type": "code",
"execution_count": 60,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"VotingClassifier(estimators=[('gbc', GradientBoostingClassifier(init=None, learning_rate=0.1, loss='deviance',\n",
" max_depth=3, max_features=1, max_leaf_nodes=None,\n",
" min_samples_leaf=4, min_samples_split=6,\n",
" min_weight_fraction_leaf=0.0, n_estimators=82,\n",
" presort='aut...stimators=93, n_jobs=1,\n",
" oob_score=False, random_state=42, verbose=0, warm_start=False))],\n",
" voting='hard', weights=[1, 1, 1])"
]
},
"execution_count": 60,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# トレーニング\n",
"models.fit(X_train, y_train)"
]
},
{
"cell_type": "code",
"execution_count": 61,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"mean accuracy (train): 0.8339\n"
]
}
],
"source": [
"# トレーニングデータに対する予測精度\n",
"print(\"mean accuracy (train): {0:.4f}\".format(models.score(X_train, y_train)))\n"
]
},
{
"cell_type": "code",
"execution_count": 62,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"df_test[['Survived']].to_csv('voting_pred.csv')"
]
}
],
"metadata": {
"anaconda-cloud": {},
"kernelspec": {
"display_name": "Python [Root]",
"language": "python",
"name": "Python [Root]"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.12"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment