Skip to content

Instantly share code, notes, and snippets.

@iwatobipen
Created November 28, 2019 02:25
Show Gist options
  • Save iwatobipen/8d865c617b63c8775245f491af379edf to your computer and use it in GitHub Desktop.
Save iwatobipen/8d865c617b63c8775245f491af379edf to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"RDKit WARNING: [11:24:44] Enabling RDKit 2019.09.1 jupyter extensions\n"
]
}
],
"source": [
"%matplotlib inline\n",
"import matplotlib.pyplot as plt\n",
"import os\n",
"from rdkit import Chem\n",
"from rdkit.Chem import DataStructs\n",
"from rdkit.Chem import AllChem\n",
"from rdkit import RDPaths\n",
"from rdkit.Chem.Draw import IPythonConsole\n",
"from rdkit.Chem import Draw\n",
"from rdkit.Chem import PandasTools\n",
"import numpy as np\n",
"import pandas as pd\n",
"from IPython.display import HTML"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"traindf = PandasTools.LoadSDF(os.path.join(RDPaths.RDDocsDir,'Book/data/solubility.train.sdf'))\n",
"testdf = PandasTools.LoadSDF(os.path.join(RDPaths.RDDocsDir, 'Book/data/solubility.test.sdf'))"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>ID</th>\n",
" <th>NAME</th>\n",
" <th>SOL</th>\n",
" <th>SOL_classification</th>\n",
" <th>smiles</th>\n",
" <th>ROMol</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>n-pentane</td>\n",
" <td>n-pentane</td>\n",
" <td>-3.18</td>\n",
" <td>(A) low</td>\n",
" <td>CCCCC</td>\n",
" <td><img data-content=\"rdkit/molecule\" src=\"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAMgAAADICAIAAAAiOjnJAAAABmJLR0QA/wD/AP+gvaeTAAAH5UlEQVR4nO3cSUxTWxzH8UMpKAoiSpwiEMc4RRRFnOIQXSiyhR1blk1cseyWjUlXJCzZEZbVqAkENcZIkIiIsWjUIALiiIqijPctTl4fYRBo++u9vHw/S3N7e6xfLvee/t9LcRzHAInmc3sB+H8iLEgQFiQICxKEBQnCggRhQYKwIEFYkCAsSBAWJAgLEoQFCcKCBGFBgrAgQViQICxIEBYkCAsShAUJwoIEYUGCsCBBWJAgLEgQFiQICxKEBQnCggRhQYKwIEFYkCAsSBAWJAgLEoQFCcKCBGFBgrAgQViQICxIEBYkCAsShAUJwoIEYUGCsCBBWJAgLEgQFiQICxKEBQnCggRhQYKwIEFYkCAsSBAWJAgLEoQFCcKCBGFBgrAgQViQICxIEBYkCAsShAUJwoIEYUGCsCBBWJAgLEgQFiQICxKEBQnCggRhQYKwIEFYkCAsSBAWJDwXVldXV19fXyQScXshy0MkEunr6+vq6nJ7IbM4nvHly5dAIOD3+4uLi/1+f1VV1adPn9xelHcNDQ1VV1enp6cXFxf7fL7KysoPHz64vaj/eCKssbGxUCiUk5NjjPH7/UVFRampqcaY3Nzc2traiYkJtxfoLRMTE7W1tbm5ucaY1NTUoqIiv99vjMnJyQmFQmNjY24v0HG8EFZzc/OBAwfs5fPChQtPnz51HCcSiVy6dMn+4Z49e27duuX2Mr2ipaWlsLDQfjLnzp178uSJ4zgvXry4cuWK/cPdu3ffuHHD7WW6GtbLly/Ly8vtx7Fz587GxsYZB4TD4e3bt9sDysrKXr165co6PaK3t7eystJ+Gnl5efX19TMOaGpq2rdvnz3g4sWLz549c2WdljthDQ8PB4PBFStWGGNWr14dDAb//Pkz55Gjo6OhUGjNmjXGmLS0tEAg8P379ySv1nU/f/4MBoMrV640xqxatSoYDP7+/XvOI+1NRXZ2dvTj+vbtW5JXayU7rMnJyfr6+o0bNxpj7C3n4ODggq8aGBioqqry+XzGmM2bN9fV1U1OTiZhta6bmppqbGzMz883xqSkpJSXl799+3bBV33+/DkQCNj71PXr14dCoeTfpyY1rNbW1pKSEnutPnbs2MOHD5f08vb29lOnTtmXHzly5P79+6J1ekRbW9vJkyft3/fo0aMPHjxY0ssfP3585swZ+/LDhw/fu3dPtM45JSmsd+/eVVZWpqSkGGO2bt1aX18/NTUVw3nsT3BBQUH0J7inpyfhq3Vdf39/9Aq9ZcuWeK7Q4XB427Zt0fvUN2/eJHap85GH9evXr5qamszMTHt/UF1dPTw8HP85g8FgRkZG9J5jZGQkIat1nb2nzMrKMsakp6cHAoEfP37Eec6RkZGamhp7zoyMjOrq6vjPuSBhWNOvLvbHJbFXl0RdBb1DenXp7++PflxxXgUXQxVWe3v76dOn7WdUVFSkux+6e/fuoUOH7BuVlJS0traK3kiqo6Mjej+0d+/e27dvi96ora3txIkT9o2Ki4uXet+2eIkPyz7B2UcS+wSnfiSJ7UnTI6Y/wa1bty4JT3D2N0leXt6SnjSXKpFh2U0Ut/acpu+NZWZm/mVvzCNm7zkNDQ0l7d2n743ZrcT59sZik7CwwuHwjh07XN8ln76bv2vXrtm7+R7R1NS0f/9+13fJF9zNj1kCwopEIpcvX45+r3fz5s34zxmnpqam6d8/dnV1ub2i/8z4Xu/69etur8hpaWk5ePCgXdL58+c7OzvjP2dcYUUHXcy/X62Pj4/Hv6aEGB8fr6ursyMAHhnCiQ66GGPWrl1bU1PjnV/W9j51w4YN0fvUOIdwYgxr9j/bx48f41mHiE1/+q2xK+kn/J9N5OvXrzPSHx0dje1UsYQ156CLlz1//tzFIZw7d+7MHnTxsu7u7viHcJYW1oKDLl42Ywjn9evX6nfU3RonQZxDOIsNa/GDLl42fQjHfmEi2hBZ/KCLl8UzhLNwWLO3H9+/fx/fgl0mHcKJbdDFy2IbwlkgrNbW1uPHj9vrYQyDLl6mGMJ59OhRPIMuXrbUIZx5wxoYGKioqLAnys/Pb2hoWO5f8c6WwCGcBA66eNbU1FRDQ4O9GBtjKioqBgYG5jt43rAGBwezs7MTNejiZXEO4SgGXbwsOoSTlZUVS1iO44TD4d7eXsHavCi2IZwZgy5JeNL0iN7e3nA4/JcD3P/Pvzxl8UM4HR0dZ8+etUdKB12WKcKaacEhnOQPuixHhDU3+73ejCGcGfs6Xvj+0bMI62+6u7tLS0vt77uCgoLomHVpaWl3d7fbq/O0FMdxDP6qubn56tWrPT09Pp9v06ZN165dKysrc3tRXkdYizI+Pt7Z2WmMKSwsTEtLc3s5ywBhQcJz/+M1/D8QFiQICxKEBQnCggRhQYKwIEFYkCAsSBAWJAgLEoQFCcKCBGFBgrAgQViQICxIEBYkCAsShAUJwoIEYUGCsCBBWJAgLEgQFiQICxKEBQnCggRhQYKwIEFYkCAsSBAWJAgLEoQFCcKCBGFBgrAgQViQICxIEBYkCAsShAUJwoIEYUGCsCBBWJAgLEgQFiQICxKEBQnCggRhQYKwIEFYkCAsSBAWJAgLEoQFCcKCBGFBgrAgQViQICxIEBYkCAsShAUJwoIEYUGCsCBBWJAgLEgQFiQICxKEBQnCggRhQYKwIEFYkCAsSBAWJAgLEv8AerkQZXqa/Q0AAAAASUVORK5CYII=\" alt=\"Mol\"/></td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>cyclopentane</td>\n",
" <td>cyclopentane</td>\n",
" <td>-2.64</td>\n",
" <td>(B) medium</td>\n",
" <td>C1CCCC1</td>\n",
" <td><img data-content=\"rdkit/molecule\" src=\"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAMgAAADICAIAAAAiOjnJAAAABmJLR0QA/wD/AP+gvaeTAAAQ90lEQVR4nO2dWVAUVxuGPzYBQSEBhkF2CPu+RcREEUURBANqFZBoVW68xMtccp2LVHlrqlIpk1SRUkQWcUEkFi4EGTAQdFQSwLDNxj4O68z8F/1n/vkTgZ7pPr1+z6X29PlgXs7b/fXp97hYrVZAELZx5bsARJqgsBAioLAQIqCwECKgsBAioLAQIqCwECKgsBAioLAQIqCwECKgsBAioLDoYjAYDAYD31WIBhTWzqysrFy6dCk9PT0rK+vChQtGo5HvikSAC65u2J7BwcGamhq1Wr1r1y4AWF9fT0pKamhoyMjI4Ls0QYMz1pZYrdZvv/22oKBArVYnJyf39fX99ttvGRkZarV6//79X3/9tcVi4btGAWNF3odery8vL6d+RefPn3/37h317ysrK3V1dS4uLgBw/PjxmZkZfusULCis99DV1RUaGgoAfn5+P//8878PaG5uDggIAIDg4OA7d+5wX6HwQWH9HxsbG/X19W5ubgCQn58/Ojq61ZETExOHDx8GABcXl7q6urW1NS7rFD4orP8xPj5+8OBBAHBzc6uvr9/c3Nz+eIvFcvnyZQ8PDwDIy8sbGRnhpk5RgML6L42NjR988AEAhIeHd3d30/9gb29vTEwMAOzdu/enn34iV6G4QGFZTSZTXV0ddZ1eWVk5Ozvr6BkWFxdrampsV/rLy8sk6hQXchfW8PBwamoqAHh5eV2+fJnJqa5everj4wMACQkJAwMDbFUoUuQrLOoKydPTEwCSk5OHhoaYn1OtVmdmZgKAh4dHfX292Wxmfk6RIlNh6fX6U6dO/btNxRz7RldxcbFsG11yFNaDBw/27dsHAP7+/teuXSMxREtLC9XoUigUt2/fJjGEwJGXsKg2laurKwAUFhZOTk6SG0uj0Zw4cUK2jS4ZCcvWpnJ3d6fTpmKOfaMrNzdXVo0uuQjr+vXrVJsqIiLi0aNHXA7d29sbGxsLAHv27Pnxxx+5HJpHpC8s+zZVVVXV3Nwc9zUsLi7W1tbKqtElcWH19/fHx8cDgLe3N8M2FXNsja7o6Oienh5+iyGNZIVl36ZKSUlhpU3FHLVanZWVJYdGlzSFpdPpysrKqDuyixcvstimYs7q6upXX31F3ZkeO3Zsenqa74qIIEFh2dpUgYGBra2tfJfzfu7du6dUKiXc6JKUsOzbVEeOHCHapmKORqMpKSmRaqNLOsIaGxsrKCjgsk3FHOpCkHpNIzc3982bN3xXxBoSEda1a9f8/f2pNtXjx4/5Lscxnj17Zmt0/fDDD3yXww6iF9bS0tLFixepFtGZM2d4aVMxZ3Fx8fPPP6d+inPnzs3Pz/NdEVPELSyVShUXFyeQNhVzrl696uvrCwBRUVFPnz7luxxGiFVY9lcnKSkpv//+O98VscOrV6+oRhd1pSjeRpcohaXT6UpLS233U6urq3xXxCbr6+u2e9ujR49OTU3xXZEziE9YnZ2dISEhVJuqra2N73JI0dHRQTW6goKC2tvb+S7HYcQkrH+0qUT6p0wfrVYr3kaXaIQ1NjZ24MABCVx8OIT9pWROTo6IGl3iEJatTRUZGSm6NhVz+vr6PvroI6rRdeXKFb7LoYXQhbW0tHT+/HmqwXP27FmRtqmYs7S09MUXX4io0SVoYfX19UmpTcUc+0bXkydP+C5nOwQqLPtri+zs7NevX/NdkVAYHR3Nz88X/rWmEIWl1WpPnjwp1TYVc+wbXUVFRcK8OxacsO7fv0+1qYKCgm7dusV3OcJF4L8oAQlLFH+IgkLIU7tQhCWWSwehIdiLUUEIS0Q3O8JEgLfPPAtLdO0ZwSK0hh+fwhJjQ1ngCOcRBT/CEu8jMOEjkIeqPAjrH/cy4npoLwrsl4HwdX/NtbDEvsxIRPC7cI07YUljYaS44HGpLUfCksxSbtHB18sBXAhLSi+fiBTuX2ciKyzpvS4nXjh+AZOgsJ49e2ZrU0nmBV+xw1mji4iwJBxJIAG4CblgX1jSDlGRBhzE8rAsLMnHPkkJokFirAlLJkF1EoNc9CE7wrK1qSQfrSk9CIW1siAsW5tKDmHAUoX1eGlGwpJhfLmEYTcQ33lhyXPDBcnD1hYezghLzlvEyAFWNh1yWFgy39RKJjDfJs0xYeE2fLKCycaOdIWFG4fKE6e3oqUlLNzqWM44t3n2zsLCzdkRq9U6PDycmpoKAF5eXnQaXdsJa3FxsaamBttUCIV9o6uysnJ2dnabg7cU1uDgYGRkJAD4+fk1NDQQqBMRJQ0NDX5+ftSKrsHBwa0Oc4UtMBgM8/PzGRkZKpWqurp6q8MQuVFdXT04OJibmzs7O6vVarc6bEthXblyZWlp6cKFC9QqUASxERkZWV1dbTQav/vuu62O2VJYZ86cAYDm5mYipSEihxIGJZL34mK1Wt/7H0ajUaFQrK2tTU5OUu89IgiFVqsNDQ318PDQ6XR79ux57zFbzli+vr7Hjx+3WCwtLS3EKkRESVNTk9lsLikp2UpVsI2w4O+J7saNG+yXhogZShLb+CBsY4UAsLCwEBwcbDabZ2ZmgoKC2C8QESGzs7NKpdLFxUWr1VILbN7LdjOWv7//0aNHzWZzW1sbgQoRUdLc3Ly5uVlcXLyNqmB7YQG6IfIv6PggbG+FQHveQ2TC4uKiQqGgc3W0w4wVEBBw6NChjY2N27dvs1ohIkra2trW19cLCwt3vObeQViAbojYQdMHYUcrBACNRhMWFrZr1y6dTke95oXIE5PJFBQUtLq6OjExQa0s3YadZyylUnngwIGVlZU7d+6wVCEiSm7dumUymQ4ePLijqoCOsADdEAEAR3wQ6FghAExOTkZERPj4+Oh0Om9vb6YFIiJkdXVVoVAYjcbx8fGIiIgdj6c1Y4WFheXm5hqNxo6ODsYVIqLk7t27y8vLeXl5dFQFNIUF6IayxyEfBJpWCACjo6OxsbH+/v5arZaK6kPkw8bGRnBw8Pz8/MjICM2Fn3RnrJiYmIyMjIWFha6uLgYVIqKks7Nzfn4+MzOT/nJiusICdEMZ46gPAn0rBICXL1+mpKQEBgbOzMy4u7s7UyAiQsxmc0hIiF6vf/nyZVJSEs1POTBjJScnJyUlGQyG7u5upypERMnDhw/1en1CQgJ9VYFDwgKAqqoqQDeUGdTXfe7cOYc+5YAVAsDz58+zs7OVSuXU1BSVcYNIG4vFEhYWNjMz8/z5cyq/gyaOiSMrKys2Nlaj0fT09DhYISJKnj59OjMzEx0d7ZCqwFFhAbqhzHDOB8EJYVH3nI2NjQ55KCJGrFbrzZs3wcFGA4XDwvr4448jIiImJiZUKpWjn0XERV9f39u3b8PCwvLy8hz9rMPCcnFx+eyzzwDdUAbY+qJUkqNDOHNnhy14mbBjQMM2ONZuoLDdgg4ODqanpzsxKiJ8BgcHMzMzg4ODp6am3NzcHP24MzOWq6vr6dOnASctSUN9uVVVVU6oCpwTFqAbygAnHjzb44wVAsDm5mZISIjBYHDowSQiFl6/fp2YmBgQEKDRaJxbcODkjOXu7l5eXg4ATU1Nzp0BETLXr18HgNOnTzu9jMX5533ohhKGoQ+C01YIdstV//jjD2obMEQajI2NxcTE+Pn56XQ6p5ehOz9jeXh4lJaWAgDV9UckQ2NjIwCUl5czebmB0dIXdENJwtwHgYkVAsDKykpQUJDJZKL5EiMifKiXk729vfV6/e7du50+D6MZy9vb++TJk1arFVO7JcONGzesVmtZWRkTVQFDYQG6oeRgxQeBoRUCxsFLCzoB7jRhOmP5+voWFxdbLJbW1laGp0J45+bNm2az+cSJEwxVBcyFBeiGEoItHwTmVgh/x8FbLJaZmZnAwEDmNSG8wG6QMQszlr+/f1FR0ebmJsbBi5qWlpbNzc1jx46xEo/NzruB6IYSgEUfBFasEAAMBkNISIibm5tWq6W230TExfLyskKh2NjYYGt7G3ZmrMDAwE8//XRtba29vZ2VEyIc09raurq6evjwYbY2TWLtNXl0Q1HDrg8CW1YIABqNJjQ01NPTU6/X+/j4sHJOhBscCnCnCWszFsbBi5f29naTyVRQUMCWqoBFYQG6oWhh3QeBRSsEgImJicjISB8fH71e7+XlxdZpEaKsra0pFIrl5eWxsbHIyEi2TsvmjBUeHp6Tk2M0Gu/fv8/iaRGi3Lt3b2lpKTc3l0VVAbvCAnRDEULCB4FdKwSAkZGR+Ph4jIMXC7Y3Yl69epWQkMDimVmeseLi4tLT0xcWFn755Rd2z4yQ4MGDB/Pz8xkZGeyqClgXFqAbigpCPgisWyEAvHjxIjU1FePghY/ZbN63b59Op3vx4kVycjK7J2d/xkpJSUlMTDQYDI8fP2b95AiLdHd363S6+Ph41lUFJIQFGIArEpwOrqUD+1YIAAMDAzk5ORgHL2QsFkt4ePj09PTAwEBWVhbr5yfyrWdnZ1Nx8L/++iuJ8yPM6enpmZ6ediLAnSakppPKykpANxQw1Fdz9uxZJ4Jr6UBKWNQdbFNTE8bBCxMmwbV0IHKNBQBWqzUyMpKKg8/JySExBOI0KpUqLy8vNDR0YmJCZDMWxsELGSYB7jQheMtm2xyF3BCIczi9kQl9SFkh2MXBDw0NpaWlERoFcZShoaGMjAynA9xpQnDGcnV1raioAHRDgUF9HZWVleRUBUSFBfhAWpCQe/BsD0ErBLs4eLVanZiYSG4ghCZv3rxJSEhgEuBOE7Izlru7+6lTpwDj4AUDFeBeUVFBeuEJ8Qd56IaCghsfBNJWCABra2vBwcGLi4t//vlnTEwM0bGQ7RkfH4+JifH19dXpdKRfoyI+Y3l6epaVlQHGwQsAasPliooKDl7O42JNC7qhQODMB4EDKwQAk8mkUChMJtPbt2/Dw8NJD4e8l6mpqfDwcG9vb51Ox0G4Bhcz1u7du0tKSjAOnl+oAPfS0lJuIls4Wt6Jbsg7XPogcGOFAGA0GoOCgtbX16emppRKJQcjIvZQAe7u7u46nW7v3r0cjMjRjIVx8PzS3NxMBbhzoyrgTFiAbsgrHPsgcGaF8HccvNVq1Wg0H374ITeDIgAwPz+vVCo5/s1zN2P5+/sXFhZubGygG3JMc3Pz+vp6UVERl3/PnL70h27IC9z7IHBphWAXB8/ZvQliC3Cfnp5WKBScjcvpjBUYGPjJJ59gHDyXtLW1ra6uHjp0iEtVAcfCAnRDzuHFB4FjK4S/4+C9vLy4eWIlc6intCsrKywGuNOE6xlLqVTu37/fZDLdvXuX46FlyO3bt9+9e5efn8+xqoB7YQG6IYfw5YPAvRUCt+sY5QwV4L60tDQ6OhodHc3x6DzMWFFRUdnZ2cvLy52dndyPLh86OjqoAHfuVQW8CAvQDTmBRx8EXqwQMA6ePBsbG0qlcm5ujvUAd5rwM2PFxcWlpqYuLCw8fPiQlwIkT1dX19zcXFpaGi+qAr6EBeiGhOHXB4EvKwSA4eHhtLQ0jIMngS3AfXh4OCUlhZcaeJuxUlNTqTj4J0+e8FWDVHn06BEV4M6XqoBHYQEG4BLDFlzLYw28WSEA9Pf35+bmhoaG/vXXXxgHzxa29Nf+/v7s7Gy+yuDz68zJyYmJiZmamurt7eWxDInR09MzMTERFRVFYlsA+vB81VxZWfnNN99cunTpyJEj/FYiGagN/cgFuNPFyisqlaq2tpbPn1+K1NbWqlQqfr9ZPq+xKHQ63ffff89vDRLjyy+/5Hi96L/hX1iIJMF7MYQIKCyECCgshAgoLIQIKCyECCgshAgoLIQIKCyECCgshAgoLIQIKCyECP8BXl82fmLSMUEAAAAASUVORK5CYII=\" alt=\"Mol\"/></td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"HTML(traindf.head(2).to_html())"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"cls2lab = {'(A) low':0, '(B) medium':1, '(C) high':2}"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"def fp2np(fp):\n",
" arr = np.zeros((0,))\n",
" DataStructs.ConvertToNumpyArray(fp, arr)\n",
" return arr"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"trainfp = [AllChem.GetMorganFingerprintAsBitVect(m, 2) for m in traindf.ROMol]\n",
"testfp = [AllChem.GetMorganFingerprintAsBitVect(m, 2) for m in testdf.ROMol]"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"trainX = np.array([fp2np(fp) for fp in trainfp])\n",
"testX = np.array([fp2np(fp) for fp in testfp])"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"trainY = np.array([cls2lab[i] for i in traindf.SOL_classification.to_list()])\n",
"testY = np.array([cls2lab[i] for i in testdf.SOL_classification.to_list()])"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"[MLENS] backend: threading\n"
]
}
],
"source": [
"from mlens.ensemble import SuperLearner\n",
"from sklearn.linear_model import LogisticRegression\n",
"from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier\n",
"from sklearn.metrics import r2_score, accuracy_score\n",
"from sklearn.svm import SVR, SVC"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Base model"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.7198443579766537"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"rf = RandomForestClassifier(n_estimators=100, random_state=794)\n",
"rf.fit(trainX, trainY)\n",
"pred = rf.predict(testX)\n",
"accuracy_score(testY, pred)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## SuperLearner is stacking model"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"SuperLearner(array_check=None, backend=None, folds=2,\n",
" layers=[Layer(backend='threading', dtype=<class 'numpy.float32'>, n_jobs=-1,\n",
" name='layer-1', propagate_features=None, raise_on_exception=True,\n",
" random_state=3251, shuffle=False,\n",
" stack=[Group(backend='threading', dtype=<class 'numpy.float32'>,\n",
" indexer=FoldIndex(X=None, folds=2, raise_on_ex...81782f0>)],\n",
" n_jobs=-1, name='group-1', raise_on_exception=True, transformers=[])],\n",
" verbose=1)],\n",
" model_selection=False, n_jobs=None, raise_on_exception=True,\n",
" random_state=794, sample_size=20,\n",
" scorer=<function accuracy_score at 0x7f61481782f0>, shuffle=False,\n",
" verbose=2)"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ensemble = SuperLearner(scorer=accuracy_score, random_state=794, verbose=2)\n",
"ensemble.add([RandomForestClassifier(n_estimators=100, random_state=794), SVC(gamma='auto', C=1000)])\n",
"ensemble.add_meta(LogisticRegression(solver='lbfgs', multi_class='auto'))\n"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Fitting 2 layers\n",
"Processing layer-1 done | 00:00:01\n",
"Processing layer-2 done | 00:00:00\n",
"Fit complete | 00:00:01\n",
"\n",
"Predicting 2 layers\n",
"Processing layer-1 done | 00:00:00\n",
"Processing layer-2 done | 00:00:00\n",
"Predict complete | 00:00:00\n"
]
},
{
"data": {
"text/plain": [
"0.7159533073929961"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ensemble.fit(trainX, trainY)\n",
"pred = ensemble.predict(testX)\n",
"accuracy_score(testY, pred)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
" score-m score-s ft-m ft-s pt-m pt-s\n",
"layer-1 randomforestclassifier 0.55 0.01 0.37 0.01 0.03 0.00\n",
"layer-1 svc 0.56 0.03 0.60 0.10 0.40 0.02"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ensemble.data"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Blending approaches"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"from mlens.ensemble import BlendEnsemble"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"BlendEnsemble(array_check=None, backend=None,\n",
" layers=[Layer(backend='threading', dtype=<class 'numpy.float32'>, n_jobs=-1,\n",
" name='layer-1', propagate_features=None, raise_on_exception=True,\n",
" random_state=None, shuffle=False,\n",
" stack=[Group(backend='threading', dtype=<class 'numpy.float32'>,\n",
" indexer=BlendIndex(X=None, raise_on_exception=...81782f0>)],\n",
" n_jobs=-1, name='group-3', raise_on_exception=True, transformers=[])],\n",
" verbose=1)],\n",
" model_selection=False, n_jobs=None, raise_on_exception=True,\n",
" random_state=None, sample_size=20,\n",
" scorer=<function accuracy_score at 0x7f61481782f0>, shuffle=False,\n",
" test_size=0.2, verbose=2)"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ensemble2 = BlendEnsemble(scorer=accuracy_score, test_size=0.2, verbose=2)\n",
"ensemble2.add([RandomForestClassifier(n_estimators=794, random_state=794),\n",
" SVC(gamma='auto')])\n",
"ensemble2.add_meta(LogisticRegression(solver='lbfgs', multi_class='auto'))"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Fitting 2 layers\n",
"Processing layer-1 done | 00:00:03\n",
"Processing layer-2 done | 00:00:00\n",
"Fit complete | 00:00:03\n"
]
},
{
"data": {
"text/plain": [
"BlendEnsemble(array_check=None, backend=None,\n",
" layers=[Layer(backend='threading', dtype=<class 'numpy.float32'>, n_jobs=-1,\n",
" name='layer-1', propagate_features=None, raise_on_exception=True,\n",
" random_state=None, shuffle=False,\n",
" stack=[Group(backend='threading', dtype=<class 'numpy.float32'>,\n",
" indexer=BlendIndex(X=None, raise_on_exception=...81782f0>)],\n",
" n_jobs=-1, name='group-3', raise_on_exception=True, transformers=[])],\n",
" verbose=1)],\n",
" model_selection=False, n_jobs=None, raise_on_exception=True,\n",
" random_state=None, sample_size=20,\n",
" scorer=<function accuracy_score at 0x7f61481782f0>, shuffle=False,\n",
" test_size=0.2, verbose=2)"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ensemble2.fit(trainX, trainY)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Predicting 2 layers\n",
"Processing layer-1 done | 00:00:00\n",
"Processing layer-2 done | 00:00:00\n",
"Predict complete | 00:00:00\n"
]
}
],
"source": [
"pred_b = ensemble2.predict(testX)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
" score-m score-s ft-m ft-s pt-m pt-s\n",
"layer-1 randomforestclassifier 0.60 0.00 3.21 0.00 0.08 0.00\n",
"layer-1 svc 0.38 0.00 1.72 0.00 0.39 0.00"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ensemble2.data"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.669260700389105"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"accuracy_score(pred_b, testY)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
@iwatobipen
Copy link
Author

Example of ensemble learning with ml-ens

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment