Skip to content

Instantly share code, notes, and snippets.

@yamasakih
Last active December 16, 2018 13:00
Show Gist options
  • Save yamasakih/f3a0c77634cc95d467dc08db719024b0 to your computer and use it in GitHub Desktop.
Save yamasakih/f3a0c77634cc95d467dc08db719024b0 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## LogS QSPR with TPOT"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import pandas as pd\n",
"from tpot import TPOTRegressor\n",
"\n",
"from rdkit import Chem\n",
"from rdkit.Chem import Descriptors\n",
"from rdkit.ML.Descriptors import MoleculeDescriptors\n",
"\n",
"from sklearn.metrics import mean_absolute_error\n",
"from sklearn.model_selection import cross_validate, KFold, train_test_split\n",
"\n",
"\n",
"%matplotlib inline"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 1. Load dataset. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"1290個の水溶解度のデータセットをロードする。ダウンロードは[こちら](https://datachemeng.com/pythonassignment/)からさせていただいた。 \n",
"Reference \n",
"`T. J. Hou, K. Xia, W. Zhang and X. J. Xu, ADME Evaluation in Drug Discovery. 4. Prediction of Aqueous Solubility Based on Atom Contribution Approach, J. Chem. Inf. Comput. Sci., 44(1), 266–275, 2004.`"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"RDKitで読み込めた化合物は全部で1290個である。\n"
]
}
],
"source": [
"supp = Chem.SDMolSupplier('logSdataset1290_2d.sdf')\n",
"mols = [mol for mol in supp if mol]\n",
"print(f'RDKitで読み込めた化合物は全部で{len(mols)}個である。')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 2. Make y. "
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'CAS_Number': '60-35-5', 'logS': 1.58}"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"mol = mols[0]\n",
"mol.GetPropsAsDict()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"目的変数 logS は logS という Prop に入っている。"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'1.58'"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"mol.GetProp('logS')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"文字列なのに注意し、すべての Mol オブジェクトから logS のデータを取り出し数値データに変換して numpy.array である変数 y に保存する。"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([1.58, 1.34, 1.22, 1.15, 1.12])"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y = np.array([float(mol.GetProp('logS')) for mol in mols])\n",
"y[:5]"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(array([ 2., 4., 27., 41., 81., 204., 309., 361., 191., 70.]),\n",
" array([-11.62, -10.3 , -8.98, -7.66, -6.34, -5.02, -3.7 , -2.38,\n",
" -1.06, 0.26, 1.58]),\n",
" <a list of 10 Patch objects>)"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAAD8CAYAAAB5Pm/hAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4wLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvqOYd8AAAEaRJREFUeJzt3X+MZWV9x/H3pwuiqVagO9B1d+2gXVvRxoWMK4ltqmD5ZeNiWgz8IcSSrFpspNFWwCZqUhKoP2htWpK1UKGhIhUoG8FWRKi1CeCAy8KyUlZBWXfLjj9ACBGz8O0f92ydLrNz7/y8s0/fr+TmnvOc55zzHZj7uWee+9yzqSokSe36hWEXIElaWAa9JDXOoJekxhn0ktQ4g16SGmfQS1LjDHpJapxBL0mNM+glqXEHDbsAgOXLl9fo6Oiwy5CkA8rdd9/9g6oa6ddvSQT96Ogo4+Pjwy5Dkg4oSb47SD+HbiSpcQa9JDXOoJekxhn0ktQ4g16SGmfQS1LjDHpJapxBL0mNM+glqXFL4puxkpaO0fNvGtq5H7n4rUM7d8u8opekxvUN+iQvTHJXknuTbE3ysa79s0keTrK5e6zt2pPk00m2J9mS5NiF/iEkSfs3yNDNM8DxVfVUkoOBryf5UrftT6vqC/v0PwVY0z3eAFzWPUuShqDvFX31PNWtHtw9appd1gNXdfvdARyaZMXcS5UkzcZAY/RJliXZDOwGbqmqO7tNF3XDM5cmOaRrWwk8Omn3HV2bJGkIBgr6qnq2qtYCq4B1SV4LXAD8BvB64HDgQ133THWIfRuSbEgynmR8YmJiVsVLkvqb0aybqnocuB04uap2dcMzzwD/AKzruu0AVk/abRWwc4pjbayqsaoaGxnp+w+kSJJmaZBZNyNJDu2WXwS8BfjW3nH3JAFOA+7vdtkEnNXNvjkOeKKqdi1I9ZKkvgaZdbMCuDLJMnpvDNdW1ReTfDXJCL2hms3Ae7r+NwOnAtuBp4F3zX/ZkqRB9Q36qtoCHDNF+/H76V/AuXMvTZI0H/xmrCQ1zqCXpMYZ9JLUOINekhpn0EtS4wx6SWqcQS9JjTPoJalxBr0kNc6gl6TGGfSS1DiDXpIaZ9BLUuMMeklqnEEvSY0z6CWpcQa9JDXOoJekxhn0ktS4Qf5xcElDMHr+TcMuQY3oe0Wf5IVJ7kpyb5KtST7WtR+V5M4kDyX5fJIXdO2HdOvbu+2jC/sjSJKmM8jQzTPA8VX1OmAtcHKS44BLgEurag3wY+Ccrv85wI+r6teAS7t+kqQh6Rv01fNUt3pw9yjgeOALXfuVwGnd8vpunW77CUkybxVLkmZkoA9jkyxLshnYDdwCfBt4vKr2dF12ACu75ZXAowDd9ieAX57PoiVJgxso6Kvq2apaC6wC1gGvnqpb9zzV1Xvt25BkQ5LxJOMTExOD1itJmqEZTa+sqseB24HjgEOT7J21swrY2S3vAFYDdNtfCvxoimNtrKqxqhobGRmZXfWSpL4GmXUzkuTQbvlFwFuAbcBtwB903c4GbuyWN3XrdNu/WlXPu6KXJC2OQebRrwCuTLKM3hvDtVX1xSQPANck+Qvgm8DlXf/LgX9Msp3elfwZC1C3JGlAfYO+qrYAx0zR/h164/X7tv8UOH1eqpMkzZm3QJCkxhn0ktQ4g16SGmfQS1LjDHpJapxBL0mNM+glqXEGvSQ1zqCXpMYZ9JLUOINekhpn0EtS4wx6SWqcQS9JjTPoJalxBr0kNc6gl6TGGfSS1DiDXpIaZ9BLUuP6Bn2S1UluS7ItydYk7+/aP5rk+0k2d49TJ+1zQZLtSR5MctJC/gCSpOkdNECfPcAHquqeJC8B7k5yS7ft0qr6xOTOSY4GzgBeA7wM+EqSV1XVs/NZuCRpMH2v6KtqV1Xd0y0/CWwDVk6zy3rgmqp6pqoeBrYD6+ajWEnSzM1ojD7JKHAMcGfX9L4kW5JckeSwrm0l8Oik3XYwxRtDkg1JxpOMT0xMzLhwSdJgBg76JC8GrgPOq6qfAJcBrwTWAruAT+7tOsXu9byGqo1VNVZVYyMjIzMuXJI0mIGCPsnB9EL+6qq6HqCqHquqZ6vqOeAz/Hx4ZgewetLuq4Cd81eyJGkmBpl1E+ByYFtVfWpS+4pJ3d4O3N8tbwLOSHJIkqOANcBd81eyJGkmBpl180bgncB9STZ3bRcCZyZZS29Y5hHg3QBVtTXJtcAD9GbsnOuMG0kanr5BX1VfZ+px95un2eci4KI51CVJmid+M1aSGmfQS1LjDHpJapxBL0mNM+glqXEGvSQ1zqCXpMYZ9JLUOINekhpn0EtS4wx6SWqcQS9JjTPoJalxBr0kNc6gl6TGGfSS1DiDXpIaZ9BLUuMMeklqXN+gT7I6yW1JtiXZmuT9XfvhSW5J8lD3fFjXniSfTrI9yZYkxy70DyFJ2r9Bruj3AB+oqlcDxwHnJjkaOB+4tarWALd26wCnAGu6xwbgsnmvWpI0sIP6daiqXcCubvnJJNuAlcB64E1dtyuB24EPde1XVVUBdyQ5NMmK7jjSAWf0/JuGXYI0JzMao08yChwD3AkcuTe8u+cjum4rgUcn7baja5MkDcHAQZ/kxcB1wHlV9ZPpuk7RVlMcb0OS8STjExMTg5YhSZqhgYI+ycH0Qv7qqrq+a34syYpu+wpgd9e+A1g9afdVwM59j1lVG6tqrKrGRkZGZlu/JKmPQWbdBLgc2FZVn5q0aRNwdrd8NnDjpPazutk3xwFPOD4vScPT98NY4I3AO4H7kmzu2i4ELgauTXIO8D3g9G7bzcCpwHbgaeBd81qxJGlGBpl183WmHncHOGGK/gWcO8e6JEnzZJAreklaFMOayvrIxW8dynkXi7dAkKTGGfSS1DiDXpIaZ9BLUuMMeklqnEEvSY0z6CWpcQa9JDXOoJekxhn0ktQ4g16SGmfQS1LjDHpJapxBL0mNM+glqXEGvSQ1zqCXpMYZ9JLUOINekhrXN+iTXJFkd5L7J7V9NMn3k2zuHqdO2nZBku1JHkxy0kIVLkkazCBX9J8FTp6i/dKqWts9bgZIcjRwBvCabp+/S7JsvoqVJM1c36Cvqq8BPxrweOuBa6rqmap6GNgOrJtDfZKkOZrLGP37kmzphnYO69pWAo9O6rOja5MkDclsg/4y4JXAWmAX8MmuPVP0rakOkGRDkvEk4xMTE7MsQ5LUz6yCvqoeq6pnq+o54DP8fHhmB7B6UtdVwM79HGNjVY1V1djIyMhsypAkDWBWQZ9kxaTVtwN7Z+RsAs5IckiSo4A1wF1zK1GSNBcH9euQ5HPAm4DlSXYAHwHelGQtvWGZR4B3A1TV1iTXAg8Ae4Bzq+rZhSldkjSIvkFfVWdO0Xz5NP0vAi6aS1GSpPnjN2MlqXEGvSQ1zqCXpMYZ9JLUOINekhpn0EtS4wx6SWqcQS9JjTPoJalxBr0kNc6gl6TGGfSS1DiDXpIaZ9BLUuMMeklqnEEvSY0z6CWpcQa9JDXOoJekxhn0ktS4vkGf5Ioku5PcP6nt8CS3JHmoez6sa0+STyfZnmRLkmMXsnhJUn+DXNF/Fjh5n7bzgVurag1wa7cOcAqwpntsAC6bnzIlSbPVN+ir6mvAj/ZpXg9c2S1fCZw2qf2q6rkDODTJivkqVpI0c7Mdoz+yqnYBdM9HdO0rgUcn9dvRtT1Pkg1JxpOMT0xMzLIMSVI/8/1hbKZoq6k6VtXGqhqrqrGRkZF5LkOStNdsg/6xvUMy3fPurn0HsHpSv1XAztmXJ0maq9kG/Sbg7G75bODGSe1ndbNvjgOe2DvEI0kajoP6dUjyOeBNwPIkO4CPABcD1yY5B/gecHrX/WbgVGA78DTwrgWoWZI0A32DvqrO3M+mE6boW8C5cy1KkjR//GasJDXOoJekxhn0ktQ4g16SGmfQS1Lj+s66kZaC0fNvGnYJ0gHLK3pJapxBL0mNM+glqXEGvSQ1zqCXpMYZ9JLUOKdXSvp/b5jTdx+5+K0Lfg6v6CWpcQa9JDXOoJekxhn0ktQ4g16SGmfQS1Lj5jS9MskjwJPAs8CeqhpLcjjweWAUeAR4R1X9eG5lSpJmaz6u6N9cVWuraqxbPx+4tarWALd265KkIVmIoZv1wJXd8pXAaQtwDknSgOYa9AV8OcndSTZ0bUdW1S6A7vmIOZ5DkjQHc70FwhurameSI4Bbknxr0B27N4YNAC9/+cvnWIYkaX/mdEVfVTu7593ADcA64LEkKwC659372XdjVY1V1djIyMhcypAkTWPWQZ/kF5O8ZO8ycCJwP7AJOLvrdjZw41yLlCTN3lyGbo4Ebkiy9zj/VFX/muQbwLVJzgG+B5w+9zIlSbM166Cvqu8Ar5ui/YfACXMpSpI0f7wfvWZkmPftljQ73gJBkhpn0EtS4wx6SWqcQS9JjTPoJalxBr0kNc6gl6TGOY/+AORcdkkz4RW9JDXOoJekxhn0ktQ4g16SGmfQS1LjDHpJapxBL0mNM+glqXEGvSQ1zqCXpMYZ9JLUuAUL+iQnJ3kwyfYk5y/UeSRJ01uQoE+yDPhb4BTgaODMJEcvxLkkSdNbqLtXrgO2V9V3AJJcA6wHHlig8w2Fd5GUdCBYqKBfCTw6aX0H8IaFOJFhK0nTW6igzxRt9X86JBuADd3qU0ke3M+xlgM/mMfaFot1Ly7rXlzWPU9yyUDd9lf3rw6y80IF/Q5g9aT1VcDOyR2qaiOwsd+BkoxX1dj8lrfwrHtxWffisu7FNde6F2rWzTeANUmOSvIC4Axg0wKdS5I0jQW5oq+qPUneB/wbsAy4oqq2LsS5JEnTW7B/M7aqbgZunodD9R3eWaKse3FZ9+Ky7sU1p7pTVf17SZIOWN4CQZIatySDPsnpSbYmeS7J2KT2301yd5L7uufjh1nnvvZXd7ftgu52EA8mOWlYNQ4iydokdyTZnGQ8ybph1zSoJH/c/TfemuQvh13PTCT5YJJKsnzYtQwiyceTfCvJliQ3JDl02DVN50C8LUuS1UluS7Kt+51+/6wOVFVL7gG8Gvh14HZgbFL7McDLuuXXAt8fdq0D1n00cC9wCHAU8G1g2bDrnebn+DJwSrd8KnD7sGsasO43A18BDunWjxh2TTOofTW9yQvfBZYPu54Baz4ROKhbvgS4ZNg1TVPrsu519wrgBd3r8ehh1zVA3SuAY7vllwD/NZu6l+QVfVVtq6rnfYGqqr5ZVXvn428FXpjkkMWtbv/2Vze92z9cU1XPVNXDwHZ6t4lYqgr4pW75pezzHYgl7L3AxVX1DEBV7R5yPTNxKfBn7PPFwqWsqr5cVXu61TvofV9mqfrf27JU1c+AvbdlWdKqaldV3dMtPwlso3fngRlZkkE/oN8Hvrn3Rb3ETXVLiBn/z1pE5wEfT/Io8AnggiHXM6hXAb+d5M4k/57k9cMuaBBJ3kbvr9N7h13LHPwh8KVhFzGNA+01+DxJRumNatw5030XbHplP0m+AvzKFJs+XFU39tn3NfT+VDxxIWrrc+7Z1N33lhCLbbqfAzgB+JOqui7JO4DLgbcsZn3706fug4DDgOOA1wPXJnlFdX/3DlOfui9kCL/Lgxjk9z3Jh4E9wNWLWdsMLbnX4EwkeTFwHXBeVf1kpvsPLeiralbBkWQVcANwVlV9e36r6m+Wdfe9JcRim+7nSHIVsPdDn38G/n5RihpAn7rfC1zfBftdSZ6jd4+QicWqb3/2V3eS36T3uc29SaD3u3FPknVV9d+LWOKU+v2+Jzkb+D3ghKXwhjqNJfcaHFSSg+mF/NVVdf1sjnFADd10n+rfBFxQVf857HpmYBNwRpJDkhwFrAHuGnJN09kJ/E63fDzw0BBrmYl/oVcvSV5F70O3JXUDq31V1X1VdURVjVbVKL1AOnYphHw/SU4GPgS8raqeHnY9fRyQt2VJ793/cmBbVX1q1sdZim/CSd4O/A0wAjwObK6qk5L8Ob3x4snBc+JS+dBtf3V32z5MbxxzD70/v5bseGaS3wL+mt5ffD8F/qiq7h5uVf11L+ArgLXAz4APVtVXh1vVzCR5hN6MrSX9BgWQZDu9mWQ/7JruqKr3DLGkaSU5Ffgrfn5blouGXFJf3WvxP4D7gOe65gurd+eBwY+zFINekjR/DqihG0nSzBn0ktQ4g16SGmfQS1LjDHpJapxBL0mNM+glqXEGvSQ17n8Av1+rG0gccOgAAAAASUVORK5CYII=\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.hist(y)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"少し偏ってるけどだいたい正規分布になっている。"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 3. Make X. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"RDKit で計算可能な記述子を計算する。"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['MaxEStateIndex',\n",
" 'MinEStateIndex',\n",
" 'MaxAbsEStateIndex',\n",
" 'MinAbsEStateIndex',\n",
" 'qed']"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"descriptor_names = [d[0] for d in Descriptors._descList]\n",
"descriptor_names[:5]"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"全部で200個ある\n"
]
}
],
"source": [
"print(f'全部で{len(descriptor_names)}個ある')"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"calc = MoleculeDescriptors.MolecularDescriptorCalculator(descriptor_names)\n",
"descs = [calc.CalcDescriptors(mol) for mol in mols]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"人間の目で見た時にわかりやすくしたいならカラム名を追加してデータフレームにすれば良い。"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>MaxEStateIndex</th>\n",
" <th>MinEStateIndex</th>\n",
" <th>MaxAbsEStateIndex</th>\n",
" <th>MinAbsEStateIndex</th>\n",
" <th>qed</th>\n",
" <th>MolWt</th>\n",
" <th>HeavyAtomMolWt</th>\n",
" <th>ExactMolWt</th>\n",
" <th>NumValenceElectrons</th>\n",
" <th>NumRadicalElectrons</th>\n",
" <th>...</th>\n",
" <th>fr_sulfide</th>\n",
" <th>fr_sulfonamd</th>\n",
" <th>fr_sulfone</th>\n",
" <th>fr_term_acetylene</th>\n",
" <th>fr_tetrazole</th>\n",
" <th>fr_thiazole</th>\n",
" <th>fr_thiocyan</th>\n",
" <th>fr_thiophene</th>\n",
" <th>fr_unbrch_alkane</th>\n",
" <th>fr_urea</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>9.222222</td>\n",
" <td>-0.333333</td>\n",
" <td>9.222222</td>\n",
" <td>0.333333</td>\n",
" <td>0.401031</td>\n",
" <td>59.068</td>\n",
" <td>54.028</td>\n",
" <td>59.037114</td>\n",
" <td>24</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>4.597222</td>\n",
" <td>1.652778</td>\n",
" <td>4.597222</td>\n",
" <td>1.652778</td>\n",
" <td>0.273315</td>\n",
" <td>46.073</td>\n",
" <td>40.025</td>\n",
" <td>46.053098</td>\n",
" <td>20</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>9.000000</td>\n",
" <td>-0.833333</td>\n",
" <td>9.000000</td>\n",
" <td>0.833333</td>\n",
" <td>0.429883</td>\n",
" <td>60.052</td>\n",
" <td>56.020</td>\n",
" <td>60.021129</td>\n",
" <td>24</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>3.222222</td>\n",
" <td>1.250000</td>\n",
" <td>3.222222</td>\n",
" <td>1.250000</td>\n",
" <td>0.434794</td>\n",
" <td>71.123</td>\n",
" <td>62.051</td>\n",
" <td>71.073499</td>\n",
" <td>30</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>9.229167</td>\n",
" <td>-0.939815</td>\n",
" <td>9.229167</td>\n",
" <td>0.939815</td>\n",
" <td>0.256644</td>\n",
" <td>76.055</td>\n",
" <td>72.023</td>\n",
" <td>76.027277</td>\n",
" <td>30</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>5 rows × 200 columns</p>\n",
"</div>"
],
"text/plain": [
" MaxEStateIndex MinEStateIndex MaxAbsEStateIndex MinAbsEStateIndex \\\n",
"0 9.222222 -0.333333 9.222222 0.333333 \n",
"1 4.597222 1.652778 4.597222 1.652778 \n",
"2 9.000000 -0.833333 9.000000 0.833333 \n",
"3 3.222222 1.250000 3.222222 1.250000 \n",
"4 9.229167 -0.939815 9.229167 0.939815 \n",
"\n",
" qed MolWt HeavyAtomMolWt ExactMolWt NumValenceElectrons \\\n",
"0 0.401031 59.068 54.028 59.037114 24 \n",
"1 0.273315 46.073 40.025 46.053098 20 \n",
"2 0.429883 60.052 56.020 60.021129 24 \n",
"3 0.434794 71.123 62.051 71.073499 30 \n",
"4 0.256644 76.055 72.023 76.027277 30 \n",
"\n",
" NumRadicalElectrons ... fr_sulfide fr_sulfonamd fr_sulfone \\\n",
"0 0 ... 0 0 0 \n",
"1 0 ... 0 0 0 \n",
"2 0 ... 0 0 0 \n",
"3 0 ... 0 0 0 \n",
"4 0 ... 0 0 0 \n",
"\n",
" fr_term_acetylene fr_tetrazole fr_thiazole fr_thiocyan fr_thiophene \\\n",
"0 0 0 0 0 0 \n",
"1 0 0 0 0 0 \n",
"2 0 0 0 0 0 \n",
"3 0 0 0 0 0 \n",
"4 0 0 0 0 0 \n",
"\n",
" fr_unbrch_alkane fr_urea \n",
"0 0 0 \n",
"1 0 0 \n",
"2 0 0 \n",
"3 0 0 \n",
"4 0 1 \n",
"\n",
"[5 rows x 200 columns]"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"descs_df = pd.DataFrame(descs, columns=descriptor_names)\n",
"descs_df.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"今回は Automated ML の引数にするので y 同様 numpy.array に変換し X に保存しておく。 pandas.DataFrame のままでも問題はない。"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[ 9.22222222, -0.33333333, 9.22222222, 0.33333333, 0.40103092],\n",
" [ 4.59722222, 1.65277778, 4.59722222, 1.65277778, 0.27331504],\n",
" [ 9. , -0.83333333, 9. , 0.83333333, 0.42988288],\n",
" [ 3.22222222, 1.25 , 3.22222222, 1.25 , 0.43479446],\n",
" [ 9.22916667, -0.93981481, 9.22916667, 0.93981481, 0.25664361]])"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X = np.array(descs)\n",
"X[:5, :5]"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(1290, 200)"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X.shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"再度確認すると、サンプル数は1290個、説明変数の数は200個である。"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 4. Split dataset to training set and test set."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"1290個のデータのうち1000個を予測モデル構築のための Training set、290個を構築した予測モデルの評価のための Test set とする。 \n",
"今回は再現性を担保するために乱数を固定したいと思う。 train_test_split や TPOT の `random_state` という引数に整数を与えると乱数を固定できる。"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"random_state = 20181216"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"training_X, test_X, training_y, test_y = train_test_split(\n",
" X, y, test_size=290, random_state=random_state)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(1000, 200)"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"training_X.shape"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(290, 200)"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"test_X.shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"1000個と290個に分割できた。ヒストグラムもみてみる。"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(array([ 3., 5., 11., 21., 30., 41., 71., 62., 31., 15.]),\n",
" array([-8.65 , -7.673, -6.696, -5.719, -4.742, -3.765, -2.788, -1.811,\n",
" -0.834, 0.143, 1.12 ]),\n",
" <a list of 10 Patch objects>)"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAeoAAAD8CAYAAAC4lecIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4wLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvqOYd8AAAFm9JREFUeJzt3X+wZGWd3/H3Z0HYRHEBmSEjPzKQoJFNSnDvUqTIJsr4A8FyMCUWVEonStVsDG5B4tYyaFW0YqwasqtstrLBGoU4JiwwixCmhLggC6H8A3QGkV8DYcQRhmFnxh+oG6uwBr75o8/FZuzL7Znbffth+v2q6upznj7n9Pf27XM/95zz9NOpKiRJUpt+Y9IFSJKkuRnUkiQ1zKCWJKlhBrUkSQ0zqCVJaphBLUlSwwxqSZIaZlBLktQwg1qSpIYdPOkCAI466qhavnz5pMuQmrd58+YfVtWSSdfxctyfpeEMuz83EdTLly9n06ZNky5Dal6SH0y6hvm4P0vDGXZ/9tS3JEkNM6glSWqYQS1JUsMMakmSGmZQS5LUMINakqSGGdSSJDXMoJYkqWEGtSSSvDHJ/X23nyW5JMmRSW5P8nh3f8Ska5WmTRMjk0nDWr7mlpFsZ9vac0aynQNFVT0GnAKQ5CDgaeAmYA1wR1WtTbKmm790YoVq4twHF59H1JL2tgL4XlX9AFgJrO/a1wPnTqwqaUoZ1JL2dj5wbTd9dFU9A9DdL51YVdKUMqglvSjJIcB7gb/cx/VWJ9mUZNPu3bvHU5w0pQxqSf3eDdxXVTu7+Z1JlgF097sGrVRV66pqpqpmlixp+ls4pVccg1pSvwv41WlvgI3Aqm56FXDzolckTTmDWhIASf4u8A7gxr7mtcA7kjzePbZ2ErVJ08yPZ0kCoKp+Abxur7Yf0esFLmlCPKKWJKlhBrUkSQ0zqCVJaphBLUlSwwxqSZIaZlBLktQwg1qSpIYZ1JIkNWzeoE5yXJI7k2xJ8nCSi7v2Tyd5uu+L5s/uW+eyJFuTPJbkXeP8ASRJOpANMzLZHuDjVXVfksOAzUlu7x67oqr+pH/hJCfT+5q83wZeD3wjyRuq6vlRFi5J0jSY94i6qp6pqvu66Z8DW4BjXmaVlcB1VfVcVX0f2AqcNopiJUmaNvt0jTrJcuBU4N6u6WNJHkhydZIjurZjgKf6VtvOywe7JEmaw9BBneQ1wFeBS6rqZ8CVwD8ATgGeAT43u+iA1WvA9vyieUmS5jFUUCd5Fb2QvqaqbgSoqp1V9XxVvQB8kV+d3t4OHNe3+rHAjr236RfNS5I0v2F6fQe4CthSVZ/va1/Wt9j7gIe66Y3A+UkOTXICcBLwrdGVLEnS9Bim1/cZwAeBB5Pc37V9ArggySn0TmtvA34foKoeTrIBeIRej/GL7PEtSdL+mTeoq+qbDL7ufOvLrPNZ4LMLqEuSJOHIZJIkNW2YU9+SJI3U8jW3jGQ729aeM5LttMwjakmSGmZQSwIgyeFJbkjyaDe2/z9NcmSS25M83t0fMf+WJI2SQS1p1n8Bvl5V/wh4M73hgtcAd1TVScAd3bykRWRQSyLJa4F/Tm/MBKrql1X1LL2x+9d3i60Hzp1MhdL0MqglAZwI7Ab+e5LvJPlSklcDR1fVM9D7gh5g6SSLlKaRQS0Jep8AeQtwZVWdCvw/9uE0t2P3S+NjUEuC3hj926tq9pvxbqAX3Dtnhwvu7ncNWtmx+6Xx8XPUGrtRfV5S41NVf5PkqSRvrKrHgBX0hgF+BFgFrO3ub55gmdJUMqglzfoD4JokhwBPAB+md9ZtQ5ILgSeB8yZYnzSVDGpJAFTV/cDMgIdWLHYtkn7Fa9SSJDXMoJYkqWEGtSRJDTOoJUlqmEEtSVLDDGpJkhpmUEuS1DCDWpKkhhnUkiQ1zKCWJKlhBrUkSQ0zqCVJaphBLUlSwwxqSZIaZlBLktQwg1qSpIYZ1JIkNWzeoE5yXJI7k2xJ8nCSi7v2I5PcnuTx7v6Irj1J/izJ1iQPJHnLuH8ISZIOVMMcUe8BPl5VbwJOBy5KcjKwBrijqk4C7ujmAd4NnNTdVgNXjrxqSZKmxLxBXVXPVNV93fTPgS3AMcBKYH232Hrg3G56JfCV6rkHODzJspFXLknSFNina9RJlgOnAvcCR1fVM9ALc2Bpt9gxwFN9q23v2vbe1uokm5Js2r17975XLmmkkmxL8mCS+5Ns6toGXuKStHiGDuokrwG+ClxSVT97uUUHtNWvNVStq6qZqppZsmTJsGVIGq+3VdUpVTXTzc91iUvSIhkqqJO8il5IX1NVN3bNO2dPaXf3u7r27cBxfasfC+wYTbmSFtlcl7gkLZJhen0HuArYUlWf73toI7Cqm14F3NzX/qGu9/fpwE9nT5FLaloBtyXZnGR11zbXJS5Ji+TgIZY5A/gg8GCS+7u2TwBrgQ1JLgSeBM7rHrsVOBvYCvwC+PBIK5Y0LmdU1Y4kS4Hbkzw67IpdsK8GOP7448dVnzSV5g3qqvomg687A6wYsHwBFy2wLkmLrKp2dPe7ktwEnEZ3iauqntnrEtfe664D1gHMzMz8Wp8USfvPkckkkeTVSQ6bnQbeCTzE3Je4JC2SYU59SzrwHQ3c1OuSwsHAX1TV15N8m8GXuCQtEoNaElX1BPDmAe0/YsAlLkmLx1PfkiQ1zKCWJKlhBrUkSQ0zqCVJaphBLUlSwwxqSZIaZlBLktQwg1qSpIYZ1JIkNcygliSpYQa1JEkNM6glSWqYQS1JUsMMakmSGmZQS5LUMINakqSGHTzpAtSu5WtumXQJkjT1PKKWJKlhBrUkSQ0zqCVJaphBLelFSQ5K8p0kX+vmT0hyb5LHk1yf5JBJ1yhNGzuTSep3MbAFeG03fzlwRVVdl+QLwIXAlZMqTvvHjqGvbB5RSwIgybHAOcCXuvkAZwI3dIusB86dTHXS9DKoJc36U+CPgBe6+dcBz1bVnm5+O3DMJAqTpplBLYkk7wF2VdXm/uYBi9Yc669OsinJpt27d4+lRmlazRvUSa5OsivJQ31tn07ydJL7u9vZfY9dlmRrkseSvGtchUsaqTOA9ybZBlxH75T3nwKHJ5nty3IssGPQylW1rqpmqmpmyZIli1GvNDWGOaL+MnDWgPYrquqU7nYrQJKTgfOB3+7W+W9JDhpVsZLGo6ouq6pjq2o5vX34r6vqXwF3Au/vFlsF3DyhEqWpNW9QV9XdwI+H3N5K4Lqqeq6qvg9sBU5bQH2SJutS4N8n2UrvmvVVE65HmjoL+XjWx5J8CNgEfLyqfkKvo8k9fcvY+UR6hamqu4C7uukn8J9taaL2N6ivBD5Dr2PJZ4DPAR9hHzufAKsBjj/++P0sQ9o/o/pc6ba154xkO5I0l/3q9V1VO6vq+ap6Afgiv/qPeztwXN+idj6RJGkB9iuokyzrm30fMNsjfCNwfpJDk5wAnAR8a2ElSpI0veY99Z3kWuCtwFFJtgOfAt6a5BR6p7W3Ab8PUFUPJ9kAPALsAS6qqufHU7okSQe+eYO6qi4Y0Dxnz8+q+izw2YUUJUmSehyZTJKkhhnUkiQ1zKCWJKlhBrUkSQ0zqCVJaphBLUlSwwxqSZIaZlBLktQwg1qSpIYZ1JIkNcygliSpYQa1JEkNM6glSWqYQS1JUsMMakmSGmZQS5LUsIMnXYCkyUvym8DdwKH0/i7cUFWfSnICcB1wJHAf8MGq+uXkKpVeavmaW0a2rW1rzxnZtkbJI2pJAM8BZ1bVm4FTgLOSnA5cDlxRVScBPwEunGCN0lQyqCVRPX/bzb6quxVwJnBD174eOHcC5UlTzVPfkgBIchCwGfiHwJ8D3wOerao93SLbgWPmWHc1sBrg+OOPH3+xU2KUp3X1yuURtSQAqur5qjoFOBY4DXjToMXmWHddVc1U1cySJUvGWaY0dQxqSS9RVc8CdwGnA4cnmT3zdiywY1J1SdPKoJZEkiVJDu+m/w7wdmALcCfw/m6xVcDNk6lQml5eo5YEsAxY312n/g1gQ1V9LckjwHVJ/hPwHeCqSRYpTSODWhJV9QBw6oD2J+hdr5Y0IZ76liSpYQa1JEkNM6glSWrYvEGd5Ooku5I81Nd2ZJLbkzze3R/RtSfJnyXZmuSBJG8ZZ/GSJB3ohjmi/jJw1l5ta4A7uvF/7+jmAd4NnNTdVgNXjqZMSZKm07xBXVV3Az/eq3klvXF/4aXj/64EvtKNG3wPvcESlo2qWEmSps3+XqM+uqqeAejul3btxwBP9S0359jAkiRpfqPuTJYBbQPHBk6yOsmmJJt279494jIkSTow7G9Q75w9pd3d7+ratwPH9S0359jADuIvSdL89jeoN9Ib9xdeOv7vRuBDXe/v04Gfzp4ilyRJ+27eIUSTXAu8FTgqyXbgU8BaYEOSC4EngfO6xW8Fzga2Ar8APjyGmiVJmhrzBnVVXTDHQysGLFvARQstSpIk9TgymSRJDTOoJUlqmEEtSVLDDGpJkhpmUEuS1DCDWpKkhhnUkiQ1zKCWJKlhBrUkkhyX5M4kW5I8nOTirv3IJLcneby7P2LStUrTZt6RySRNhT3Ax6vqviSHAZuT3A78a+COqlqbZA2wBrh0gnU2b/maWyZdgg4wHlFLoqqeqar7uumfA1vofZf8SmB9t9h64NzJVChNL4Na0kskWQ6cCtwLHD37DXjd/dI51vH75aUxMaglvSjJa4CvApdU1c+GXc/vl5fGx6CWBECSV9EL6Wuq6saueWeSZd3jy4Bdk6pPmlYGtSSSBLgK2FJVn+97aCOwqpteBdy82LVJ085e35IAzgA+CDyY5P6u7RPAWmBDkguBJ4HzJlSfNLUMaklU1TeBzPHwisWsRdJLeepbkqSGGdSSJDXMoJYkqWEGtSRJDTOoJUlqmEEtSVLDDGpJkhpmUEuS1DCDWpKkhjkymSRJwPI1t4xkO9vWnjOS7cwyqA8wo3qjSZLa4KlvSZIatqAj6iTbgJ8DzwN7qmomyZHA9cByYBvwgar6ycLKlNo0yjMYoz5dJunAMIoj6rdV1SlVNdPNrwHuqKqTgDu6eUmStB/Gcep7JbC+m14PnDuG55AkaSosNKgLuC3J5iSru7ajq+oZgO5+6aAVk6xOsinJpt27dy+wDEmSDkwL7fV9RlXtSLIUuD3Jo8OuWFXrgHUAMzMztcA6JEk6IC3oiLqqdnT3u4CbgNOAnUmWAXT3uxZapCRJ02q/gzrJq5McNjsNvBN4CNgIrOoWWwXcvNAiJUmaVgs59X00cFOS2e38RVV9Pcm3gQ1JLgSeBM5beJmSJE2n/Q7qqnoCePOA9h8BKxZSlKTFleRq4D3Arqr6x12bYyJIDXAIUUkAXwb+K/CVvrbZMRHWJlnTzV86gdoWhcPvqlUOISqJqrob+PFezY6JIDXAoJY0l6HGRJA0Xga1pAVzACNpfAxqSXMZekyEqlpXVTNVNbNkyZJFK1CaBga1pLk4JoLUAHt9N8Iep5qkJNcCbwWOSrId+BSwFsdEkCbOoJZEVV0wx0OOiSBNmKe+JUlqmEEtSVLDDGpJkhpmUEuS1DCDWpKkhhnUkiQ1zI9nLYCffZYkjZtH1JIkNcygliSpYZ76lvSK5eUnTQOPqCVJaphBLUlSwwxqSZIaZlBLktQwg1qSpIbZ61vSorO3tjQ8j6glSWqYQS1JUsMMakmSGmZQS5LUsLEFdZKzkjyWZGuSNeN6Hknj5b4sTdZYen0nOQj4c+AdwHbg20k2VtUj43i+fWWPU2k4re/L0jQY18ezTgO2VtUTAEmuA1YC+71zG67SRIx8X5a0b8Z16vsY4Km++e1dm6RXFvdlacLGdUSdAW31kgWS1cDqbvZvkzw2plqGcRTwwwk+fz9rGeyAryWXD7XY3x/1885j3n0ZFmV/bun338+6htdiTTCGuobcl2HI/XlcQb0dOK5v/lhgR/8CVbUOWDem598nSTZV1cyk6wBrmYu1TMy8+zKMf39u9TW3ruG1WBO0W1e/cZ36/jZwUpITkhwCnA9sHNNzSRof92VpwsZyRF1Ve5J8DPgr4CDg6qp6eBzPJWl83JelyRvbl3JU1a3ArePa/og1cQq+Yy2DWcuENLIvt/qaW9fwWqwJ2q3rRan6tX4hkiSpEQ4hKklSw6YmqJOcl+ThJC8kmdnrscu64REfS/KuOdY/Icm9SR5Pcn3XsWYUdV2f5P7uti3J/XMsty3Jg91ym0bx3AOe49NJnu6r5+w5lhv7kJJJ/jjJo0keSHJTksPnWG5sr8t8P2eSQ7vf39buvbF8lM+vl0pySpJ7Zn/XSU6bdE2zkvxB9155OMl/nnQ9/ZL8YZJKclQDtQy1Xy9iPa+M4XGraipuwJuANwJ3ATN97ScD3wUOBU4AvgccNGD9DcD53fQXgI+OocbPAf9hjse2AUeN+TX6NPCH8yxzUPcanQgc0r12J4+hlncCB3fTlwOXL+brMszPCfxb4Avd9PnA9eP8/Uz7DbgNeHc3fTZw16Rr6mp5G/AN4NBufumka+qr7Th6HQF/MO6/H0PWM9R+vUi1LMrfslHcpuaIuqq2VNWgQRhWAtdV1XNV9X1gK71hE1+UJMCZwA1d03rg3FHW1z3HB4BrR7ndMXhxSMmq+iUwO6TkSFXVbVW1p5u9h97ndxfTMD/nSnrvBei9N1Z0v0eNRwGv7aZ/iwGf556QjwJrq+o5gKraNeF6+l0B/BEDBqmZhAb2636L8rdsFKYmqF/GMEMkvg54tu8NNo5hFH8P2FlVj8/xeAG3JdncjQI1Lh/rTktdneSIAY9PYkjJjwD/e47HxvW6DPNzvrhM9974Kb33isbjEuCPkzwF/Alw2YTrmfUG4Pe6yx//J8nvTroggCTvBZ6uqu9OupY5vNx+vRheMcPjju3jWZOQ5BvA3xvw0Cer6ua5VhvQtvd/n0MNo7jAui7g5Y+mz6iqHUmWArcnebSq7h62hmFqAa4EPkPvZ/sMvVPxH9l7EwPW3a//1od5XZJ8EtgDXDPHZkbyugwqb0DbSN8X+nXzvD9XAP+uqr6a5APAVcDbG6jrYOAI4HTgd4ENSU6s7vzqBOv6BL1TzYtqRPv1YnjF7L8HVFBX1f7stMMMkfhD4PAkB3dHTgOHUdzfupIcDPxL4HdeZhs7uvtdSW6id9pmnwNp2NcoyReBrw14aKghJUdRS5JVwHuAFXP90RvV6zLAMD/n7DLbu9/hbwE/HsFzT62Xe08k+QpwcTf7l8CXFqUo5q3ro8CN3Xv0W0leoDd+9O5J1ZXkn9Drc/Pd7mrMscB9SU6rqr+ZRE19tc27Xy+Skf0tGzdPffeGQzy/68F7AnAS8K3+Bbo3053A+7umVcBcR+j74+3Ao1W1fdCDSV6d5LDZaXr/JT80wueffZ5lfbPvm+M5FmVIySRnAZcC762qX8yxzDhfl2F+zo303gvQe2/89YT/8BzodgD/ops+E5jrMtFi+1/06iHJG+h1TJrol09U1YNVtbSqllfVcnqh9JZxh/R8htmvF9ErZ3jcSfdmW6wbveDZDjwH7AT+qu+xT9Lr/fcYXa/Srv1W4PXd9In0Anwrvf/mDx1hbV8G/s1eba8Hbu177u92t4fpnUIax2v0P4AHgQfovWGX7V1LN3828H+712xctWyld/3o/u72hb1rGffrMujnBP4jvT8yAL/ZvRe2du+NEyf9Pj+Qb8A/AzZ3v+97gd+ZdE1dXYcA/5PeP4n3AWdOuqYBNW6jjV7fA/frCdYz9r9lo7g5MpkkSQ3z1LckSQ0zqCVJaphBLUlSwwxqSZIaZlBLktQwg1qSpIYZ1JIkNcygliSpYf8ftvmYl/tDWlYAAAAASUVORK5CYII=\n",
"text/plain": [
"<Figure size 576x288 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.figure(figsize=(8, 4))\n",
"plt.subplot(1, 2, 1)\n",
"plt.hist(training_y)\n",
"plt.subplot(1, 2, 2)\n",
"plt.hist(test_y)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"ほぼ同じ分布で Training set, Test set に分割できている。"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 5. Automated ML with TPOT."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"いよいよ TPOT を用いて Automated ML を行ってみる。 \n",
"Cross Validation として `KFold(n_splits=5, shuffle=True)` を採用する。"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"kfold = KFold(n_splits=5, shuffle=True, random_state=random_state)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"計算に用いる CPU の数を TPOT の `n_jobs` という引数で指定する。私は今回の検証はすべて `n_jobs=-1` で行った。 \n",
"`n_jobs=-1` とするとすべての CPU を用いて計算する。冬の寒い日はパソコンが暖かくなってちょうど良いが他の作業ができなくなるので注意しよう。 "
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"n_jobs = 1"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### 5.1 Automated ML."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"最初は計算コストを下げすぐに終了するように `generations=3, population_size=5` で行ってみる。 \n",
"`population_size` の数だけ並行して Pipeline の探索を行い、 generations の数だけ並行した結果を考慮してよりよい Pipeline の候補を構築し探索するということを繰り返す。"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"tpot1 = TPOTRegressor(generations=3, population_size=5, scoring='neg_mean_absolute_error', cv=kfold, \n",
" n_jobs=n_jobs, random_state=random_state, verbosity=2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"sklearn の機械学習を行う時と同様に `fit` メソッドの引数に `Training set` の説明変数, `Training set` の目的変数を与えて Automated ML を行うことができる。"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "79c2ebff902a49d6b684624077696cc7",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, description='Optimization Progress', max=20, style=ProgressStyle(descripti…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Generation 1 - Current best internal CV score: -0.45904943717365965\n",
"Generation 2 - Current best internal CV score: -0.45904943717365965\n",
"Generation 3 - Current best internal CV score: -0.45904943717365965\n",
"\n",
"Best pipeline: ExtraTreesRegressor(input_matrix, bootstrap=False, max_features=0.6500000000000001, min_samples_leaf=5, min_samples_split=14, n_estimators=100)\n"
]
},
{
"data": {
"text/plain": [
"TPOTRegressor(config_dict=None, crossover_rate=0.1,\n",
" cv=KFold(n_splits=5, random_state=20181216, shuffle=True),\n",
" disable_update_check=False, early_stop=None, generations=3,\n",
" max_eval_time_mins=5, max_time_mins=None, memory=None,\n",
" mutation_rate=0.9, n_jobs=-1, offspring_size=None,\n",
" periodic_checkpoint_folder=None, population_size=5,\n",
" random_state=20181216, scoring='neg_mean_absolute_error',\n",
" subsample=1.0, use_dask=False, verbosity=2, warm_start=False)"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tpot1.fit(training_X, training_y)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"私のパソコンでは1分ほどで計算が終わった。 \n",
"\n",
"`Current best internal CV score: -0.5172185206482698` と書いている部分が各 Generation で最も良い Pipeline の スコアである。 \n",
"\n",
"今回のスコアは `neg_mean_absolute_error` なので0に近づくほど良い。計算コストを下げたので予測性能が向上していないように見えるが、 \n",
"内部で複数のアルゴリズムの比較を行い最も良い予測モデルのみを表示しているので CV score は Generation 1 の時点でもそれなりに良い値を示していることが多い。 \n",
"予測性能が向上して行くところは 5.4 で詳しくみたいと思う。 \n",
"\n",
"また、 xgboost をインストールしていない場合は `Warning: xgboost.XGBRegressor is not available and will not be used by TPOT.` と表示される。"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"今回最もよかった Pipeline (予測モデルのワークフローのこと) は \n",
"`ExtraTreesRegressor(input_matrix, bootstrap=False, max_features=0.6500000000000001, min_samples_leaf=5, min_samples_split=14, n_estimators=100)` \n",
"であった。また以下のように `export` メソッドを実行すると結果を書き出すことができる。"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tpot1.export('tpot1_logS_QSPR.py')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"どのようなファイルが出力されているかは `6. Evaluate prediction pipeline.` でふれたい。"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### 5.2 Change argument max_time_mins=5"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"`max_time_mins` に時間を指定するとその指定した時間まで最適な予測モデルの Pipeline を探索してくれる。 \n",
"探索ツールはどれぐらいの時間がかかるか不明瞭な時が多いので時間を指定できるのは非常に便利だ。 \n",
"なお `max_time_mins` を指定した場合 `generations` は無視されることに注意しておこう。 (無視されるのを確認するために今回はわざと generation=3 と書いておく)"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"tpot2 = TPOTRegressor(generations=3, population_size=5, scoring='neg_mean_absolute_error', cv=kfold, \n",
" n_jobs=n_jobs, max_time_mins=3, random_state=random_state, verbosity=2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"`max_time_mins=3` としたので3分を超えた時点の Generation が終わると計算を終了する。"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "d3866c165f6b438e96b5558f1917f5b9",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, description='Optimization Progress', max=5, style=ProgressStyle(descriptio…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Generation 1 - Current best internal CV score: -0.45904943717365965\n",
"Generation 2 - Current best internal CV score: -0.45904943717365965\n",
"Generation 3 - Current best internal CV score: -0.45904943717365965\n",
"Generation 4 - Current best internal CV score: -0.4580808516874792\n",
"Generation 5 - Current best internal CV score: -0.43517099826728833\n",
"Generation 6 - Current best internal CV score: -0.4348320502269952\n",
"Generation 7 - Current best internal CV score: -0.4315345048520923\n",
"Generation 8 - Current best internal CV score: -0.4315345048520923\n",
"\n",
"3.4183998 minutes have elapsed. TPOT will close down.\n",
"TPOT closed prematurely. Will use the current best pipeline.\n",
"\n",
"Best pipeline: ExtraTreesRegressor(GradientBoostingRegressor(input_matrix, alpha=0.99, learning_rate=0.1, loss=lad, max_depth=7, max_features=0.7000000000000001, min_samples_leaf=20, min_samples_split=7, n_estimators=100, subsample=0.9000000000000001), bootstrap=False, max_features=0.7000000000000001, min_samples_leaf=5, min_samples_split=5, n_estimators=100)\n"
]
},
{
"data": {
"text/plain": [
"TPOTRegressor(config_dict=None, crossover_rate=0.1,\n",
" cv=KFold(n_splits=5, random_state=20181216, shuffle=True),\n",
" disable_update_check=False, early_stop=None, generations=1000000,\n",
" max_eval_time_mins=5, max_time_mins=3, memory=None,\n",
" mutation_rate=0.9, n_jobs=-1, offspring_size=None,\n",
" periodic_checkpoint_folder=None, population_size=5,\n",
" random_state=20181216, scoring='neg_mean_absolute_error',\n",
" subsample=1.0, use_dask=False, verbosity=2, warm_start=False)"
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tpot2.fit(training_X, training_y)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"5分後に計算が終了している。その時の最も良い Pipeline も表示してくれている。 \n",
"また 5.1 よりも計算量を多くしたので、世代を経るごとに CV score の値が0に近づいていきよくなっていっているのがわかる。 "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### 5.3 強制終了時の挙動"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"次の例のように途中で `Ctrl + C` や Jupyter Notebook の `Interrupt the Kernel` で強制終了したという時も、その時の最も良い Pipeline を表示してくれる。"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
"tpot_mistaken = TPOTRegressor(generations=3, population_size=5, scoring='neg_mean_absolute_error', cv=kfold, \n",
" n_jobs=n_jobs, max_time_mins=5, random_state=random_state, verbosity=2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"下のセルを実行した後わざと強制終了してみてほしい。"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "91e931217d3748f9b17d84731bb6492b",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, description='Optimization Progress', max=5, style=ProgressStyle(descriptio…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Generation 1 - Current best internal CV score: -0.45904943717365965\n",
"Generation 2 - Current best internal CV score: -0.45904943717365965\n",
"Generation 3 - Current best internal CV score: -0.45904943717365965\n",
"Generation 4 - Current best internal CV score: -0.4580808516874792\n",
"Generation 5 - Current best internal CV score: -0.43517099826728833\n",
"Generation 6 - Current best internal CV score: -0.4348320502269952\n",
"Generation 7 - Current best internal CV score: -0.4315345048520923\n",
"\n",
"\n",
"TPOT closed prematurely. Will use the current best pipeline.\n",
"\n",
"Best pipeline: ExtraTreesRegressor(GradientBoostingRegressor(input_matrix, alpha=0.99, learning_rate=0.1, loss=lad, max_depth=7, max_features=0.7000000000000001, min_samples_leaf=20, min_samples_split=7, n_estimators=100, subsample=0.9000000000000001), bootstrap=False, max_features=0.7000000000000001, min_samples_leaf=5, min_samples_split=5, n_estimators=100)\n"
]
},
{
"data": {
"text/plain": [
"TPOTRegressor(config_dict=None, crossover_rate=0.1,\n",
" cv=KFold(n_splits=5, random_state=20181216, shuffle=True),\n",
" disable_update_check=False, early_stop=None, generations=1000000,\n",
" max_eval_time_mins=5, max_time_mins=5, memory=None,\n",
" mutation_rate=0.9, n_jobs=-1, offspring_size=None,\n",
" periodic_checkpoint_folder=None, population_size=5,\n",
" random_state=20181216, scoring='neg_mean_absolute_error',\n",
" subsample=1.0, use_dask=False, verbosity=2, warm_start=False)"
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tpot_mistaken.fit(training_X, training_y)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"強制終了したもののその時点での計算をちゃんと保存しており、この時点での最も良い Pipeline を出力することもできる。"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tpot_mistaken.export('tpot_mistaken_logS_QSPR.py')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### 5.4 generation, Population_size を増やした Pipeline の構築"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"5.1 よりも計算コストはかかるが `generation`, `population_size` を増やして Pipeline の探索を行ってみる。 \n",
"今回は `generation=10`, `population_size=20` とした。 "
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [],
"source": [
"tpot3 = TPOTRegressor(generations=10, population_size=20, scoring='neg_mean_absolute_error', cv=kfold, \n",
" n_jobs=n_jobs, random_state=random_state, verbosity=2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"`max_time_mins=3` としたので3分を超えた時点の Generation が終わると計算を終了する。"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "394c364e77324929a9562ab8ef16d0dd",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, description='Optimization Progress', max=220, style=ProgressStyle(descript…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Generation 1 - Current best internal CV score: -0.4745942938299195\n",
"Generation 2 - Current best internal CV score: -0.4655966251361625\n",
"Generation 3 - Current best internal CV score: -0.46018596571526854\n",
"Generation 4 - Current best internal CV score: -0.45646592009261855\n",
"Generation 5 - Current best internal CV score: -0.4461230389900289\n",
"Generation 6 - Current best internal CV score: -0.4461230389900289\n",
"Generation 7 - Current best internal CV score: -0.4380682448809523\n",
"Generation 8 - Current best internal CV score: -0.4380682448809523\n",
"Generation 9 - Current best internal CV score: -0.43677564142857134\n",
"Generation 10 - Current best internal CV score: -0.43677564142857134\n",
"\n",
"Best pipeline: ExtraTreesRegressor(ZeroCount(input_matrix), bootstrap=False, max_features=0.7000000000000001, min_samples_leaf=4, min_samples_split=5, n_estimators=100)\n"
]
},
{
"data": {
"text/plain": [
"TPOTRegressor(config_dict=None, crossover_rate=0.1,\n",
" cv=KFold(n_splits=5, random_state=20181216, shuffle=True),\n",
" disable_update_check=False, early_stop=None, generations=10,\n",
" max_eval_time_mins=5, max_time_mins=None, memory=None,\n",
" mutation_rate=0.9, n_jobs=-1, offspring_size=None,\n",
" periodic_checkpoint_folder=None, population_size=20,\n",
" random_state=20181216, scoring='neg_mean_absolute_error',\n",
" subsample=1.0, use_dask=False, verbosity=2, warm_start=False)"
]
},
"execution_count": 29,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tpot3.fit(training_X, training_y)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"私のパソコンでは15分ほどで計算が終わった。その時の最も良い Pipeline も表示してくれている。 \n",
"また 5.1 よりも計算量を多くしたので、世代を経るごとに CV score の値が0に近づいていきよくなっていっているのがわかる。 "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"5.1 で求めた Pipeline と比較を行いたいのでこちらも `export` メソッドで結果を書き出しておく。"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 30,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tpot3.export('tpot3_logS_QSPR.py')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 6. Evaluate prediction pipeline."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### 6.1 すべての Pipeline の確認"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"すべての Pipeline を確認したい場合は `evaluated_individuals_` attribute を参照する。 \n",
"今回はより複雑な Pipeline になっていそうな 5.3 の Pipeline を確認してみよう。"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tpot3 の結果では全部で206個の Pipeline がある。\n"
]
}
],
"source": [
"print(f'tpot3 の結果では全部で{len(tpot3.evaluated_individuals_)}個の Pipeline がある。')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"全部表示すると大変なので最後の一つだけ表示してみる。(最後のが一番良いとは限らないのに注意したい)"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"最後の Pipeline は ExtraTreesRegressor(GradientBoostingRegressor(input_matrix, GradientBoostingRegressor__alpha=0.85, GradientBoostingRegressor__learning_rate=0.1, GradientBoostingRegressor__loss=lad, GradientBoostingRegressor__max_depth=9, GradientBoostingRegressor__max_features=0.15000000000000002, GradientBoostingRegressor__min_samples_leaf=4, GradientBoostingRegressor__min_samples_split=16, GradientBoostingRegressor__n_estimators=100, GradientBoostingRegressor__subsample=0.45), ExtraTreesRegressor__bootstrap=True, ExtraTreesRegressor__max_features=0.7000000000000001, ExtraTreesRegressor__min_samples_leaf=6, ExtraTreesRegressor__min_samples_split=8, ExtraTreesRegressor__n_estimators=100)\n"
]
}
],
"source": [
"for key in tpot3.evaluated_individuals_.keys():\n",
" pass\n",
"print(f'最後の Pipeline は {key}')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"最後の Pipeline は `GradientBoostingRegressor` で予測を行った後に `ExtraTreesRegressor` で再び予測を行う Pipeline のようだ。 \n",
"このように TPOT では、前処理も含めどのアルゴリズムがよいか、また場合によってはそれを組み合わせの検討を行い Pipeline を探索する。"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'generation': 'INVALID',\n",
" 'mutation_count': 5,\n",
" 'crossover_count': 0,\n",
" 'predecessor': ('ExtraTreesRegressor(SelectPercentile(GradientBoostingRegressor(input_matrix, GradientBoostingRegressor__alpha=0.85, GradientBoostingRegressor__learning_rate=0.1, GradientBoostingRegressor__loss=lad, GradientBoostingRegressor__max_depth=9, GradientBoostingRegressor__max_features=0.15000000000000002, GradientBoostingRegressor__min_samples_leaf=4, GradientBoostingRegressor__min_samples_split=16, GradientBoostingRegressor__n_estimators=100, GradientBoostingRegressor__subsample=0.45), SelectPercentile__percentile=56), ExtraTreesRegressor__bootstrap=True, ExtraTreesRegressor__max_features=0.7000000000000001, ExtraTreesRegressor__min_samples_leaf=6, ExtraTreesRegressor__min_samples_split=8, ExtraTreesRegressor__n_estimators=100)',),\n",
" 'operator_count': 2,\n",
" 'internal_cv_score': -0.4444842302707152}"
]
},
"execution_count": 33,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tpot3.evaluated_individuals_[key]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"value の方には細かい情報が保存されている。 `operator_count` で Pipeline のシンプルさでフィルターをかけたりも可能だ。"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"たまに最も良い Pipeline は GradientBoosting 回帰モデルを3回パラメータを変えて行ったものであると TPOT によって求められる時があり、 \n",
"(あまり人間が介入しすぎると Automated ML とは言えなくなるが) さすがにこれはちょっとなぁという Pipeline の時は他の Pipeline をこのようにチェックしたり \n",
"generation や population_size が小さすぎた場合は大きい値にして再度計算したりすると良いと思う。"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### 6.2 Export された file の確認"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"5.4 で出力した file の中身を確認してみたい。この Jupyter Notebook と同じフォルダにある tpot3_logS_QSPR.py をテキストエディタや cat コマンドで開いてみる。"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"import numpy as np\n",
"import pandas as pd\n",
"from sklearn.ensemble import ExtraTreesRegressor\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.pipeline import make_pipeline\n",
"from tpot.builtins import ZeroCount\n",
"\n",
"# NOTE: Make sure that the class is labeled 'target' in the data file\n",
"tpot_data = pd.read_csv('PATH/TO/DATA/FILE', sep='COLUMN_SEPARATOR', dtype=np.float64)\n",
"features = tpot_data.drop('target', axis=1).values\n",
"training_features, testing_features, training_target, testing_target = \\\n",
" train_test_split(features, tpot_data['target'].values, random_state=20181216)\n",
"\n",
"# Average CV score on the training set was:-0.43677564142857134\n",
"exported_pipeline = make_pipeline(\n",
" ZeroCount(),\n",
" ExtraTreesRegressor(bootstrap=False, max_features=0.7000000000000001, min_samples_leaf=4, min_samples_split=5, n_estimators=100)\n",
")\n",
"\n",
"exported_pipeline.fit(training_features, training_target)\n",
"results = exported_pipeline.predict(testing_features)\n"
]
}
],
"source": [
"!cat tpot3_logS_QSPR.py"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"上記のように最も良い Pipeline を用いた ML を行うためのスクリプトを書き出してくれており非常に助かる。 \n",
"実際に用いるためには後は `PATH/TO/DATA/FILE`, `COLUMN_SEPARATOR`, `target` を自分のデータセットに合わせて適宜変更するだけで良い。"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"最もよい Pipeline は `ZeroCount` で前処理を行った後に `ExtraTreesRegressor` で予測を行う Pipeline のようだ。 "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### 6.3 Pipeline の再構築"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"今回はこの Jupyter Notebook で準備したデータセットを使って引き続き Pipeline を用いて予測を行いたいので以下の部分だけ抜き出した。"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np #すでにimportしているがわかりやすさを重視するためにわざと残している\n",
"import pandas as pd #すでにimportしているがわかりやすさを重視するためにわざと残している\n",
"from sklearn.ensemble import ExtraTreesRegressor\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.pipeline import make_pipeline\n",
"from tpot.builtins import ZeroCount\n",
"\n",
"exported_pipeline_tpot3 = make_pipeline(\n",
" ZeroCount(),\n",
" ExtraTreesRegressor(bootstrap=False, max_features=0.7000000000000001, min_samples_leaf=4, min_samples_split=5, n_estimators=100)\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"`make_pipeline` 関数を用いることで最も良いとされた Pipeline を再現して予測モデルの構築ができる。 \n",
"Sklearn の ML 同様 `fit` メソッドで予測モデルの構築ができる。"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Pipeline(memory=None,\n",
" steps=[('zerocount', ZeroCount()), ('extratreesregressor', ExtraTreesRegressor(bootstrap=False, criterion='mse', max_depth=None,\n",
" max_features=0.7000000000000001, max_leaf_nodes=None,\n",
" min_impurity_decrease=0.0, min_impurity_split=None,\n",
" min_samples_leaf=4, min_samples_split=5,\n",
" min_weight_fraction_leaf=0.0, n_estimators=100, n_jobs=None,\n",
" oob_score=False, random_state=None, verbose=0, warm_start=False))])"
]
},
"execution_count": 36,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"exported_pipeline_tpot3.fit(training_X, training_y)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Pipeline が構築できた。続けて Sklearn の ML 同様 predict メソッドでデータセットの予測ができる。"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([-4.55851 , -5.30955607, -1.40719952, -0.09583857, -4.0573719 ,\n",
" -3.2350931 , -1.73425476, -1.02698429, -7.44066464, -4.13389024])"
]
},
"execution_count": 37,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"predicted_training_y = exported_pipeline_tpot3.predict(training_X)\n",
"predicted_training_y[:10]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### 6.4 5-fold Cross Validation"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"`make_pipeline` 関数で作成した `exported_pipeline` オブジェクトは `cross_validate` 関数の引数 `estimator` に用いることもでき、 \n",
"5-fold Cross Validation などを行うことも簡単にでき、ロバストであるかどうかの検証も行える。"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'fit_time': array([0.61394882, 0.58112717, 0.58521986, 0.59816813, 0.61214423]),\n",
" 'score_time': array([0.00735998, 0.0070889 , 0.00589609, 0.00773478, 0.00749111]),\n",
" 'test_score': array([-0.43199187, -0.45730515, -0.43283408, -0.46029806, -0.42407188]),\n",
" 'train_score': array([-0.19643681, -0.1964946 , -0.20369601, -0.19860522, -0.19826147])}"
]
},
"execution_count": 38,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"kfold = KFold(n_splits=5, shuffle=True, random_state=random_state)\n",
"cross_validate(estimator=exported_pipeline_tpot3, X=training_X, y=training_y, scoring='neg_mean_absolute_error', \n",
" cv=kfold, n_jobs=-1, return_train_score=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"今回は5回ともおおよそ同じ `test_score`, `train_score` でありロバストと言えそうだ。"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### 6.5 tpot1, tpot3 で求めた Pipeline の比較"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"6.5.1 tpot1 で求めた Pipeline の構築"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"最後に 5.1, 5.4 で求めた Pipeline の比較を行いたい。計算コストを下げたので 5.1 の方が Pipeline がシンプルで予測性能が低いと考えられるがその通りになるだろうか。 \n",
"6.3 同様に 5.1 の方の Pipeline も構築する。"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"import numpy as np\n",
"import pandas as pd\n",
"from sklearn.ensemble import ExtraTreesRegressor\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"# NOTE: Make sure that the class is labeled 'target' in the data file\n",
"tpot_data = pd.read_csv('PATH/TO/DATA/FILE', sep='COLUMN_SEPARATOR', dtype=np.float64)\n",
"features = tpot_data.drop('target', axis=1).values\n",
"training_features, testing_features, training_target, testing_target = \\\n",
" train_test_split(features, tpot_data['target'].values, random_state=20181216)\n",
"\n",
"# Average CV score on the training set was:-0.45904943717365965\n",
"exported_pipeline = ExtraTreesRegressor(bootstrap=False, max_features=0.6500000000000001, min_samples_leaf=5, min_samples_split=14, n_estimators=100)\n",
"\n",
"exported_pipeline.fit(training_features, training_target)\n",
"results = exported_pipeline.predict(testing_features)\n"
]
}
],
"source": [
"!cat tpot1_logS_QSPR.py"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
"outputs": [],
"source": [
"exported_pipeline_tpot1 = ExtraTreesRegressor(bootstrap=False, max_features=0.6500000000000001, min_samples_leaf=5, min_samples_split=14, n_estimators=100)"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"ExtraTreesRegressor(bootstrap=False, criterion='mse', max_depth=None,\n",
" max_features=0.6500000000000001, max_leaf_nodes=None,\n",
" min_impurity_decrease=0.0, min_impurity_split=None,\n",
" min_samples_leaf=5, min_samples_split=14,\n",
" min_weight_fraction_leaf=0.0, n_estimators=100, n_jobs=None,\n",
" oob_score=False, random_state=None, verbose=0, warm_start=False)"
]
},
"execution_count": 41,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"exported_pipeline_tpot1.fit(training_X, training_y)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"tpot1 で求めた Pipeline も構築できた。 tpot1 は `ExtraTressRegressor` を行っているだけというシンプルな Pipeline のようだ。"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### 6.5.2 tpot1, tpot3 のそれぞれの予測性能を計算する。"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"可読性を上げるために X, y , pipeline を引数に与えたら MAE を計算する関数を宣言する。"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {},
"outputs": [],
"source": [
"def calculate_MAE(X, y, pipeline):\n",
" predicted_y = pipeline.predict(X)\n",
" return mean_absolute_error(y, predicted_y)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"tpot1, tpot3 の Training set, Test set に対する MAE を求める。"
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {},
"outputs": [],
"source": [
"MAE = np.zeros(4, )\n",
"i = 0\n",
"for pipeline in [exported_pipeline_tpot1, exported_pipeline_tpot3]:\n",
" for X, y in [(training_X, training_y), (test_X, test_y)]:\n",
" MAE[i] = calculate_MAE(X, y, pipeline)\n",
" i += 1"
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Training</th>\n",
" <th>Test</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>tpot1</th>\n",
" <td>0.269629</td>\n",
" <td>0.463421</td>\n",
" </tr>\n",
" <tr>\n",
" <th>tpot3</th>\n",
" <td>0.195962</td>\n",
" <td>0.455865</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Training Test\n",
"tpot1 0.269629 0.463421\n",
"tpot3 0.195962 0.455865"
]
},
"execution_count": 44,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"result = pd.DataFrame(MAE.reshape(2, -1), index=['tpot1', 'tpot3'], columns=['Training', 'Test'])\n",
"result"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"tpot3 の方が Training set は大幅に, Test set はわずかではあるがどちらもよい予測性能を示した。 \n",
"今回は計算コストをあまり大きくせず、比較的シンプルな結果を優先したためあまり差が出なかったが、 \n",
"generation, population_size を増やせばよりよい Pipeline が探索できてはいたことも伝えておく。ぜひ自分自身で探索してほしい。"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"実測値と予測値の散布図もプロットしてみる。こちらも可読性を上げるために実測値と予測値を与えたら散布図をプロットする関数を宣言する。"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {},
"outputs": [],
"source": [
"def plot_experimented_and_predicted_value(training_y, predicted_training_y,\n",
" test_y, predicted_test_y, title):\n",
" \n",
" def _plot_experimented_and_predicted_value(experimented_value, predicted_value, title):\n",
" lim = (min(experimented_value.min(), predicted_value.min()),\n",
" max(experimented_value.max(), predicted_value.max()))\n",
" plt.scatter(predicted_value, experimented_value, s=3)\n",
" plt.plot(lim, lim, c='red', alpha=0.5)\n",
" plt.xlim(lim)\n",
" plt.ylim(lim)\n",
" plt.xlabel('Predicted value', fontsize=16)\n",
" plt.ylabel('Experimented value', fontsize=16)\n",
" plt.title(title, fontsize=16)\n",
" \n",
" plt.figure(figsize=(11, 5))\n",
" plt.subplot(1, 2, 1)\n",
" _plot_experimented_and_predicted_value(training_y, predicted_training_y, f'{title}: Training set')\n",
" plt.subplot(1, 2, 2)\n",
" _plot_experimented_and_predicted_value(test_y, predicted_test_y, f'{title}: Test set')\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"プロットを行う。"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 792x360 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plot_experimented_and_predicted_value(training_y, exported_pipeline_tpot1.predict(training_X),\n",
" test_y, exported_pipeline_tpot1.predict(test_X), 'tpot1')"
]
},
{
"cell_type": "code",
"execution_count": 47,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 792x360 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plot_experimented_and_predicted_value(training_y, exported_pipeline_tpot3.predict(training_X),\n",
" test_y, exported_pipeline_tpot3.predict(test_X), 'tpot3')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 7. まとめ"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"TPOT による Automated ML の一例を示した。 \n",
"今回はチュートリアルということで計算コストを下げて行ったが実際に用いる時は generation, population_size どちらも大きな値にして行うべきだと考えられる。 \n",
"特に population_size の値は探索に大きく関わるので default の 100 などで Automated ML を行うとよいと思う。 \n",
"また、回帰と同様に判別モデルも `TPOTClassifier` を用いることで行うことができる。自動でよい予測性能を示す Pipeline を探してくれるので興味のある方はぜひ使ってみてほしい。"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# EOF"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment