Skip to content

Instantly share code, notes, and snippets.

@canard0328
Last active August 22, 2018 09:16
Show Gist options
  • Save canard0328/81c827dadcd18df69d76533c93eaf7dc to your computer and use it in GitHub Desktop.
Save canard0328/81c827dadcd18df69d76533c93eaf7dc to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Jupyter notebook tips \n",
"- Tab: コード補完候補表示\n",
"- Shift + Tab: メソッドのドキュメント表示\n",
"- Shift + Tab + Tab: ドキュメント表示欄拡大\n",
"- スペース2個: マークダウン内での改行"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# データの取得"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Abalone datasetをUCI Machine Learning RepositoryからDLします. \n",
"dataはヘッダ行のないCSVファイルなので,あらかじめカラム名を用意しておきます."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import pandas as pd\n",
"if not os.path.exists('abalone.csv'):\n",
" col_names = ('sex', 'length', 'diameter', 'height', 'whole_weight',\n",
" 'shucked_weight', 'viscera_weight', 'shell_weight', 'rings')\n",
" data = pd.read_csv(('http://archive.ics.uci.edu/ml/machine-learning-databases'\n",
" '/abalone/abalone.data'), header=None, names=col_names)\n",
" data.to_csv('abalone.csv', index=False)\n",
"else:\n",
" data = pd.read_csv('abalone.csv')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# データの確認・可視化"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"きちんとデータを取得できていることを,ざっと確認しましょう"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>sex</th>\n",
" <th>length</th>\n",
" <th>diameter</th>\n",
" <th>height</th>\n",
" <th>whole_weight</th>\n",
" <th>shucked_weight</th>\n",
" <th>viscera_weight</th>\n",
" <th>shell_weight</th>\n",
" <th>rings</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>M</td>\n",
" <td>0.455</td>\n",
" <td>0.365</td>\n",
" <td>0.095</td>\n",
" <td>0.5140</td>\n",
" <td>0.2245</td>\n",
" <td>0.1010</td>\n",
" <td>0.150</td>\n",
" <td>15</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>M</td>\n",
" <td>0.350</td>\n",
" <td>0.265</td>\n",
" <td>0.090</td>\n",
" <td>0.2255</td>\n",
" <td>0.0995</td>\n",
" <td>0.0485</td>\n",
" <td>0.070</td>\n",
" <td>7</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>F</td>\n",
" <td>0.530</td>\n",
" <td>0.420</td>\n",
" <td>0.135</td>\n",
" <td>0.6770</td>\n",
" <td>0.2565</td>\n",
" <td>0.1415</td>\n",
" <td>0.210</td>\n",
" <td>9</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>M</td>\n",
" <td>0.440</td>\n",
" <td>0.365</td>\n",
" <td>0.125</td>\n",
" <td>0.5160</td>\n",
" <td>0.2155</td>\n",
" <td>0.1140</td>\n",
" <td>0.155</td>\n",
" <td>10</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>I</td>\n",
" <td>0.330</td>\n",
" <td>0.255</td>\n",
" <td>0.080</td>\n",
" <td>0.2050</td>\n",
" <td>0.0895</td>\n",
" <td>0.0395</td>\n",
" <td>0.055</td>\n",
" <td>7</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" sex length diameter height whole_weight shucked_weight viscera_weight \\\n",
"0 M 0.455 0.365 0.095 0.5140 0.2245 0.1010 \n",
"1 M 0.350 0.265 0.090 0.2255 0.0995 0.0485 \n",
"2 F 0.530 0.420 0.135 0.6770 0.2565 0.1415 \n",
"3 M 0.440 0.365 0.125 0.5160 0.2155 0.1140 \n",
"4 I 0.330 0.255 0.080 0.2050 0.0895 0.0395 \n",
"\n",
" shell_weight rings \n",
"0 0.150 15 \n",
"1 0.070 7 \n",
"2 0.210 9 \n",
"3 0.155 10 \n",
"4 0.055 7 "
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data.head()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>length</th>\n",
" <th>diameter</th>\n",
" <th>height</th>\n",
" <th>whole_weight</th>\n",
" <th>shucked_weight</th>\n",
" <th>viscera_weight</th>\n",
" <th>shell_weight</th>\n",
" <th>rings</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>count</th>\n",
" <td>4177.000000</td>\n",
" <td>4177.000000</td>\n",
" <td>4177.000000</td>\n",
" <td>4177.000000</td>\n",
" <td>4177.000000</td>\n",
" <td>4177.000000</td>\n",
" <td>4177.000000</td>\n",
" <td>4177.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>mean</th>\n",
" <td>0.523992</td>\n",
" <td>0.407881</td>\n",
" <td>0.139516</td>\n",
" <td>0.828742</td>\n",
" <td>0.359367</td>\n",
" <td>0.180594</td>\n",
" <td>0.238831</td>\n",
" <td>9.933684</td>\n",
" </tr>\n",
" <tr>\n",
" <th>std</th>\n",
" <td>0.120093</td>\n",
" <td>0.099240</td>\n",
" <td>0.041827</td>\n",
" <td>0.490389</td>\n",
" <td>0.221963</td>\n",
" <td>0.109614</td>\n",
" <td>0.139203</td>\n",
" <td>3.224169</td>\n",
" </tr>\n",
" <tr>\n",
" <th>min</th>\n",
" <td>0.075000</td>\n",
" <td>0.055000</td>\n",
" <td>0.000000</td>\n",
" <td>0.002000</td>\n",
" <td>0.001000</td>\n",
" <td>0.000500</td>\n",
" <td>0.001500</td>\n",
" <td>1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>25%</th>\n",
" <td>0.450000</td>\n",
" <td>0.350000</td>\n",
" <td>0.115000</td>\n",
" <td>0.441500</td>\n",
" <td>0.186000</td>\n",
" <td>0.093500</td>\n",
" <td>0.130000</td>\n",
" <td>8.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>50%</th>\n",
" <td>0.545000</td>\n",
" <td>0.425000</td>\n",
" <td>0.140000</td>\n",
" <td>0.799500</td>\n",
" <td>0.336000</td>\n",
" <td>0.171000</td>\n",
" <td>0.234000</td>\n",
" <td>9.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>75%</th>\n",
" <td>0.615000</td>\n",
" <td>0.480000</td>\n",
" <td>0.165000</td>\n",
" <td>1.153000</td>\n",
" <td>0.502000</td>\n",
" <td>0.253000</td>\n",
" <td>0.329000</td>\n",
" <td>11.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>max</th>\n",
" <td>0.815000</td>\n",
" <td>0.650000</td>\n",
" <td>1.130000</td>\n",
" <td>2.825500</td>\n",
" <td>1.488000</td>\n",
" <td>0.760000</td>\n",
" <td>1.005000</td>\n",
" <td>29.000000</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" length diameter height whole_weight shucked_weight \\\n",
"count 4177.000000 4177.000000 4177.000000 4177.000000 4177.000000 \n",
"mean 0.523992 0.407881 0.139516 0.828742 0.359367 \n",
"std 0.120093 0.099240 0.041827 0.490389 0.221963 \n",
"min 0.075000 0.055000 0.000000 0.002000 0.001000 \n",
"25% 0.450000 0.350000 0.115000 0.441500 0.186000 \n",
"50% 0.545000 0.425000 0.140000 0.799500 0.336000 \n",
"75% 0.615000 0.480000 0.165000 1.153000 0.502000 \n",
"max 0.815000 0.650000 1.130000 2.825500 1.488000 \n",
"\n",
" viscera_weight shell_weight rings \n",
"count 4177.000000 4177.000000 4177.000000 \n",
"mean 0.180594 0.238831 9.933684 \n",
"std 0.109614 0.139203 3.224169 \n",
"min 0.000500 0.001500 1.000000 \n",
"25% 0.093500 0.130000 8.000000 \n",
"50% 0.171000 0.234000 9.000000 \n",
"75% 0.253000 0.329000 11.000000 \n",
"max 0.760000 1.005000 29.000000 "
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data.describe()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<font color=\"red\">heightが0のデータがあることが分かります.</font>"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"データを可視化してみます."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"変数'sex'は質的変数なので,頻度をカウントしてみます"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"M 1528\n",
"I 1342\n",
"F 1307\n",
"Name: sex, dtype: int64"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pd.value_counts(data['sex'])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"その他の変数は量的変数なのでboxplotで分布をみてみます"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"# notebook内にグラフを出力する\n",
"%matplotlib inline"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x130d11da7f0>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"ax = data.boxplot(rot=45)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"ヒストグラムでもみてみます"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x130d128d588>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"axes = data.hist(bins=20, xlabelsize=7, ylabelsize=10)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"変数同士の相関をみてみます"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>length</th>\n",
" <th>diameter</th>\n",
" <th>height</th>\n",
" <th>whole_weight</th>\n",
" <th>shucked_weight</th>\n",
" <th>viscera_weight</th>\n",
" <th>shell_weight</th>\n",
" <th>rings</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>length</th>\n",
" <td>1.000000</td>\n",
" <td>0.986812</td>\n",
" <td>0.827554</td>\n",
" <td>0.925261</td>\n",
" <td>0.897914</td>\n",
" <td>0.903018</td>\n",
" <td>0.897706</td>\n",
" <td>0.556720</td>\n",
" </tr>\n",
" <tr>\n",
" <th>diameter</th>\n",
" <td>0.986812</td>\n",
" <td>1.000000</td>\n",
" <td>0.833684</td>\n",
" <td>0.925452</td>\n",
" <td>0.893162</td>\n",
" <td>0.899724</td>\n",
" <td>0.905330</td>\n",
" <td>0.574660</td>\n",
" </tr>\n",
" <tr>\n",
" <th>height</th>\n",
" <td>0.827554</td>\n",
" <td>0.833684</td>\n",
" <td>1.000000</td>\n",
" <td>0.819221</td>\n",
" <td>0.774972</td>\n",
" <td>0.798319</td>\n",
" <td>0.817338</td>\n",
" <td>0.557467</td>\n",
" </tr>\n",
" <tr>\n",
" <th>whole_weight</th>\n",
" <td>0.925261</td>\n",
" <td>0.925452</td>\n",
" <td>0.819221</td>\n",
" <td>1.000000</td>\n",
" <td>0.969405</td>\n",
" <td>0.966375</td>\n",
" <td>0.955355</td>\n",
" <td>0.540390</td>\n",
" </tr>\n",
" <tr>\n",
" <th>shucked_weight</th>\n",
" <td>0.897914</td>\n",
" <td>0.893162</td>\n",
" <td>0.774972</td>\n",
" <td>0.969405</td>\n",
" <td>1.000000</td>\n",
" <td>0.931961</td>\n",
" <td>0.882617</td>\n",
" <td>0.420884</td>\n",
" </tr>\n",
" <tr>\n",
" <th>viscera_weight</th>\n",
" <td>0.903018</td>\n",
" <td>0.899724</td>\n",
" <td>0.798319</td>\n",
" <td>0.966375</td>\n",
" <td>0.931961</td>\n",
" <td>1.000000</td>\n",
" <td>0.907656</td>\n",
" <td>0.503819</td>\n",
" </tr>\n",
" <tr>\n",
" <th>shell_weight</th>\n",
" <td>0.897706</td>\n",
" <td>0.905330</td>\n",
" <td>0.817338</td>\n",
" <td>0.955355</td>\n",
" <td>0.882617</td>\n",
" <td>0.907656</td>\n",
" <td>1.000000</td>\n",
" <td>0.627574</td>\n",
" </tr>\n",
" <tr>\n",
" <th>rings</th>\n",
" <td>0.556720</td>\n",
" <td>0.574660</td>\n",
" <td>0.557467</td>\n",
" <td>0.540390</td>\n",
" <td>0.420884</td>\n",
" <td>0.503819</td>\n",
" <td>0.627574</td>\n",
" <td>1.000000</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" length diameter height whole_weight shucked_weight \\\n",
"length 1.000000 0.986812 0.827554 0.925261 0.897914 \n",
"diameter 0.986812 1.000000 0.833684 0.925452 0.893162 \n",
"height 0.827554 0.833684 1.000000 0.819221 0.774972 \n",
"whole_weight 0.925261 0.925452 0.819221 1.000000 0.969405 \n",
"shucked_weight 0.897914 0.893162 0.774972 0.969405 1.000000 \n",
"viscera_weight 0.903018 0.899724 0.798319 0.966375 0.931961 \n",
"shell_weight 0.897706 0.905330 0.817338 0.955355 0.882617 \n",
"rings 0.556720 0.574660 0.557467 0.540390 0.420884 \n",
"\n",
" viscera_weight shell_weight rings \n",
"length 0.903018 0.897706 0.556720 \n",
"diameter 0.899724 0.905330 0.574660 \n",
"height 0.798319 0.817338 0.557467 \n",
"whole_weight 0.966375 0.955355 0.540390 \n",
"shucked_weight 0.931961 0.882617 0.420884 \n",
"viscera_weight 1.000000 0.907656 0.503819 \n",
"shell_weight 0.907656 1.000000 0.627574 \n",
"rings 0.503819 0.627574 1.000000 "
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data.corr()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"このような形の可視化も見やすいかもしれません."
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x130d120b6d8>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"axes = data.drop('rings', axis=1).plot(kind='hist', bins=50, subplots=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# データの前処理"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"変数heightが0であるデータを削除しましょう."
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>sex</th>\n",
" <th>length</th>\n",
" <th>diameter</th>\n",
" <th>height</th>\n",
" <th>whole_weight</th>\n",
" <th>shucked_weight</th>\n",
" <th>viscera_weight</th>\n",
" <th>shell_weight</th>\n",
" <th>rings</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>1257</th>\n",
" <td>I</td>\n",
" <td>0.430</td>\n",
" <td>0.34</td>\n",
" <td>0.0</td>\n",
" <td>0.428</td>\n",
" <td>0.2065</td>\n",
" <td>0.0860</td>\n",
" <td>0.1150</td>\n",
" <td>8</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3996</th>\n",
" <td>I</td>\n",
" <td>0.315</td>\n",
" <td>0.23</td>\n",
" <td>0.0</td>\n",
" <td>0.134</td>\n",
" <td>0.0575</td>\n",
" <td>0.0285</td>\n",
" <td>0.3505</td>\n",
" <td>6</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" sex length diameter height whole_weight shucked_weight \\\n",
"1257 I 0.430 0.34 0.0 0.428 0.2065 \n",
"3996 I 0.315 0.23 0.0 0.134 0.0575 \n",
"\n",
" viscera_weight shell_weight rings \n",
"1257 0.0860 0.1150 8 \n",
"3996 0.0285 0.3505 6 "
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data[data.height==0]"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"data_tidy = data[~(data.height==0)]"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"sex 4175\n",
"length 4175\n",
"diameter 4175\n",
"height 4175\n",
"whole_weight 4175\n",
"shucked_weight 4175\n",
"viscera_weight 4175\n",
"shell_weight 4175\n",
"rings 4175\n",
"dtype: int64"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data_tidy.count()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"質的変数sexをダミー変数を利用して量的変数に変換しましょう"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"data_tidy = pd.get_dummies(data_tidy, columns=['sex'], drop_first=True)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Index(['length', 'diameter', 'height', 'whole_weight', 'shucked_weight',\n",
" 'viscera_weight', 'shell_weight', 'rings', 'sex_I', 'sex_M'],\n",
" dtype='object')"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data_tidy.columns"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"ここで,データを説明変数Xと目的変数yに分けます"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"X = data_tidy.drop('rings', axis=1)\n",
"y = data_tidy.loc[:, 'rings']"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Index(['length', 'diameter', 'height', 'whole_weight', 'shucked_weight',\n",
" 'viscera_weight', 'shell_weight', 'sex_I', 'sex_M'],\n",
" dtype='object')"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X.columns"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'rings'"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y.name"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# モデルの学習と予測"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"手元のデータで学習をして,学習したモデルで予測をしてみましょう."
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.linear_model import LinearRegression\n",
"regr = LinearRegression()\n",
"regr.fit(X, y)\n",
"y_pred = regr.predict(X)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([ 9.22143205, 7.8513656 , 11.09261694, ..., 10.94817007,\n",
" 9.73916518, 10.99364565])"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y_pred"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Text(0,0.5,'$\\\\hat{y}$')"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAVUAAAFBCAYAAADHSzyjAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAIABJREFUeJzt3X9wXeV9JvDnkSwa4TDIbh1iBA6EZUygNhZoiGe82wE3iUlogoCwWw9k6DQzTmeSbjJktDVZdnAa03jrJuw/O2mdhS27JOCAjUKBqUuATDaeYiLHNsZ1XH6EX7ILTogJJE6w5e/+cc8V916dc95zdc7R+957ns+MR9J7daXX19Lj9/dLM4OIiBSjx3cFRES6iUJVRKRAClURkQIpVEVECqRQFREpkEJVRKRA3kKV5LtIPklyD8l9JL8clZ9NcgfJZ0huJnmSrzqKiLTLZ0v1twBWmtmFAJYBuJzkcgD/HcBtZnYugF8A+LTHOoqItMVbqFrNW9GHfdEfA7ASwH1R+Z0ARjxUT0RkRryOqZLsJbkbwGsAHgHwHIAjZnY8+pRXAAz6qp+ISLvm+PzmZjYJYBnJAQD3A/hA3KfFPZfkGgBrAGDu3LkXn3feeaXVU0SqaefOnT8zswXtPMdrqNaZ2RGS3wewHMAAyTlRa/UMAAcTnrMJwCYAGB4etvHx8dmqrohUBMkX232Oz9n/BVELFST7AXwIwH4AjwP4ZPRpNwD4rp8aioi0z2dLdSGAO0n2ohbu3zGzB0n+C4B7SK4HsAvA7R7rKCLSFm+hamZPARiKKX8ewCWzXyMRkfy0o0pEpEAKVRGRAilURUQKpFAVESmQQlVEpEAKVRGRAilURUQKpFAVESmQQlVEpEAKVRGRAilURUQKpFAVESmQQlVEpEAKVRGRAilURUQKpFAVESmQQlVEpEAKVRGRAilURUQKpFAVESmQQlVEpEAKVRGRAilURUQKpFAVESmQQlVEpEAKVRGRAilURUQKpFAVESmQQlVEpEAKVRGRAnkLVZJnknyc5H6S+0h+PipfR3KC5O7oz8d81VFEpF1zPH7v4wC+aGY/JnkKgJ0kH4keu83M/sZj3UREZsRbqJrZIQCHovffJLkfwKCv+oiIFCGIMVWSZwEYArAjKvocyadI3kFynreKiYi0yXuoknw3gC0AvmBmvwTwDQDnAFiGWkv2awnPW0NynOT44cOHZ62+IiJpvIYqyT7UAvVbZrYVAMzsVTObNLMTAL4J4JK455rZJjMbNrPhBQsWzF6lRURS+Jz9J4DbAew3s683lC9s+LSrADw923UTEZkpn7P/KwB8CsBekrujsi8BWE1yGQAD8AKAz/ipnohI+3zO/v8QAGMeeni26yIiUhTvE1UiIt1EoSoiUiCFqohIgRSqIiIFUqiKiBRIoSoiUiCFqohIgXwu/heRFGO7JrBx2wEcPHIUpw/0Y3TVYowM6SC30ClURQI0tmsCN23di6PHJgEAE0eO4qatewFAwRo4df9FArRx24GpQK07emwSG7cd8FQjyUqhKhKgg0eOtlUu4VCoigTo9IH+tsolHApVkQCNrlqM/r7eprL+vl6MrlrsqUaSlSaqRAJUn4zS7H/nUaiKBGpkaFAh2oHU/RcRKZBCVUSkQApVEZECKVRFRAqkUBURKZBCVUSkQApVEZECKVRFRAqkUBURKZB2VImURIdMV5NCVaQEWQ6ZVuh2J3X/RUrgOmS6HroTR47C8E7oju2a8FBbKZJaqiIlcB0ynRa6IbRW1YqeOYWqSAlOH+jHREyw1g+ZDvlk/064Hyvk0Ff3X6QErkOmQz7ZP/T7sUIfOvEWqiTPJPk4yf0k95H8fFQ+n+QjJJ+J3s7zVUeRmRoZGsQ1Fw+ilwQA9JK45uJ3zkcN+WT/kFvRQPih77OlehzAF83sAwCWA/gsyfMBrAXwqJmdC+DR6GORjjK2awJbdk5g0gwAMGmGLTsnplpTI0OD+OrVSzA40A8CGBzox1evXhJEFzbkVjQQfuh7G1M1s0MADkXvv0lyP4BBAFcCuDT6tDsBfB/AX3ioosiMZZmICvVk/9FVi5vGVIFwWtGAe7zatyDGVEmeBWAIwA4Ap0WBWw/e9/irmVTZ2K4JrNjwGM5e+xBWbHisrTG70FtTaUJuRQNhD50AAcz+k3w3gC0AvmBmv2Q0BpXheWsArAGARYsWlVdBqaS8M+Cht6ZcQm1FA+FfikiLxny8fHOyD8CDALaZ2dejsgMALjWzQyQXAvi+maX+FzQ8PGzj4+PlV1g6Sp5lNys2PBYbioMD/di+dmWm7x3XhQ6pxSduJHea2XA7z/HWUmWtSXo7gP31QI08AOAGABuit9/1UD3pcHlbmnm776G3psoW8jrSsvns/q8A8CkAe0nujsq+hFqYfofkpwG8BOBaT/WTDpZ3x1IR3feQu9Bl6oTNA2XyNlFlZj80M5rZUjNbFv152Mx+bmZ/aGbnRm9f91VH6Vx5W5qhT4aELPR1pGULYvZfpGh511qGPgMesk5e+VAE77P/ImUoYq1l3u67a1yxW8cdO33lQ15qqUpX8t3SdO1PD33/eh5VHzpRS1UkQZ6WpGuiLMtE2s1je3H3jpcxaYZeEqs/eCbWjywp5i9XoqqvfFCoSml8dm/zzkCXvSTL9fjNY3tx1xMvTZVPmk19nCVYfQ8tVHXlA6Duv5TEd/c27wx03ue7Jspcj9+94+XYx5PKG/l+7atOoSql8L2sJu8MdBFLsvp6mrdc9/VwalzRNe44mbDTMam8ke/XvuoUqlIK38tq8i6pKuT4u9ZjLBo+dp232ptwBkZSeaMsr32ew2IknUJVSuH7TM68M9B5n79x2wEcm2xuVR6btKaL/9LOW139wTNjv25SeSPXa6/hgXIpVKUUvpfV5F1Slff5eS7+A4Dh982f9svZE5W7uF57DQ+US7P/UorZWFbjmuHOOwOd5/mn9vfhyNFjseVAttA90fLYiajcVaeRoUGMv/h603KsxqEF30MzRfC9uiGNQlVKU+aymtAP7Uga+qyXD5zch1/8enroDpycLXTTJA0tDL9vPkaGBoPY8ZQnFEP/t1f3XzpS6F3YIzGB2VieNIlfL88zJu16bXwPzeQd0w39316hKh2piC6sawY8zwy5KxTjhgYay/MEn+u1ca08KFveUAx9+EKhKh0pS0suLRTL3pvvCkXXkqksE2VJf78ss/9pKw/KljcUfa8scVGoSkcaXbUYfb0ti+t731lc7wpFV2spS2vq5rG9OOemh3HW2odwzk0P4+axvVOPjQwN4qJFpzY9/6JFp06FYpbF/SNDg9i+diV+uuEKbF+7clqgJv39Qp/9zxuKvocvXBSq0rlac6nhY1dwFLU3v7G1d9cTL00F681je7H9uebz1bc/9/rU44MJAZJU3sp1IEtaKzdukiqtvGh5Q9H3CWQumv2XjrRx2wEcO9GyuP6ETYWKKxRdM+Cux9P25q8fWeJ8fHTVYozet6dpg0BjS9vFFYxpKy96ydiWcpbdWlmlze4Xsdwu5ANbFKrSkVyh+a6+Hhw91rrSs1YOAGf9bnxonvW7tdB0HXLt6r5n2ruf0tJ2cQVjWqjlOVcgiyxLnkIOxbzU/ZeO5BqX+01MoDaW//Pz8Vef1ctdXUzXRJPr8bSWdhZpwegaTx6INiC0Sipvl+8xW98UqtKRXONySW2uevmJhE9oLE+bKHLtzXc9nndcM21M1hVqro0JeYW+5KlsClXpSL4nK9aPLMH1yxc1tUyvX75o6gBp1+NZTqFKWxKWtvrBFWqujQl5hb7kqWwaU5Vg5dnK2J8wptofjan29QBxIwR9Dc0M1/dfP7Ik9RT+4ffNx+M/OYyDR47ivae+q+kwFNe4ZqatmAljsnkn4fIq4tLFTqaWqgRpbNcERu/d0zQuOHrvnsyL87969dLYU56+evXSzN//xs27m77+jZt3Z14gX/ZWzLQxWdfQSBHrPNNa0b57Eb6ppSpBWvfAvtjQWPfAvkwX57mW7STMY02V37T1qdhTom7a+tTU10hryWa52C9NnnW0I0ODuHf8paZ1so0bD/Iuaar67L6LQlWC5Nobn2UyJM8vdtzQQWO5K1hcE1GuJVF5uvBpGw/qwxV5Xpu8/2F0O3X/pSPlnQxJmujOOgHu6p67JqJcqwPydOHzXBqYRdVn913UUpUgzUs4b3RedN7oZectaLrCue6y8xZMvX/z2N6mg5pXf/DMqZbanISJqjlRM4OIX5ZVj0pXS9Q1ETX8vvn49o6XmpZw9fCdk/1dXfS0x7+weXfq9wbck3Bpr53rLNiqU6hKkK5YujA2NK9YuhAA8OCeQ7HPe3DPIawfWTK1N7+uvjcfqM3au8ZUXTuy8m713LjtwLS1sies+WR/Vxc96fEsu61G790zNWZdnwSsf03Xa+c6C7bq1P2XID30VHxo1stdY655u8CuMdW8Wz3L7EK7hhbSJgEB92vneu2rzmuokryD5Gskn24oW0dyguTu6M/HfNZR/IjrXqaVtyp7f3veMdkyF8i7Lg10haLrtctzfXYoyryi23dL9e8BXB5TfpuZLYv+PDzLdZIukDf0XFzbYF1c58FmkRQMaZcGZuEKzbL/wypb2Vd0ew1VM/sBgPiTLaTSTu6L/9Gsl89LmBSpl598Um/s40nlRcsU6o5TqmZ6c4FrEs312rmGD8o+kKVsZR/44rulmuRzJJ+KhgfmxX0CyTUkx0mOHz58eLbrJyU7nnDiSb38lo9fENvSu+XjFwAAfv325LTnNpb3JKReUnm7XC1Z1ylVeW4ucB2YcsvHL5j29+whpl4717kFZR/IUrayl4SFOPv/DQBfQe3n7ysAvgbgT1s/ycw2AdgEAMPDw53R75Amact63p6M/yetl7uWHLkWz/cy/qSq3lkKBldr0rXAPi0YEgO94YHeHuJEw2vc25KyaecalH0gS9nKPvsguJaqmb1qZpNmdgLANwFc4rtOUrwixrXGX3wd//bGb2AA/u2N32D8xXdGkhrXqzaql7uWVOWVt4vsak0lrQnNslZ047YDTTcOAMCxyexnuc7GKVRlTiSVfcdVcKFKcmHDh1cBeDrpc6VcZf5g5x3Xct0R9fhP4oeEksqLtu4T8V3sdZ+4INPzXWPCedaK5u3+lh1KZU8klX3gi9fuP8m7AVwK4PdIvgLgFgCXklyGWvf/BQCf8VbBCst09FwOeX+xvxWzMaBevn5kSe5DoF07qrKIW9yf1a8SxoTr5W8kLItKKm90an9f7LKqUzO2oou4YyrNbJwtUOaBL15D1cxWxxTfPusVkWnK/sHOO66Vd0lT2V//xoStojdu3l366/er3x6PDc360EMRE01lhlKnny0QXPdfwlD2D3bod7fnlTQ0W9CQbeqYcdIQQ7089ImmTr85QKEqscr+wR4ZGsQ1Fw82Ldu55uLqnsHZri07X0ksb5ywa1QvDz20Ov0/XIWqxJqNyYgtOyeaJpq27JwodDKsm6WdTRB3EA2AqXLXygjfOv3mgBDXqUoAumEyQuKlHVZTX5ua536wInTyzQEKVUmU9wc77Rcz7+y8zJzrsJqyV350O4WqlGJs1wRG79sztch84shRjN73zpmdec8jrbq022KThgayytKLSDvEuuo0piql+PI/7IvdtfPlf6id2dnpJx35ds3FZ7RV3si128u18sO18aLqFKpSClcXs9NPOirb78yJ/9Wsl29+Mn4yKqm80R9duDC13LU6oOw7sDqdQlW86PSTjlxcRxe6/PZ4fBe+Xp7n7ALXFl7Xyg/1MtIpVMWL0Beg5/VXVy9tq7xIrvNSXd1715Kmbjj5v0yaqBIvXLeZdrq0Bfhlz6Cfv/AUbH9u+vc/f+EpALJtEU5b+bH6g2fGroVNOty6arrkR1g6TdlH7/nmc9zxied/kVqeZWNH2gll60eWYMU585uev+Kc+Zr9j7QVqiS/R/LCsioj0i18jju6vrdri7Dr6L2xXRN48qfNwf3kT3+h3XCR1FAleT7JuxqK/guA20j+75ZzT6WC0loz3T4RFTLXa+/aIuw669Z1xXXVuVqqjwK4uf6Bmf3YzFYCeBDAP5K8hWQYpzDIrHK1Zv7dgrmxz0sql+L0JwxM18tdoena7ea64rrqXKH6EQC3NhaQJIADqN0l9ecAniH5qXKqJ6Fy/WI+d/hXsc9LKpfi/DphYLpe7pr91+x+PqmhamZ7zey6+sckfwhgAsBtAAYB/AlqJ/dfQnJTedWU0Lh+MZNOuW/n9HuZGVf337W43zUm61qyVXXtzv7/GYBBM/uwmf03M3vQzJ41sz8H8B9KqJ8EKunqjaxXckh5XPdXuWb/BxNCt17uuh686toKVTN72ixx+vKKAuojHcLVGkr6wdIaPv9ci/tdoTsyNIiNn7yw6fkbP3mhTrCKFLb438yeL+prSfhcO6Is4eY8q8iw3Em9xNuT01+Ak3rLfwEGEi72azxXIW1xf5azdDv5vNOyaUeVzIhrV06eK5S7QVygppUXad0nLsCNm3c33YfVg+zXYwMKzTzUG5MZGV21GL0tF9v39rBj7hHqdr0tLeLWj6U8ClWZkfEXX8dky1T+5AlL3PMu7cmzeWLjtgOxZ9nWl7sVIW3jR9UpVGVGdKZmufIMn5R9vbhr40fVKVRlRnSmZrjKvoLatfGj6hSqIh1o7km9ieVlXy9edku402n2XxL5vqa4ylwXI55I6BGcMMu0JCrPv22W81irTC1VieUaN9MdU+Va/v55qeVJN6ZmuUk175ho2S3hTqdQlViucbM3fxO/+D+pXNrzws/ju9JJ5Y1coZl3TNS1I6vq1P2XWK7j35LWsM/C2vZKcI1bzju5L/bG2nkn96WG5sjQYCFjotockMxrS5XkHSRfI/l0Q9l8ko+QfCZ6G98PklIlLYfUEvJsWq8baS1PulS1Xu6awb9iafwZ8VcsXegMzYGE06SSyqU9vrv/fw/g8paytQAeNbNzUTske+1sV0pit+2nlkuzsxe8O7U84QbqqfLLzlsQ+3i9PO2aaVcgV30Lcdm8hqqZ/QBA6xacKwHcGb1/J4CRWa2USAG+vWP6baON5a7gSwtNIH14wDWR9EbCCf1J5dIe3y3VOKeZ2SEAiN6+J+6TSK4hOU5y/PDh+B9AEV9ch3S7gs81pp0Wyq6JpLI3B1Rdx05UmdkmAJsAYHh4WB0X6SgjQ4MYf/F13L3jZUyaTbvRNOHkxKkx7dFVi3HT1r1NE1KtZ54mTSS5niv5hNhSfbV+U2v09jXP9akk3VOUzjXR1J/wCfXysV0T2Pzky003mm5+8uWpZU+uMe08y5q0JKpcIbZUHwBwA4AN0dvv+q1ONS1//zxsf276iVNJi9KrJmmNfb38movPwF1PTB9XvebiMwCkX/OcNdzyLGvSkqjy+F5SdTeAfwawmOQrJD+NWph+mOQzAD4cfSyzbNdLR9oql2YPPXUotdx1zbOWtHUury1VM1ud8NAfzmpFZBrXNceSLm5hflp5q+uWL4pt6V63fFGuekn5QhxTFel6rmue148swfXLF02NYfeSuH75IqwfWTJrdZSZCXFMVcQ71+y7i+vyvVs+fgFG79vTdEJ/6zXP60eWKEQ7kFqqIjFcs+9p55kCtUv2+lru8Orr4dTle7rmuXuppVphOi81mes801uvWoIv3run6Z6u3h7i1qtqLUtd81xdCtWKqh8PV18AXj8eDoB+0eG+LiZLaEo1qftfUbpnKF1PwuBpUnkrXY5XXQrVitI9Q+lce/fLPghaOpdCtaJ0qEY+rtDUf1rVpVCtqNFVi6d1ZXsIHaqRkSs09Z9WdSlUK2r8xdendXFPWK1c3FyhqcvxqkuhWlGuQ5QlnSs0dRJUdWlJVRdLW4fqmojpdnNP6sWv3p6MLQfcO6K0DlWSKFS7VNXXoQ4O9Meenj8Ydc/jArWxfN0nLsDovXuajudr3BEFKDQlnrr/XarqS3pc3XPXIdwjQ4PYeG3LNtJrtY1U3NRS7VJVX9Ljuq7EtWOq/jUUotIutVS71Kn98UfLJZV3m7FdE9iyc6LpupItOyemFucPJszeJ5WLZKVQ7VJJV0lV5Yop1/DH6KrF6OttOUWql1ryJLkpVLvUkYQT5pPKu02m4Y/WEYCKrHyQcilUu5TrNs/QubrnAwnDGPVy1+L8jdsOxF68V5WJPClPZ/yGSduOJtwllVQeGtfs/R9duDD2efVy1/OrPpEn5dHsf5dynVwfOtfi+sd/cjj2efVy1/NPTVjcX5WJPCmPQlW8cJ2sD6QvacrS0kx7ftUn8qQ8CtUO1snXoWRZJ5rm9IQdU1lPgar6RJ6UR2OqHWps1wRG793TdEjy6L17OuZkedeOJpfLzlvgLB/bNYEVGx7D2WsfwooNjzW9NjqaT8qiUO1Q6x7YFzt7ve6BfZ5q1J68LVXXmKrrZH4dzSdlUah2qLhJlrTy0ORtqbrGVF2L/0eGBnHNxYNT3691G6vITClUpRRJ0Vgvz9tSHTg5YZ1qVO4KXdc2VpGZUqh2qN+ZE/9Pl1Q+265bvii1PO/e+6TsrZdnWfxf5VO8pDxh/AZK294+Hr+IP6m8aK6W6PqRJbh++aKm7vX1yxdh/cgSAPnHNN9IGOaol2vxv/gS7JIqki8AeBPAJIDjZjbst0Zh8b24/7rli3DXE9OvXmlsoa4fWTIVoq2ynJyfxrWkyvX18y7JEkkSbKhGLjOzn/muhEznmn3PIs95paOrFjfdbABMb+mmff0szxeZidBDVTw5ua8Hv445J+Dk6EAW393nvC3dvM8XSRJyqBqAfyJpAP7OzDb5rlCVxAVqY3kI3ee8J/PrZH8pQ8gTVSvM7CIAHwXwWZJ/0PggyTUkx0mOHz6cvcspxdDieZF4wYaqmR2M3r4G4H4Al7Q8vsnMhs1seMGC+C2LMnP1q5qTynWvvUi8ILv/JOcC6DGzN6P3PwLgLz1Xa9b5PDClr7cHtYUXceU16j6LTBdkqAI4DcD9rK1xnAPg22b2j36rNLvqe9frs9P1vesACgmywYQx0frie9c6UBGJF2T338yeN7MLoz8XmNmtvus028re8eMaE9UpTiIzE2SoSvlLllxjokVMRKUdvSfSrULt/leea8mSax0pEb+7qnF7adqYaN51nGUPX4iESi1Vj9Jacq6W4l9dvRQ9LRvwe1grB4Db/tOy2O+ZVB5nZGgQ29euxE83XIHta1e2FYY6sESqSi1VT1wtOVdLMe/jZfO940rEF1rG8ytDNjw8bOPj476r0ZYVGx5LnH3fvnalhxoVq9v/flINJHe2e5iTuv+edHtLTjuupKoUqp50+5Il7biSqtKYqidVOHpOO66kihSqnvieSBKRcihUPVJLTqT7aExVRKRAaqnKjPk8RUskVApVmRFtQxWJp+6/zIi2oYrEU6jKjHT75gWRmVKoyox0++YFkZlSqMqMaBuqSDxNVMmMaPOCSDyFqsyYNi+ITKfuv4hIgRSqIiIFUqiKiBRIoSoiUiCFqohIgRSqIiIFUqiKiBRIoSoiUiCFqohIgRSqIiIFUqiKiBQo2FAleTnJAySfJbnWd31ERLIIMlRJ9gL4nwA+CuB8AKtJnu+3ViIibkGGKoBLADxrZs+b2dsA7gFwpec6iYg4hRqqgwBebvj4lahMRCRooYYqY8qs6RPINSTHSY4fPnx4lqolIpIu1FB9BcCZDR+fAeBg4yeY2SYzGzaz4QULFsxq5UREkoQaqj8CcC7Js0meBOCPATzguU4iIk5BXqdiZsdJfg7ANgC9AO4ws32eqyUi4hRkqAKAmT0M4GHf9RARaUeo3X8RkY6kUBURKVCw3X8p39iuCWzcdgAHjxzF6QP9GF21WFdOi+SkUK2osV0TuGnrXhw9NgkAmDhyFDdt3QsAClaRHNT9r6iN2w5MBWrd0WOT2LjtgKcaiXQHhWpFHTxytK1yEclGoVpRpw/0t1UuItkoVCtqdNVi9Pf1NpX19/VidNViTzUS6Q6aqKqo+mSUZv9FiqVQrbCRoUGFqEjB1P0XESmQQlVEpEAKVRGRAmlMtcK0TVWkeArVitI2VZFyqPtfUdqmKlIOhWpFaZuqSDkUqhWlbaoi5VCoVpS2qYqUQxNVFaVtqiLlUKhWmLapihRP3X8RkQIpVEVECqRQFREpkEJVRKRAClURkQIpVEVECqRQFREpkEJVRKRAClURkQIFF6ok15GcILk7+vMx33USEckq1G2qt5nZ3/iuhIhIu4JrqYqIdLJQQ/VzJJ8ieQfJeb4rIyKSFc1s9r8p+T0A74156L8CeALAzwAYgK8AWGhmfxrzNdYAWBN9+PsAni6ntoX4PdT+TqFS/fIJuX4h1w0Iv36LzeyUdp7gJVSzInkWgAfN7PcdnzduZsOzUqkZUP3yUf1mLuS6Ad1Zv+C6/yQXNnx4FcJugYqINAlx9v+vSS5Drfv/AoDP+K2OiEh2wYWqmX1qBk/bVHhFiqX65aP6zVzIdQO6sH5Bj6mKiHSa4MZURUQ6WdeEaqjbW0leTvIAyWdJrvVdn1YkXyC5N3rNxgOozx0kXyP5dEPZfJKPkHwmeutl7XJC3YL5uSN5JsnHSe4nuY/k56PyUF6/pPoF8RqSfBfJJ0nuier35aj8bJI7otdvM8mTUr9Ot3T/Sa4D8FZI21tJ9gL4VwAfBvAKgB8BWG1m/+K1Yg1IvgBg2MyCWCtI8g8AvAXg/9SX0pH8awCvm9mG6D+meWb2F4HUbR0C+bmLVs4sNLMfkzwFwE4AIwD+BGG8fkn1+48I4DUkSQBzzewtkn0Afgjg8wBuBLDVzO4h+bcA9pjZN5K+Tte0VAN1CYBnzex5M3sbwD0ArvRcp6CZ2Q8AvN5SfCWAO6P370TtF3HWJdQtGGZ2yMx+HL3/JoD9AAYRzuuXVL8gWM1b0Yd90R8DsBLAfVG58/XrtlANbXvrIICXGz5+BQH9EEUMwD+R3BntUgvRaWZ2CKj9YgJ4j+f6tArt566+cWYIwA4E+Pq11A8I5DUk2UtyN4DXADwC4DkAR8zsePQpzt/hjgpVkt8j+XTMnysBfAPAOQCWATgE4GteK1tItyb3AAAChElEQVTDmLLQxltWmNlFAD4K4LNRF1eyC+7njuS7AWwB8AUz+6Xv+rSKqV8wr6GZTZrZMgBnoNbT/EDcp6V9jeDWqaYxsw9l+TyS3wTwYMnVyeIVAGc2fHwGgIOe6hLLzA5Gb18jeT9qP0g/8FuraV4ludDMDkXjcq/5rlCdmb1afz+En7toLHALgG+Z2daoOJjXL65+ob2GAGBmR0h+H8ByAAMk50StVefvcEe1VNMEur31RwDOjWYPTwLwxwAe8FynKSTnRhMGIDkXwEcQxuvW6gEAN0Tv3wDgux7r0iSkn7toouV2APvN7OsNDwXx+iXVL5TXkOQCkgPR+/0APoTauO/jAD4ZfZrz9eum2f//i1r3YWp7a30cyadoecj/ANAL4A4zu9VzlaaQfD+A+6MP5wD4tu/6kbwbwKWonV70KoBbAIwB+A6ARQBeAnCtmc36hFFC3S5FID93JP89gP8HYC+AE1Hxl1Abtwzh9Uuq32oE8BqSXIraRFQvag3O75jZX0a/J/cAmA9gF4Drzey3iV+nW0JVRCQEXdP9FxEJgUJVRKRAClURkQIpVEVECqRQFREpkEJVRKRAClURkQIpVKUSSH6lfn5n9PGtJP+zzzpJd9Lif6mE6FSkrWZ2EckeAM8AuMTMfu61YtJ1OupAFZGZMrMXSP6c5BCA0wDsUqBKGRSqUiX/C7VT8N8L4A6/VZFupe6/VEZ0Uthe1E50P9fMJj1XSbqQWqpSGWb2NsnHUTvJXYEqpVCoSmVEE1TLAVzruy7SvbSkSiqB5PkAngXwqJk947s+0r00pioiUiC1VEVECqRQFREpkEJVRKRAClURkQIpVEVECqRQFREp0P8HzZs+2NjlVLAAAAAASUVORK5CYII=\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x130d3919b38>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"plt.figure(figsize=(5, 5))\n",
"plt.scatter(y, y_pred)\n",
"plt.xlim(-5, 30)\n",
"plt.ylim(-5, 30)\n",
"plt.xlabel('y')\n",
"plt.ylabel(r'$\\hat{y}$')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 性能の評価"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"予測値を平均絶対誤差で評価してみましょう."
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"誤差: 1.577885656351277\n"
]
}
],
"source": [
"from sklearn.metrics import mean_absolute_error\n",
"print('誤差: ', mean_absolute_error(y, y_pred))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 過学習"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"データを,訓練データと評価データに分けましょう. \n",
"ここで,評価データは予測モデルを実環境で運用した際に得られる未知のデータを想定しており, \n",
"<font color='red'>評価データに対する性能が良くなるよう試行錯誤するのは適切ではない</font>ことに注意してください."
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.model_selection import train_test_split\n",
"X_train, X_test, y_train, y_test =\\\n",
" train_test_split(X, y, test_size=0.4, random_state=1,\n",
" shuffle=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"過学習についてみていきましょう. \n",
"適当な回帰アルゴリズムを選び,<font color=\"red\">訓練データに対する予測性能が最も良くなるように</font>ハイパーパラメータをチューニングしてみましょう. \n",
"そのあと,作成したモデルで評価データに対する予測性能をみてみましょう."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"線形回帰の例:"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"訓練データに対する誤差: 1.5448696958937491\n",
"評価データに対する誤差: 1.6071922396127398\n"
]
}
],
"source": [
"from sklearn.linear_model import LinearRegression\n",
"from sklearn.metrics import mean_absolute_error\n",
"regr = LinearRegression()\n",
"regr.fit(X_train, y_train)\n",
"y_pred = regr.predict(X_train)\n",
"print('訓練データに対する誤差: ', mean_absolute_error(y_train, y_pred))\n",
"y_pred = regr.predict(X_test)\n",
"print('評価データに対する誤差: ', mean_absolute_error(y_test, y_pred))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"回帰木の例:"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"訓練データに対する誤差: 0.0\n",
"評価データに対する誤差: 2.065868263473054\n"
]
}
],
"source": [
"from sklearn.tree import DecisionTreeRegressor\n",
"from sklearn.metrics import mean_absolute_error\n",
"regr = DecisionTreeRegressor(max_depth=None, random_state=1)\n",
"regr.fit(X_train, y_train)\n",
"y_pred = regr.predict(X_train)\n",
"print('訓練データに対する誤差: ', mean_absolute_error(y_train, y_pred))\n",
"y_pred = regr.predict(X_test)\n",
"print('評価データに対する誤差: ', mean_absolute_error(y_test, y_pred))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"サポートベクターマシンの例:"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"訓練データに対する誤差: 0.09793777638387188\n",
"評価データに対する誤差: 1.9349815512908115\n"
]
}
],
"source": [
"from sklearn.svm import SVR\n",
"from sklearn.metrics import mean_absolute_error\n",
"regr = SVR(kernel='rbf', C=1000, gamma=1000)\n",
"regr.fit(X_train, y_train)\n",
"y_pred = regr.predict(X_train)\n",
"print('訓練データに対する誤差: ', mean_absolute_error(y_train, y_pred))\n",
"y_pred = regr.predict(X_test)\n",
"print('評価データに対する誤差: ', mean_absolute_error(y_test, y_pred))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# グリッドサーチと交差検証"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Grid searchとcross validationをやってみましょう"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.model_selection import GridSearchCV, KFold"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"GridSearchCV(cv=KFold(n_splits=5, random_state=1, shuffle=True),\n",
" error_score='raise',\n",
" estimator=SVR(C=1.0, cache_size=200, coef0=0.0, degree=3, epsilon=0.1, gamma='auto',\n",
" kernel='rbf', max_iter=-1, shrinking=True, tol=0.001, verbose=False),\n",
" fit_params=None, iid=True, n_jobs=4,\n",
" param_grid={'C': [30, 100, 300], 'gamma': [0.1, 0.3, 1]},\n",
" pre_dispatch='2*n_jobs', refit=True, return_train_score='warn',\n",
" scoring='neg_mean_absolute_error', verbose=0)"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.svm import SVR\n",
"p_grid = {'C': [30, 100, 300], 'gamma': [0.1, 0.3, 1]}\n",
"svr = SVR(kernel=\"rbf\")\n",
"cv = KFold(n_splits=5, shuffle=True, random_state=1)\n",
"regr = GridSearchCV(estimator=svr, param_grid=p_grid, cv=cv, n_jobs=4,\n",
" scoring='neg_mean_absolute_error')\n",
"regr.fit(X_train, y_train)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"交差検証誤差と,スコアの最も良かったハイパーパラメータを確認します. \n",
"このとき,選ばれたパラメータがグリッドサーチの範囲の端の値だった場合, \n",
"グリッドサーチの範囲を変更して,再度交差検証をやり直しましょう."
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"交差検証誤差: 1.4613603393492052\n",
"最小誤差時のハイパーパラメータ: {'C': 100, 'gamma': 0.3}\n"
]
}
],
"source": [
"print('交差検証誤差: ', -regr.best_score_)\n",
"print('最小誤差時のハイパーパラメータ: ', regr.best_params_)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"最終的なモデルは,全データを使い,上記パラメータで学習を行います."
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [],
"source": [
"regr = SVR(kernel=\"rbf\", gamma=0.3, C=100)\n",
"rtn = regr.fit(X_train, y_train)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"それでは,評価データに対する誤差をみてみましょう \n",
"しつこいようですが,この値は本来実運用を始めるまで知ることができない値です."
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"評価データに対する誤差: 1.4961550762118891\n"
]
}
],
"source": [
"y_pred = regr.predict(X_test)\n",
"print('評価データに対する誤差: ', mean_absolute_error(y_test, y_pred))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 正規化"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"アルゴリズムによっては,説明変数のスケールが影響を与えること \n",
"(影響をほとんど受けないアルゴリズムもあること) \n",
"正規化によりその影響を抑えられることを見てみましょう.\n",
"- 影響を受けやすいアルゴリズムの例:K近傍法\n",
"- 影響を受けにくいアルゴリズムの例:決定木"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"最小誤差時のハイパーパラメータ: {'n_neighbors': 16}\n",
"元データでの誤差: 1.5114770459081837\n"
]
}
],
"source": [
"from sklearn.neighbors import KNeighborsRegressor\n",
"p_grid = {'n_neighbors': [15, 16, 17, 18, 19]}\n",
"knr = KNeighborsRegressor()\n",
"cv = KFold(n_splits=5, shuffle=True, random_state=1)\n",
"regr = GridSearchCV(estimator=knr, param_grid=p_grid, cv=cv, n_jobs=4,\n",
" scoring='neg_mean_absolute_error')\n",
"regr.fit(X_train, y_train)\n",
"print('最小誤差時のハイパーパラメータ: ', regr.best_params_)\n",
"print('元データでの誤差: ', -regr.best_score_)"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"最小誤差時のハイパーパラメータ: {'n_neighbors': 20}\n",
"スケールを変更したときの誤差: 1.988622754491018\n"
]
}
],
"source": [
"X_train_tmp = X_train.copy(deep=True)\n",
"X_train_tmp['shucked_weight'] *= 1000\n",
"p_grid = {'n_neighbors': [19, 20, 21, 22, 23]}\n",
"cv = KFold(n_splits=5, shuffle=True, random_state=1)\n",
"regr = GridSearchCV(estimator=knr, param_grid=p_grid, cv=cv, n_jobs=4,\n",
" scoring='neg_mean_absolute_error')\n",
"regr.fit(X_train_tmp, y_train)\n",
"print('最小誤差時のハイパーパラメータ: ', regr.best_params_)\n",
"print('スケールを変更したときの誤差: ', -regr.best_score_)"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"最小誤差時のハイパーパラメータ: {'n_neighbors': 15}\n",
"基準化を行ったときの誤差: 1.535249500998004\n"
]
}
],
"source": [
"from sklearn.preprocessing import StandardScaler\n",
"X_train_scale = StandardScaler().fit_transform(X_train_tmp)\n",
"p_grid = {'n_neighbors': [14, 15, 16, 17, 18]}\n",
"cv = KFold(n_splits=5, shuffle=True, random_state=1)\n",
"regr = GridSearchCV(estimator=knr, param_grid=p_grid, cv=cv, n_jobs=4,\n",
" scoring='neg_mean_absolute_error')\n",
"regr.fit(X_train_scale, y_train)\n",
"print('最小誤差時のハイパーパラメータ: ', regr.best_params_)\n",
"print('基準化を行ったときの誤差: ', -regr.best_score_)"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"最小誤差時のハイパーパラメータ: {'max_depth': 6}\n",
"元データでの誤差: 1.6090607846224858\n"
]
}
],
"source": [
"from sklearn.tree import DecisionTreeRegressor\n",
"p_grid = {'max_depth': [3, 4, 5, 6, 7]}\n",
"dtr = DecisionTreeRegressor(random_state=1)\n",
"cv = KFold(n_splits=5, shuffle=True, random_state=1)\n",
"regr = GridSearchCV(estimator=dtr, param_grid=p_grid, cv=cv, n_jobs=4,\n",
" scoring='neg_mean_absolute_error')\n",
"regr.fit(X_train, y_train)\n",
"print('最小誤差時のハイパーパラメータ: ', regr.best_params_)\n",
"print('元データでの誤差: ', -regr.best_score_)"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"最小誤差時のハイパーパラメータ: {'max_depth': 6}\n",
"スケールを変更したときの誤差: 1.6092798830866515\n"
]
}
],
"source": [
"X_train_tmp = X_train.copy(deep=True)\n",
"X_train_tmp['shucked_weight'] *= 1000\n",
"p_grid = {'max_depth': [3, 4, 5, 6, 7]}\n",
"cv = KFold(n_splits=5, shuffle=True, random_state=1)\n",
"regr = GridSearchCV(estimator=dtr, param_grid=p_grid, cv=cv, n_jobs=4,\n",
" scoring='neg_mean_absolute_error')\n",
"regr.fit(X_train_tmp, y_train)\n",
"print('最小誤差時のハイパーパラメータ: ', regr.best_params_)\n",
"print('スケールを変更したときの誤差: ', -regr.best_score_)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# バイアスとバリアンス"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"最後に,学習曲線を作成してみましょう."
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"from sklearn.model_selection import learning_curve\n",
"def plot_learning_curve(estimator, title, X, y, ylim=None, cv=None, n_jobs=1):\n",
" plt.figure()\n",
" plt.title(title)\n",
" if ylim is not None:\n",
" plt.ylim(*ylim)\n",
" plt.xlabel(\"Training examples\")\n",
" plt.ylabel(\"Score\")\n",
" train_sizes, train_scores, test_scores = learning_curve(\n",
" estimator, X, y, cv=cv, n_jobs=n_jobs, train_sizes=np.linspace(.1, 1.0, 5),\n",
" scoring='neg_mean_absolute_error')\n",
" train_scores_mean = np.mean(train_scores, axis=1)\n",
" train_scores_std = np.std(train_scores, axis=1)\n",
" test_scores_mean = np.mean(test_scores, axis=1)\n",
" test_scores_std = np.std(test_scores, axis=1)\n",
"\n",
" plt.fill_between(train_sizes, train_scores_mean - train_scores_std,\n",
" train_scores_mean + train_scores_std, alpha=0.1, color=\"r\")\n",
" plt.fill_between(train_sizes, test_scores_mean - test_scores_std,\n",
" test_scores_mean + test_scores_std, alpha=0.1, color=\"g\")\n",
" plt.plot(train_sizes, train_scores_mean, 'o-', color=\"r\", label=\"Training score\")\n",
" plt.plot(train_sizes, test_scores_mean, 'o-', color=\"g\", label=\"Cross-validation score\")\n",
" plt.legend(loc=\"best\")"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x130d394a8d0>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"p_grid = {'C': [10, 22, 46, 100, 215, 464],\n",
" 'gamma': [0.1, 0.22, 0.46, 1, 2.15]}\n",
"svr = SVR(kernel=\"rbf\")\n",
"cv = KFold(n_splits=5, shuffle=True, random_state=1)\n",
"regr = GridSearchCV(estimator=svr, param_grid=p_grid, cv=cv,\n",
" scoring='neg_mean_absolute_error')\n",
"plot_learning_curve(regr, 'Learning Curves (SVM, RBF kernel)',\n",
" X_train, y_train, n_jobs=4, ylim=(-1.7,-1.35))"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x130d3a05a20>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"regr = LinearRegression()\n",
"plot_learning_curve(regr, 'Learning Curves (Linear regression)',\n",
" X_train, y_train, n_jobs=4, ylim=(-1.7,-1.35))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"サポートベクターマシンを使ったモデルはハイバリアンスな傾向がありそうです. \n",
"しかしデータを増やしても交差検証誤差の減少は望めなそうです. \n",
"そこで説明変数を削減してみます."
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"交差検証誤差: 1.4475913767431021\n",
"最小誤差時のハイパーパラメータ: {'C': 10, 'gamma': 3}\n"
]
}
],
"source": [
"X_train_sel = X_train.drop(['length', 'sex_M'], axis=1)\n",
"p_grid = {'C': [1, 3, 10, 30, 100], 'gamma': [0.3, 1, 3, 10, 30]}\n",
"svr = SVR(kernel=\"rbf\")\n",
"cv = KFold(n_splits=5, shuffle=True, random_state=1)\n",
"regr = GridSearchCV(estimator=svr, param_grid=p_grid, cv=cv, n_jobs=4,\n",
" scoring='neg_mean_absolute_error')\n",
"regr.fit(X_train_sel, y_train)\n",
"print('交差検証誤差: ', -regr.best_score_)\n",
"print('最小誤差時のハイパーパラメータ: ', regr.best_params_)"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"評価データに対する誤差: 1.4836160865447574\n"
]
}
],
"source": [
"regr = SVR(kernel=\"rbf\", gamma=3, C=10)\n",
"rtn = regr.fit(X_train_sel, y_train)\n",
"X_test_sel = X_test.drop(['length', 'sex_M'], axis=1)\n",
"y_pred = regr.predict(X_test_sel)\n",
"print('評価データに対する誤差: ', mean_absolute_error(y_test, y_pred))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"特徴選択をするまえの交差検証誤差は1.461, \n",
"評価データに対する誤差は1.496でした."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.3"
}
},
"nbformat": 4,
"nbformat_minor": 1
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment