Skip to content

Instantly share code, notes, and snippets.

@iwatobipen
Last active December 7, 2019 12:52
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save iwatobipen/f05252e8d6754de5413477a199652a22 to your computer and use it in GitHub Desktop.
Save iwatobipen/f05252e8d6754de5413477a199652a22 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"RDKit WARNING: [20:34:10] Enabling RDKit 2019.09.2 jupyter extensions\n"
]
}
],
"source": [
"%matplotlib inline \n",
"import matplotlib.pyplot as plt\n",
"import os\n",
"from rdkit import Chem\n",
"from rdkit import RDPaths\n",
"\n",
"import dgl\n",
"import numpy as np\n",
"import random\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"from torch.utils.data import DataLoader\n",
"from torch.utils.data import Dataset\n",
"from dgl import model_zoo\n",
"\n",
"from dgl.data.chem.utils import mol_to_complete_graph, mol_to_bigraph\n",
"\n",
"from dgl.data.chem.utils import atom_type_one_hot\n",
"from dgl.data.chem.utils import atom_degree_one_hot\n",
"from dgl.data.chem.utils import atom_formal_charge\n",
"from dgl.data.chem.utils import atom_num_radical_electrons\n",
"from dgl.data.chem.utils import atom_hybridization_one_hot\n",
"from dgl.data.chem.utils import atom_total_num_H_one_hot\n",
"from dgl.data.chem.utils import one_hot_encoding\n",
"from dgl.data.chem import CanonicalAtomFeaturizer\n",
"from dgl.data.chem import CanonicalBondFeaturizer\n",
"from dgl.data.chem import ConcatFeaturizer\n",
"from dgl.data.chem import BaseAtomFeaturizer\n",
"from dgl.data.chem import BaseBondFeaturizer\n",
"\n",
"from dgl.data.chem import one_hot_encoding\n",
"from dgl.data.utils import split_dataset\n",
"\n",
"from functools import partial\n",
"from sklearn.metrics import roc_auc_score"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"torch.cuda.is_available()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"def chirality(atom):\n",
" try:\n",
" return one_hot_encoding(atom.GetProp('_CIPCode'), ['R', 'S']) + \\\n",
" [atom.HasProp('_ChiralityPossible')]\n",
" except:\n",
" return [False, False] + [atom.HasProp('_ChiralityPossible')]\n",
" \n",
"def collate_molgraphs(data):\n",
" \"\"\"Batching a list of datapoints for dataloader.\n",
" Parameters\n",
" ----------\n",
" data : list of 3-tuples or 4-tuples.\n",
" Each tuple is for a single datapoint, consisting of\n",
" a SMILES, a DGLGraph, all-task labels and optionally\n",
" a binary mask indicating the existence of labels.\n",
" Returns\n",
" -------\n",
" smiles : list\n",
" List of smiles\n",
" bg : BatchedDGLGraph\n",
" Batched DGLGraphs\n",
" labels : Tensor of dtype float32 and shape (B, T)\n",
" Batched datapoint labels. B is len(data) and\n",
" T is the number of total tasks.\n",
" masks : Tensor of dtype float32 and shape (B, T)\n",
" Batched datapoint binary mask, indicating the\n",
" existence of labels. If binary masks are not\n",
" provided, return a tensor with ones.\n",
" \"\"\"\n",
" assert len(data[0]) in [3, 4], \\\n",
" 'Expect the tuple to be of length 3 or 4, got {:d}'.format(len(data[0]))\n",
" if len(data[0]) == 3:\n",
" smiles, graphs, labels = map(list, zip(*data))\n",
" masks = None\n",
" else:\n",
" smiles, graphs, labels, masks = map(list, zip(*data))\n",
"\n",
" bg = dgl.batch(graphs)\n",
" bg.set_n_initializer(dgl.init.zero_initializer)\n",
" bg.set_e_initializer(dgl.init.zero_initializer)\n",
" labels = torch.stack(labels, dim=0)\n",
" \n",
" if masks is None:\n",
" masks = torch.ones(labels.shape)\n",
" else:\n",
" masks = torch.stack(masks, dim=0)\n",
" return smiles, bg, labels, masks\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"atom_featurizer = BaseAtomFeaturizer(\n",
" {'hv': ConcatFeaturizer([\n",
" partial(atom_type_one_hot, allowable_set=[\n",
" 'B', 'C', 'N', 'O', 'F', 'Si', 'P', 'S', 'Cl', 'As', 'Se', 'Br', 'Te', 'I', 'At'],\n",
" encode_unknown=True),\n",
" partial(atom_degree_one_hot, allowable_set=list(range(6))),\n",
" atom_formal_charge, atom_num_radical_electrons,\n",
" partial(atom_hybridization_one_hot, encode_unknown=True),\n",
" lambda atom: [0], # A placeholder for aromatic information,\n",
" atom_total_num_H_one_hot, chirality\n",
" ],\n",
" )})\n",
"bond_featurizer = BaseBondFeaturizer({\n",
" 'he': lambda bond: [0 for _ in range(10)]\n",
" })"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"train=os.path.join(RDPaths.RDDocsDir, 'Book/data/solubility.train.sdf')\n",
"test=os.path.join(RDPaths.RDDocsDir, 'Book/data/solubility.test.sdf')"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"train_mols = Chem.SDMolSupplier(train)\n",
"train_smi =[Chem.MolToSmiles(m) for m in train_mols]\n",
"train_sol = torch.tensor([float(mol.GetProp('SOL')) for mol in train_mols]).reshape(-1,1)\n",
"\n",
"test_mols = Chem.SDMolSupplier(test)\n",
"test_smi = [Chem.MolToSmiles(m) for m in test_mols]\n",
"test_sol = torch.tensor([float(mol.GetProp('SOL')) for mol in test_mols]).reshape(-1,1)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"train_graph =[mol_to_bigraph(mol,\n",
" atom_featurizer=atom_featurizer, \n",
" bond_featurizer=bond_featurizer) for mol in train_mols]\n",
"\n",
"test_graph =[mol_to_bigraph(mol,\n",
" atom_featurizer=atom_featurizer, \n",
" bond_featurizer=bond_featurizer) for mol in test_mols]"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"def run_a_train_epoch(n_epochs, epoch, model, data_loader,loss_criterion, optimizer):\n",
" model.train()\n",
" total_loss = 0\n",
" losses = []\n",
" \n",
" for batch_id, batch_data in enumerate(data_loader):\n",
" batch_data\n",
" smiles, bg, labels, masks = batch_data\n",
" if torch.cuda.is_available():\n",
" bg.to(torch.device('cuda:0'))\n",
" labels = labels.to('cuda:0')\n",
" masks = masks.to('cuda:0')\n",
" \n",
" prediction = model(bg, bg.ndata['hv'], bg.edata['he'])\n",
" loss = (loss_criterion(prediction, labels)*(masks != 0).float()).mean()\n",
" #loss = loss_criterion(prediction, labels)\n",
" #print(loss.shape)\n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
" optimizer.step()\n",
" \n",
" losses.append(loss.data.item())\n",
" \n",
" #total_score = np.mean(train_meter.compute_metric('rmse'))\n",
" total_score = np.mean(losses)\n",
" print('epoch {:d}/{:d}, training {:.4f}'.format( epoch + 1, n_epochs, total_score))\n",
" return total_score"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"model = model_zoo.chem.AttentiveFP(node_feat_size=39,\n",
" edge_feat_size=10,\n",
" num_layers=2,\n",
" num_timesteps=2,\n",
" graph_feat_size=200,\n",
" output_size=1,\n",
" dropout=0.2)\n",
"model = model.to('cuda:0')"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"train_loader = DataLoader(dataset=list(zip(train_smi, train_graph, train_sol)), batch_size=128, collate_fn=collate_molgraphs)\n",
"test_loader = DataLoader(dataset=list(zip(test_smi, test_graph, test_sol)), batch_size=128, collate_fn=collate_molgraphs)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch 1/100, training 8.8096\n",
"epoch 2/100, training 4.6130\n",
"epoch 3/100, training 3.0711\n",
"epoch 4/100, training 3.0123\n",
"epoch 5/100, training 2.6381\n",
"epoch 6/100, training 2.1367\n",
"epoch 7/100, training 1.7387\n",
"epoch 8/100, training 2.1711\n",
"epoch 9/100, training 1.8085\n",
"epoch 10/100, training 1.4492\n",
"epoch 11/100, training 0.9864\n",
"epoch 12/100, training 0.8572\n",
"epoch 13/100, training 0.8008\n",
"epoch 14/100, training 0.7456\n",
"epoch 15/100, training 0.8072\n",
"epoch 16/100, training 1.0538\n",
"epoch 17/100, training 0.7737\n",
"epoch 18/100, training 0.7873\n",
"epoch 19/100, training 0.8416\n",
"epoch 20/100, training 0.6274\n",
"epoch 21/100, training 0.6043\n",
"epoch 22/100, training 0.6074\n",
"epoch 23/100, training 0.5553\n",
"epoch 24/100, training 0.5958\n",
"epoch 25/100, training 0.6209\n",
"epoch 26/100, training 0.6549\n",
"epoch 27/100, training 0.6409\n",
"epoch 28/100, training 0.6334\n",
"epoch 29/100, training 0.6946\n",
"epoch 30/100, training 0.4690\n",
"epoch 31/100, training 0.4530\n",
"epoch 32/100, training 0.4636\n",
"epoch 33/100, training 0.4645\n",
"epoch 34/100, training 0.4646\n",
"epoch 35/100, training 0.4130\n",
"epoch 36/100, training 0.4516\n",
"epoch 37/100, training 0.3906\n",
"epoch 38/100, training 0.3860\n",
"epoch 39/100, training 0.3687\n",
"epoch 40/100, training 0.3096\n",
"epoch 41/100, training 0.3341\n",
"epoch 42/100, training 0.3655\n",
"epoch 43/100, training 0.3989\n",
"epoch 44/100, training 0.4029\n",
"epoch 45/100, training 0.5093\n",
"epoch 46/100, training 0.6887\n",
"epoch 47/100, training 0.5787\n",
"epoch 48/100, training 0.4626\n",
"epoch 49/100, training 0.4528\n",
"epoch 50/100, training 0.5052\n",
"epoch 51/100, training 0.3732\n",
"epoch 52/100, training 0.3826\n",
"epoch 53/100, training 0.3434\n",
"epoch 54/100, training 0.3236\n",
"epoch 55/100, training 0.3086\n",
"epoch 56/100, training 0.2946\n",
"epoch 57/100, training 0.2963\n",
"epoch 58/100, training 0.2978\n",
"epoch 59/100, training 0.3293\n",
"epoch 60/100, training 0.2771\n",
"epoch 61/100, training 0.2680\n",
"epoch 62/100, training 0.2831\n",
"epoch 63/100, training 0.2773\n",
"epoch 64/100, training 0.3020\n",
"epoch 65/100, training 0.3681\n",
"epoch 66/100, training 0.4934\n",
"epoch 67/100, training 0.4425\n",
"epoch 68/100, training 0.6564\n",
"epoch 69/100, training 0.6667\n",
"epoch 70/100, training 0.4864\n",
"epoch 71/100, training 0.5470\n",
"epoch 72/100, training 0.5098\n",
"epoch 73/100, training 0.4339\n",
"epoch 74/100, training 0.4412\n",
"epoch 75/100, training 0.4010\n",
"epoch 76/100, training 0.3537\n",
"epoch 77/100, training 0.3797\n",
"epoch 78/100, training 0.3168\n",
"epoch 79/100, training 0.2813\n",
"epoch 80/100, training 0.2970\n",
"epoch 81/100, training 0.2584\n",
"epoch 82/100, training 0.2962\n",
"epoch 83/100, training 0.2858\n",
"epoch 84/100, training 0.2995\n",
"epoch 85/100, training 0.2890\n",
"epoch 86/100, training 0.2670\n",
"epoch 87/100, training 0.2438\n",
"epoch 88/100, training 0.2523\n",
"epoch 89/100, training 0.3115\n",
"epoch 90/100, training 0.2887\n",
"epoch 91/100, training 0.3327\n",
"epoch 92/100, training 0.3209\n",
"epoch 93/100, training 0.4597\n",
"epoch 94/100, training 0.4024\n",
"epoch 95/100, training 0.3614\n",
"epoch 96/100, training 0.5085\n",
"epoch 97/100, training 0.3624\n",
"epoch 98/100, training 0.3706\n",
"epoch 99/100, training 0.3915\n",
"epoch 100/100, training 0.3003\n"
]
}
],
"source": [
"loss_fn = nn.MSELoss(reduction='none')\n",
"optimizer = torch.optim.Adam(model.parameters(), lr=10 ** (-2.5), weight_decay=10 ** (-5.0),)\n",
"n_epochs = 100\n",
"epochs = []\n",
"scores = []\n",
"for e in range(n_epochs):\n",
" score = run_a_train_epoch(n_epochs, e, model, train_loader, loss_fn, optimizer)\n",
" epochs.append(e)\n",
" scores.append(score)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[<matplotlib.lines.Line2D at 0x7fb16c261518>]"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAWoAAAD4CAYAAADFAawfAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAeKklEQVR4nO3dd3Td5Z3n8ff3Nl0Vq0tGxSA3XMAVBWwDwRjSMCWESeKcJWHZZEiGJJTNzCxMzm52die7JzuTnCTspLCEkEACSQgJCQQScMBgMBi54SI3XGVZtooly6q3PPuHrmTJBcvl+v6k+3mdo2PrNn0fS/7oud/f83t+5pxDRES8y5fqAkRE5P0pqEVEPE5BLSLicQpqERGPU1CLiHhcIBkvWlxc7KqqqpLx0iIio9KqVauanHMlJ7ovKUFdVVVFTU1NMl5aRGRUMrPdJ7tPrQ8REY9TUIuIeJyCWkTE4xTUIiIep6AWEfE4BbWIiMcpqEVEPM5TQf39pdtYtrUx1WWIiHiKp4L6R8ve43UFtYjIEJ4K6nDQT080nuoyREQ8xVNBnRHw0RONpboMERFP8WBQa0YtIjKYx4LaT09EQS0iMpi3gjroo1utDxGRIbwV1AGfZtQiIsfwWFD7dTBRROQYHgtqHUwUETmWp4Ja66hFRI7nqaDWOmoRkeN5K6iDOpgoInIsbwV1QK0PEZFjeSyofXRH1PoQERnMc0HdE43jnEt1KSIinuGtoA76AeiNqf0hItJvWEFtZveb2UYz22BmT5pZOBnFZAT6ylGfWkTkqFMGtZlVAPcA1c65SwE/sCQZxfTPqLXyQ0TkqOG2PgJAppkFgCygPhnFHJ1R64CiiEi/Uwa1c24f8G/AHmA/0Oac+8uxjzOzu8ysxsxqGhvP7HJaan2IiBxvOK2PAuAWYDxQDmSb2e3HPs4597Bzrto5V11SUnJGxWQE1PoQETnWcFof1wM7nXONzrkI8AywIBnFZAT7ytGe1CIiRw0nqPcA88wsy8wMuA6oTUYxA60PzahFRAYMp0f9NvA0sBpYn3jOw8koZqD1oRm1iMiAwHAe5Jz7BvCNJNeig4kiIifgqTMTw/3rqBXUIiIDPBXUR3vUan2IiPTzVlAH1foQETmWt4I6cTBRW52KiBzlsaDWjFpE5FgKahERj/NUUJsZIV3gVkRkCE8FNSSu8qIzE0VEBngwqHWBWxGRwTwX1OGgWh8iIoN5Lqj7L3ArIiJ9PBjUfp2ZKCIyiPeCOqgZtYjIYN4Laq36EBEZwoNB7dfBRBGRQTwY1Gp9iIgM5r2gDmodtYjIYJ4L6nDAp1UfIiKDeC6otepDRGQo7wV1wK/9qEVEBvFgUGtGLSIymAeD2k807ojGFNYiIuDFoE5cN7FXQS0iAngxqAeuRK6gFhEBTwZ13wVu1acWEenjuaAOB/uvm6iVHyIi4MGg1oxaRGQoDwZ1X0laSy0i0sd7QT3Q+tCMWkQEvBjU/a0PrfoQEQE8GdQ6mCgiMpj3glqtDxGRIbwX1AOrPjSjFhEBDwb1wDpq9ahFRAAPBrXWUYuIDOXBoNY6ahGRwTwb1JpRi4j0GVZQm1m+mT1tZpvNrNbM5ieroIDfh99nOpgoIpIQGObjvge86Jz7GzMLAVlJrKnvKi86mCgiAgwjqM0sF/gg8B8BnHO9QG8yi9LluEREjhpO62MC0Aj81MzWmNkjZpZ97IPM7C4zqzGzmsbGxrMqKiPgV+tDRCRhOEEdAOYCP3TOzQE6gAeOfZBz7mHnXLVzrrqkpOSsigoHNaMWEek3nKCuA+qcc28nPn+avuBOmoyAXz1qEZGEUwa1c64B2GtmUxI3XQdsSmZRGUEf3Wp9iIgAw1/18VXgF4kVHzuAO5NXklZ9iIgMNqygds6tBaqTXMuAjICfzt7o+fpyIiKe5rkzE0HL80REBvNmUGvVh4jIAG8GtdZRi4gM8GRQh4M6mCgi0s+TQd03o1ZQi4iAZ4Pap/2oRUQSPBvUPdE4zrlUlyIiknLeDOpg3+W4emNqf4iIeDOodZUXEZEB3g5qrfwQEfFqUPdfiVwHFEVEvBnUQbU+RET6eTOo+2fUan2IiHg0qBMzau1JLSLi1aDWwUQRkQEeDWodTBQR6efRoNbBRBGRfp4M6rBWfYiIDPBkUB9d9aHWh4iIN4NaM2oRkQHeDOrEjFpbnYqIeDaoNaMWEemnoBYR8ThPBrWZEQr4tI5aRASPBjUkrvKiMxNFRLwc1LrArYgIeDiox4QDHO6OpLoMEZGU82xQF2WHaDnSm+oyRERSzrtBnROiuaMn1WWIiKSch4M6g2bNqEVEvBvUxdkhWjp7icVdqksREUkpzwZ1UU4GzsGhTs2qRSS9eTioQwBqf4hI2vNuUGdnANB0RAcURSS9eTaoS8b0zagV1CKS7jwb1P0zarU+RCTdeTao8zKD+H2mtdQikvaGHdRm5jezNWb2XDIL6ufzGYXZIc2oRSTtnc6M+l6gNlmFnEhRdogmBbWIpLlhBbWZVQKLgUeSW85QxTkZan2ISNob7oz6u8A/Aifdd9TM7jKzGjOraWxsPCfFFeWo9SEicsqgNrMbgYPOuVXv9zjn3MPOuWrnXHVJSck5Ka44J0PL80Qk7Q1nRn0lcLOZ7QKeAhaZ2RNJrSqhKCdEZ2+Mzt7o+fhyIiKedMqgds496JyrdM5VAUuAvzrnbk96ZUCx1lKLiHh3HTUM2u+jQ0EtIukrcDoPds69CryalEpOoCinf0atPrWIpC9vz6iztYOeiIi3gzrR+mjSWmoRSWOeDuqsUICskF8zahFJa54OatBaahERzwe1zk4UkXTn/aDO1oxaRNKb54O6OCekddQiktY8H9RFOSFaOnqJx12qSxERSQnvB3V2BrG4o60rkupSRERSwvtBPXAaufrUIpKePB/UJYnTyHWlFxFJV54P6qKBoNaMWkTS0wgIau33ISLpzfNBXZAVwkw76IlI+vJ8UPt9RmFWiCatpRaRNOX5oAa4IC/Mixsa+OGr79HWqWV6IpJeRkRQ/69bZzD1gjF868XNzPvfS/nZm7tSXZKIyHkzIoJ61rh8fvm383jh3quZUZHHv/55C92RWKrLEhE5L0ZEUPebVpbLfddP5khPlJdrD6S6HBGR82JEBTXAFROKGJubwe/X1Ke6FBGR82LEBbXfZ9w8q5xlWw9ySCtBRCQNjLigBrhldgWRmOP59ftTXYqISNKNyKC+pDyXSaU5PLt2X6pLERFJuhEZ1GbGrXMqeGfXIeoOdaa6HBGRpBqRQQ1w86xyAJ5dq4OKIjK6jdigHleYRfVFBfxxnYJaREa3ERvUANdNG8vmhnYOtnenuhQRkaQZ0UF91aRiAFa815ziSkREkmdEB/X08lzys4Is39aU6lJERJJmRAe132csmFjEG9ubcE5XKReR0WlEBzXAlZOKqW/rZmdTR6pLERFJihEf1P196je2q/0hIqPTiA/qCwuzqCzIZLmCWkRGqREf1GbGVZOKefO9ZmJx9alFZPQZ8UENfX3q9u4o6/e1pboUEZFzblQE9YKJRcCJ+9Q90Ri1+w+f75JERM6ZURHURTkZTC/LPeF66kde38ni77/OLq0KEZER6pRBbWbjzOwVM6s1s41mdu/5KOx0XX1xMe/saqG1c+jFBF7YsJ+4g9+t0ZaoIjIyDWdGHQW+5pybBswDvmxm05Nb1um74dIyonHHXzYdvZZifWsXG/YdxmfwzJo64jrYKCIj0CmD2jm33zm3OvH3dqAWqEh2YadrZmUelQWZPP/u0au+9F8A9+8WTmRvSxc1uw+lqjwRkTN2Wj1qM6sC5gBvn+C+u8ysxsxqGhsbz011p8HMWDyzjDe2Nw20P17adIAJxdncvXASWSE/z6yuO+91iYicrWEHtZnlAL8F7nPOHbeMwjn3sHOu2jlXXVJSci5rHLbFMxLtj40HONwd4a0dzXxo+liyMwJ89NILeP7d/XRHYimpTUTkTA0rqM0sSF9I/8I590xySzpzMyryGFeYyXPr97NsSyORmOND08cCcNvcStp7orw0qIctIjISDGfVhwE/AWqdc99JfklnzsxYPKOcN7c38ZtVdRRlh5hzYQEA8yYUUZYXVvtDREac4cyorwQ+Cywys7WJjxuSXNcZu3FmX/vjta2NLJpait9nQN+WqB+fU8Fr25p0RRgRGVGGs+pjuXPOnHMznXOzEx9/Oh/FnYlLynO5sDALYKDt0e+2uRXE4o4/6IK4IjKCjIozEwczM26bW0l+VpCrJw89qDmpdAyzKvN4epXaHyIycoy6oAb4yqJJLPuHa8kM+Y+777bLKtnc0M7Gem3gJCIjw6gMar/PyMsMnvC+m2aWE/Qbv12lU8pFZGQYlUH9fgqyQ1w3dSzPrt1HJBZPdTkiIqeUdkENfe2P5o5elm05/2dQioicrrQM6oVTSijKDvFbrakWkREgLYM66Pdx8+xyltYePG5bVBERr0nLoAb4+OwKemNxXtlyMNWliIi8r7QN6hkVeRTnhHhVfWoR8bi0DWqfz/jgxSUs29qoq5eLiKelbVADXDullNbOCGv3tqa6FBGRk0rroL56cjE+g2XqU4uIh6V1UOdnhZh7YQGvqE8tIh6W1kENcO3UUtbva9PWpyLiWWkf1Ndc3LfD3mtbm1JciYjIiaV9UF9SnkvpmIyzWk+9tPYA63RAUkSSJO2D2sxYOKWE17Y2Ej2DTZr2tnTyd0+s5p6n1pzR80VETiXtgxpg4ZRS2rujPLP69Lc+/c5LW+mNxdnd3MmfNjQkoToRSXcKauC6aaVcMb6QB55597Su/rKxvo3fr93HFz84gUmlOfzgle04p5NnROTcUlADGQE/j915OQsmFvMPT6/jl2/vGdbz/s+LW8gNB7l74STuXjiRzQ3t/HWz1mSLyLmloE7IDPl55I5qFl5cwj/9bj1ffLyGN99rOukM+c3tTSzb2siXr51IXlaQm2aVU5Gfyb+fYFbd1hnhp2/s5PEVu5I/EBEZdQKpLsBLwkE/P/rsZTy0dDu/eHs3f954gMmlOXxu/kV8fE4FY8JB4nHHK1sO8i/P11KeF+Zz86uAvq1Tv3TNBP7rsxt5YUMDY3PDHDjczSubD/LHd+vpjvQdaJxYmsOCicUpHKWIjDSWjJ5qdXW1q6mpOeevez51R2L8cV09P1+xm/X72sgO+blhRhlr97ay7eARKvIz+dZtM7lqcvGQ51z1rVdoOtIzcFtWyM/H51TwN5dVct9Tawn4jBfuu5qMwPEX3hWR9GVmq5xz1Se8T0H9/pxzrKtr4/EVu/nju/VMKM7mS9dMZPHMMoL+4ztH79a1sqn+MGPzwowdE6aqOIusUN8bl2VbG7nj0ZXcf/3F3Hv95PM2hvrWLvIyg2RnjI43UM45onF3wn9/kZFKQX2ORGNx/D7DzM74Nb765Br+vLGBF++9mgklOWdVT1tnhLV1rWQG/VyQG6Y0N4Nw8OhMfdXuFn68bAcv1R5gwcQinvj8FWdVuxc45/jKL9fw9s4WfnT7XKqrClNdUsqt29vKu3WttHREaOuKcPPscmaPy091WXKaFNQecrC9m+u+vYyK/ExumlXOlLFjGFeYRSzuiMTihIN+Lh6bc1ygxuKOnU1H2Fh/mPV1bby1s5mN9Yc59tuXGfSTEw4Q8vvY19pFflaQ6osKeLn2ID/4D3O5YUbZeRztuffYGzv573/cRF5mkM7eKN+8dQafqh6X6rJSZsO+Nm7+v8vp31I96DeyMwL86Z6rKc/PTG1xcloU1B7zp/X7+ebztexr7Trh/ZNLc/j0B8axcEopq3cf4qXaA7yxvYnO3hgAIb+PORfmM39iEZdXFRJzjoa2bg4c7uZwd5T27ghHemJcdmE+n/rAODICfm58aDltnb0s/dpCMkMjsz++vq6N2374JldPLubbn5rFV365huXbm/jCVeN58IZp+H0j+93C6YrFHbf+4A3qW7v53d0LuCAvTN2hLm78/utML8/lyb+dR0DtoRFDQe1R7d0Rth44Qn1rF0G/j1DAaGjr4Ter9rJmz9G9Q8rzwiyaVsqccQVML89lUmnOafdnV+5s4VM/XsFXF03iax+ecq6HknTt3RFufGg5kWic5++5moLsENFYnH95vpbH3tzFRy4Zy3c/PWfE/hI6Ez97cxff+MNGvrdkNrfMrhi4/fdr9nHfr9bylWsn8fcfmcKq3S08unwXR3qiXD25mGsuLmFS6fHv2iS1FNQj0JaGdlbubOayiwqZVjbmnPynuvepNbywoYGX77+GC4uyzkGV50dbZ4Qv/3I1K3Y086u75h3Xl/7pGzv5H89tYlZlPo/cUU1xTsbAfdsOtPPUO3up2dXCN2+dwaUVeee7/HOioyfKj1/bwSXluSyaWkrzkV6u/84y5lyYz8//0+XH/Xz8w2/W8fTqOi4tz2P9vjbyMoMU54R4r7EDgCljx/D1xdP4YGL3SBmqrSvCFx+v4dLyPL6+eNp5+aWmoBYAGtq6WfTtV4nGHRX5mVTkZxIK+Gju6OVQRy+dvVHiDuLODel9F2QFuWV2BZ+srqSyIIt9rV28uuUgOxo7uGV2OTMrk3fgaktDO3c9XkN9a9f79qNf3NDAvU+tIRz0M64wk4KsEIe7o6zb2zrQt3UOnvj8FcyoHFlh3dYZ4c7HVrI68S6rKDtEyZgMdjZ18Jf7P8hFRdnHPaezN8ptP1zBkZ4IX7hqAp+sriQrFBj43j382g52N3eyaGopdyyoIhqLc6QnSkbAx4JJxeSGg+d7mOfNvtYuHvjtu8wel8/t8y5ibG54yP3dkRife3QlK3e2AHDPokn85/PwLlRBLQNqdrXw0qYD1LV2se9QF5FYnMLsEIXZIbIzAvjNMAODgVnEe41HWL69b7/ucQVZ7GnpBMDvM2Jxx7wJhdx55XgqCzIJB/1kBv2UjMk47fZM05EeVu5soaGtm55onMPdEX725i6yMwL86Pa5XHbR+6/wWLe3lZ+t2MWhjl4OdUZwzrF4ZhmfmFtJV2+MJQ+/RXt3hCe+cMXAL5doLO7pPm5jew+fe3Ql2w+2891PzyEr5OfXNXt5ufYAf//hKXzxmoknfW4s7vAZJ5wN9kRj/PzN3Xx/6Tbae6JD7gv4jMvHF/Kh6WNZPLOM0jHh454/Uh083M2nfryC/W3d9Mbi+M342IwyPjGngvkTiwj6fdz9i1X8ZdMBvrdkDsu3NfLrmjr++eZLuGNB1ZDX6o3G2d3cQV5mkNLcs/83UlDLWas71MnTq+rYsK+NK8YXce3UUsbmZvDkyj08unwXDYeHXiHH7zPK8sJcVJTF4hnlfGJuxZClg9C3vvudXS28s6uFt3e0sO3gkeO+7uXjC3noM3OOm/Wc6RiWPPwWTUd6yMsM0toZoScaH2gnXDu1lOllucfVmQrOOZZvb+K/PbuRhrZufvzZy4a0KXqjcUKBs/8Fc6ijl9qGw+RkBMjJCNDc0cvS2oMsrT3AtoNH8BlcOamYKycVs6upgw31bTS0dXPL7Aq+cPV4yvJGzsqS5iM9LHn4Lfa1dvH456+gJCeDn6/Yxa9q9tLeHSUr5KeqKJtN+w/zjZumc+eV44nG4nzpidUs3XyAm2eV05uYQNS3drOnpZNY3JER8PHAx6Zyx/wqfGdxQFtBLUnVG41Ts6uF9p4o3ZEYHT0x6lu72Huok9r9h9l64AgFWUGWXH4hPoPa/e1sqj88EO7ZIT+XVRUyf0IR8yYUUlWUTTjoJxTwnfOVHPtau/jey1sBKMgKEfT7eHtnM6t2HyLuwAwqCzIZX5xDRX4m5XlhyvIzKcsLMzY3zAV5YcIBHzHniMX7PuJxiDnH7uYOanYd4p1dLTjgtrmVXDet9LTeWXRHYqzY0cxDS7exek8rZXlhHvrMnJSsF99+sJ1n19bz7Np69rR0kp8VZEZFHtmhAC/VHsBncOucCj42o4wrxhcOnNgFfTP2E519G4s7DIYEmnNu4B1QflbotL/n+1q7qNnVwkVF2UwuzSEU8LHivWZe2LCft3a0kJ3hpzgngz3Nnexr7eKxOy9n/sSiged3R2K8taOZlzYdYNnWRm6bW8n9H7p4yP33PLmGdXWt5IaD5GYGKR2TwcSSHMYXZ/P8+v38dfNB5k8o4l8/OZPKgjM7/qOglpRxzrFyZwuPLN/Jy7UH8JkxsSSbaWW5zB6XzweqCpl6wZiUtx8OdfTyxntNbDtwhJ1NHexs6mBfaxctHb2n/VpVRVl0RWIcONxDcU6I66eNpTgng7zMIH6f0XC4m/rWLlo7I4SDPrJCARywtaGd7Y1HiMUd5Xlh7r52Ep+srkz5dgPOOVo6einMDg20Ufa2dPL/Xt/Br2v20h2JE/Qbl1bk0R2JU9/aRVtXhPHF2SycUsLVk4upO9TFsi2NvPleMz3RGAVZIfKzgvRE4xw83ENv4qIbZpCfGeSCvEwml+b0fYwdw5QLxnBhYdZAiPefMfyT5Tv50/r9xOJHcywz6KcrEiM75GfBpGKisTjNHb30ROL80+JpA5ffO5f/Pr96Zy//87lNhIN+Xv8v1w75pTVcCmrxhJaOXrJCfk+0FoarOxJjf1v3wDr1hsPd9Eb7zlD1+4yAz/CZ4TMozQ1TXVVA6Zgw0Vic17Y18uTKvhUnbV2RgZNSQgEfZXlhCrJC9EbjdEVixOKOSaU5TCsbw4yKPBZNHXtOWhvJ1h2J8c6uFpZvb2L17kPkhoNUFGRSmB1i3d7WRDD3hfBFRVlcc3EJeZlBWjp6OdTZS8jvG9huwe8zWjp6aenope5QJ1sPHBlyrkE46KN0TJj27giHu6PE4o4xGQGWXD6Om2aVs7+tm60N7Rxs7+GqxDLE8/mztrelk3V1rdw4s/yMnq+gFkkx5xxHeqJEYo6CrGDarGHu6o2xes8hyvMzGV98/OqUU+noibLt4BG2NrSzuaGd5o6eRPshQEV+FjfNKmPMKFmh8n5BPTp26RHxODMbNYFyOjJDfq6cdObb+mZnBJg9Lj/t9y4Z1nsrM/uomW0xs+1m9kCyixIRkaNOGdRm5gf+HfgYMB34jJlNT3ZhIiLSZzgz6suB7c65Hc65XuAp4JbkliUiIv2GE9QVwN5Bn9clbhMRkfNgOEF9osPTxy0VMbO7zKzGzGoaGxvPvjIREQGGF9R1wOCdcCqB+mMf5Jx72DlX7ZyrLinRjlwiIufKcIL6HWCymY03sxCwBPhDcssSEZF+p1xH7ZyLmtlXgD8DfuBR59zGpFcmIiJAks5MNLNGYPcZPr0YaDqH5YwE6ThmSM9xp+OYIT3Hfbpjvsg5d8K+cVKC+myYWc3JTqMcrdJxzJCe407HMUN6jvtcjtn7u76IiKQ5BbWIiMd5MagfTnUBKZCOY4b0HHc6jhnSc9znbMye61GLiMhQXpxRi4jIIApqERGP80xQp8ue12Y2zsxeMbNaM9toZvcmbi80s5fMbFviz4JU13qumZnfzNaY2XOJz9NhzPlm9rSZbU58z+eP9nGb2f2Jn+0NZvakmYVH45jN7FEzO2hmGwbddtJxmtmDiXzbYmYfOZ2v5YmgTrM9r6PA15xz04B5wJcTY30AWOqcmwwsTXw+2twL1A76PB3G/D3gRefcVGAWfeMfteM2swrgHqDaOXcpfWczL2F0jvkx4KPH3HbCcSb+jy8BLkk85weJ3Bse51zKP4D5wJ8Hff4g8GCq6zpPY38W+BCwBShL3FYGbEl1bed4nJWJH9xFwHOJ20b7mHOBnSQO2g+6fdSOm6PbIhfSt0XFc8CHR+uYgSpgw6m+t8dmGn1bcswf7tfxxIyaNN3z2syqgDnA28BY59x+gMSfpamrLCm+C/wjEB9022gf8wSgEfhpouXziJllM4rH7ZzbB/wbsAfYD7Q55/7CKB7zMU42zrPKOK8E9bD2vB5NzCwH+C1wn3PucKrrSSYzuxE46JxblepazrMAMBf4oXNuDtDB6HjLf1KJnuwtwHigHMg2s9tTW5UnnFXGeSWoh7Xn9WhhZkH6QvoXzrlnEjcfMLOyxP1lwMFU1ZcEVwI3m9ku+i7ltsjMnmB0jxn6fq7rnHNvJz5/mr7gHs3jvh7Y6ZxrdM5FgGeABYzuMQ92snGeVcZ5JajTZs9rMzPgJ0Ctc+47g+76A3BH4u930Ne7HhWccw865yqdc1X0fW//6py7nVE8ZgDnXAOw18ymJG66DtjE6B73HmCemWUlftavo+8A6mge82AnG+cfgCVmlmFm44HJwMphv2qqm/GDmus3AFuB94Cvp7qeJI7zKvre8rwLrE183AAU0XewbVviz8JU15qk8S/k6MHEUT9mYDZQk/h+/x4oGO3jBv4Z2AxsAB4HMkbjmIEn6evDR+ibMX/+/cYJfD2Rb1uAj53O19Ip5CIiHueV1oeIiJyEglpExOMU1CIiHqegFhHxOAW1iIjHKahFRDxOQS0i4nH/HzqBbe3e3JM0AAAAAElFTkSuQmCC\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.plot(epochs, scores)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"AttentiveFP(\n",
" (init_context): GetContext(\n",
" (project_node): Sequential(\n",
" (0): Linear(in_features=39, out_features=200, bias=True)\n",
" (1): LeakyReLU(negative_slope=0.01)\n",
" )\n",
" (project_edge1): Sequential(\n",
" (0): Linear(in_features=49, out_features=200, bias=True)\n",
" (1): LeakyReLU(negative_slope=0.01)\n",
" )\n",
" (project_edge2): Sequential(\n",
" (0): Dropout(p=0.2, inplace=False)\n",
" (1): Linear(in_features=400, out_features=1, bias=True)\n",
" (2): LeakyReLU(negative_slope=0.01)\n",
" )\n",
" (attentive_gru): AttentiveGRU1(\n",
" (edge_transform): Sequential(\n",
" (0): Dropout(p=0.2, inplace=False)\n",
" (1): Linear(in_features=200, out_features=200, bias=True)\n",
" )\n",
" (gru): GRUCell(200, 200)\n",
" )\n",
" )\n",
" (gnn_layers): ModuleList(\n",
" (0): GNNLayer(\n",
" (project_edge): Sequential(\n",
" (0): Dropout(p=0.2, inplace=False)\n",
" (1): Linear(in_features=400, out_features=1, bias=True)\n",
" (2): LeakyReLU(negative_slope=0.01)\n",
" )\n",
" (attentive_gru): AttentiveGRU2(\n",
" (project_node): Sequential(\n",
" (0): Dropout(p=0.2, inplace=False)\n",
" (1): Linear(in_features=200, out_features=200, bias=True)\n",
" )\n",
" (gru): GRUCell(200, 200)\n",
" )\n",
" )\n",
" )\n",
" (readouts): ModuleList(\n",
" (0): GlobalPool(\n",
" (compute_logits): Sequential(\n",
" (0): Linear(in_features=400, out_features=1, bias=True)\n",
" (1): LeakyReLU(negative_slope=0.01)\n",
" )\n",
" (project_nodes): Sequential(\n",
" (0): Dropout(p=0.2, inplace=False)\n",
" (1): Linear(in_features=200, out_features=200, bias=True)\n",
" )\n",
" (gru): GRUCell(200, 200)\n",
" )\n",
" (1): GlobalPool(\n",
" (compute_logits): Sequential(\n",
" (0): Linear(in_features=400, out_features=1, bias=True)\n",
" (1): LeakyReLU(negative_slope=0.01)\n",
" )\n",
" (project_nodes): Sequential(\n",
" (0): Dropout(p=0.2, inplace=False)\n",
" (1): Linear(in_features=200, out_features=200, bias=True)\n",
" )\n",
" (gru): GRUCell(200, 200)\n",
" )\n",
" )\n",
" (predict): Sequential(\n",
" (0): Dropout(p=0.2, inplace=False)\n",
" (1): Linear(in_features=200, out_features=1, bias=True)\n",
" )\n",
")"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.eval()"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"all_pred = []\n",
"for test_data in test_loader:\n",
" smi_lst, bg, labels, masks = test_data\n",
" if torch.cuda.is_available():\n",
" bg.to(torch.device('cuda:0'))\n",
" labels = labels.to('cuda:0')\n",
" masks = masks.to('cuda:0')\n",
" pred = model(bg, bg.ndata['hv'], bg.edata['he'])\n",
" all_pred.append(pred.data.cpu().numpy())"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"res = np.vstack(all_pred)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(257, 1)"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"res.shape"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(257, 1)"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"test_sol.numpy().shape"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Text(0, 0.5, 'exp')"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYAAAAEGCAYAAABsLkJ6AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3df5Cc9X0f8PfnVivYs8daMchDtXBIcW0pEQKdOUCtkmEkY0SHWr6B2AqFKdN0orFdu5UGy5ECI4RLI41l17h16o6aMJPWDBYYcsZVEtlUclqTiHDqSZYVS4mJg2DlTOSiI7Fu4fZOn/5x+5z2nnu+z/N9dp9f+zzv1wyDdm/39rtY/n6e5/P9fD9fUVUQEVHx9KU9ACIiSgcDABFRQTEAEBEVFAMAEVFBMQAQERXUgrQHEMbVV1+ty5YtS3sYREQ95dixYz9T1SXu53sqACxbtgyjo6NpD4OIqKeIyGtezzMFRERUUAwAREQFxQBARFRQDABERAXFAEBEVFAMAEREBdVTZaBERL1kZKyOfYfO4Nx4A0urFWzfuALDg7W0hzWLAYCIKAYjY3XsfP4kGs1pAEB9vIGdz58EAOsgEHcAYQAgIorAyFgdu184hfFGEwDQJ8Al13ErjeY09h06YzWJRxFAgnANgIioSyNjdWx/9sTs5A/Mn/wd58YbVr9z36Ezs5O/wwkgUWEAICLq0r5DZ9A0zfguS6sVq9eZAoVtALHBFBARUQheeXnbSblSLmH7xhVWr11araDu8XttA4gNBgAiogDOpF8fb0AAONf6Tl5+UaU8J/3jJkDoRdztG1fMWQMAwgUQGwwAREQ+3Iux7kRPozmNK8vmbHq1UsbxR++0/qz2u4t7b67hyOnzrAIiIkqD12Ks2/hEEw+sHcDXj56d83y5T7B70yqrz/Gq+nnuWB177lkd294BLgITUaGMjNWxbu9hLN9xEOv2HsbIWN339Tb5/aXVCh4fXo0nNq9BrVqBAKhVK9j3sZusJ+8kqn7ceAdARIXRSW29aTHW0Z6XHx6sdXy1nkTVjxvvAIioMDq5yt6+cQUq5dKc56T171q1ElmKxlTdE2XVjxvvAIioMDq5ynYm9zhaMrQv+lb7yyj3yZz9BFFX/bgxABBRYXRaW99NasfEnY66MNFEuSSoVsp4q9FMpHkcAwARFYZfbX1SnTvb9xS4NacV77pigXXZaLe4BkBEhTE8WMOee1bPqdTZc89qAMDO50+iPt6AYmZxeOuB4xj8/HcCq4TCcK76/RaV41z0deMdABEVilc6Z93ew561/hcmmpF24LTZUxDnoq8b7wCIqPD8rrqjrMUPurqPe9HXjXcARFR4QbX+7RN3N2sFfp9TS+HEMN4BEFHhedX6t3PSMu05fGetYOfzJ63XCbw+p1Iu4YnNa/DSjg2JHxfJOwAiyi3bq3XnufYTvRztaRm/jWQ2k3ecewo6kWoAEJG7AHwFQAnA76rq3jTHQ0T54df2AfCehIcHa75BI4p2DXHsKehUagFAREoAfgfAhwG8AeAVEXlBVf8irTERUX6YrtYf+/YpvN28ZOwH5DdBJ3FIS5LSXAO4FcCPVfWvVXUSwDcAfDTF8RBRjpiuyi9MNDvuumnK4SdZuROlNFNANQCvtz1+A8Bt7heJyBYAWwBgYGAgmZERUc8Lquxxs0njhM3hJ7W7uFNpBgDxeG7eqcqquh/AfgAYGhqyO3WZiFKVhYnPq+2DH9s0jl+KqP17L6qUcXFyCs3pmWnLpvV00tIMAG8AuK7t8bUAzqU0FiKKSCc996P4TFPAcZ4PunrsNo3j/t5eZwSHqRhKQpoB4BUA7xeR5QDqAH4NwL9IcTxEFIFuSyXDCgo4zmeu23vYmBISr3xESDZtHoBke/0ESW0RWFWnAHwawCEAPwLwjKqeSms8RBSNpE+2sj3kZfvGFSiXvGd6VYTa0OXF9vtlqWIo1X0AqvqHAP4wzTEQUbSiLpUMWk+wCTjO73Dy8V66vUuxWXTOWsUQW0EQUaSiLJW0ab0QdJSiTQtmR5iqITev713uEyzuL89pPZ2V/D/AVhBEFLEo2x0ErSeMjNUxMTk1731B7RtMSl0sBmStzYMNBgAiilxU7Q780jvuxV9HtVLG7k2rAts3eJnW7irNs9TmwQZTQESUWab0jgLYeuC455W9yOWr8ZGxOvpCXNXXMrRAmwTeARBRZoXdzAXMtHq4/7/9Gf7ip/+ACxPza/FNsrZAmwQGACLKrPa8epgF2pdefTPwNSIz5Z/A/LSRlyzsbo4aU0BElGnDgzW8tGODZ++YTlXKJbSn+9+ZuuT7+m4PgskqBgAi6glRbaAqiYTuBmq72azXMAAQUU/YvnEFyn3d3QdUyiVjpY/tmcA2z/cKBgAi6gnDgzXs+9hNqFbKs88t7i9j3fuusnp/tVLGnntWG2v9/fYABG0261UMAETUM4YHazj+6J34m71344nNa9C/cAH+9NU38a6F8w90d6bzWrWCJzavwfFH78TwYM14B+C3ByBvB8E4WAVERKkLW2Hj3gR2cXIa5ZLgXQsX4K1G0/g7RsbqEHgcPAL/PQC9uMvXhmiXO9+SNDQ0pKOjo2kPg4gi8sjISTx19Oy8CTloMje1di6J4Esfv8k4MZveJwC+vHlNz0/oJiJyTFWH3M/zDoCooNKua39k5CS+fvSs58+a0zp7oIrXgTKmxddpVWw9cBxbDxxHLUTnUEV2TulKEtcAiAooqbr2kbE61u09jOU7DmLd3sNzfv/TL7/u88653CWXNouvYTqHFq0FhIMBgKiAkqhrDwoyYRuvtadu/A53aef+TnldzO0UAwBRAXVb1+53Ze+IOshI63OBmXTNAss9AfXxxuxY9x06g3tvrqFWrWS2R3+SuAZAVECm06uq/WWPV89le+i7KZjUxxtYtuNg6DErMOfErkbTv32D+zOdfz93rF7oSb8d7wCICsiUQvn521OzV9mmq3zbK/s4NklFsfM2Dy0cosI7AKICGh6sYfcLp2YrbRzNSzo7OZqu8m3P4L1w8Z3Ix90eVBb3l0O1e27X6y0cosI7AKKCeqvhPXmeG2/4XuXbnMG7/ZsnMBEiReNY7JOCci/WPvqRVVYLwV4WVcqBaxhFwABAVFB+E7nfVb7p8POJySks33EQDz1zAs3p8BtMa9XKzKTusbjr9PFpz9sPD9aw71dvmrOg+8DagXljcyv3CS5OTuWutXMnmAIiKiiv07acq2zTASxLq5V5bRGuLPeh0bw0m47p9FzdickpjL72Jrwa/5v6tHmdwTt0/VVzNritX7kER06fn308MTk1L3XUftB8kTAAEBVUUH8bU3Bw3js8WMPIWB3bDhyPZDwXJpqebSGcn3lVGnkJOph9uaECqYjrAgwARAVmmixtm5/tO3TGc8LulN/viuoq3VQC2+utnTvBAEBEntqDg9M3aNuB43PSKmHO6Y1CFFfpfqmvomEAICJfXhu/TE3c/LjbMJvaMvux2agWJK+tnTvBAECUM1F3+fQqCQ2rUi7h3ptrs4uxiyplXHynibCVos5GtW4n66B1gqJgACDKEZs2DSNjdTz27VOzlTDVShm7N60yToidpl2cK3x3W+aRsTq2P3si9OQPXN6oZjt5p93yOutSCQAisg/ARwBMAngVwL9S1fE0xkKUJ0FtGrx2/443mtj+7AkA3hU2pkVTPwLg/rUDeHx4NYDLbSXOjTfQJ9JxqSgQrmGdTc+iIktrI9h3AdygqjcC+EsAO1MaB1Gu+DVg2/n8yXmTv6O9BYSb18avIArgyOnzAOa3hQ6a/J1NXe2Hv7ezrdZJouV1r0slAKjqd1R1qvXwKIBr0xgHUZ6MjNXRZ9gxVRIJzOM7bZPdzeC2HTiOKxb0YXF/2WuPlpETjMKsIdSqFfxk7914accG7N60qqve/d22vC6CLKwB/DqAA2kPgqiXOVfZXlfXlXLJegJ27hRGX3sTzx2rz75vvNGcTesc/MFPrZqwOVfqtumjPsGcyb3bah3W+weLLQCIyIsArvH40cOq+q3Wax4GMAXgKZ/fswXAFgAYGBiIYaREvc90lV0SwZ57VhtbO3hpNKfx9MuvzwsmCuDrR8/C5hwW50p9ZKxuXe5Z8vjF3VTrsN4/WGwBQFXv8Pu5iDwI4J8D+JCqOSmoqvsB7AeAoaGhKDcdEuWGKa1xSXV2At0aomWDX57+UsD/CwWYbdy2bu9h61r/5nS4Cp8grPcPllYV0F0AfhPA7ao6kcYYiNIUdXliULrD1P8/DtX+8ux3CZtvjzo/z3p/f2mtAXwVwBUAviszi1ZHVfUTKY2FKFHdlie2B49qfxmql3P07Vfb7nTH7k2r5qVE4jDeWh9wFqXDlHwurVZYu5+gVAKAqv7jND6XKAv8yhODJjp38GhfjFX4b75yPrfUmpQ7acVgw5nETYvSJpVyCetXLmHtfoJ4IAxRwropTwwqqXQm/5d2bJgz+Tt1+MBMfr9SLuH+teGKKkqmpvxt2s8TCHOnsbh/5sCXI6fPs3Y/QQwARAkLOlLRj02QcL9m9wunPCfVI6fPGw9acSuJ4EsfvynwdffeXAtVcVSrVvDE5jUY23UnhgdrrN1PGAMAUcK8dtbalifaBIk+kdnNXI+MmHf/1scbsM3Q3HfbdRgerPme2bu4v4znjtWtJ38B5t2pmDaysXY/HgwARAkbHqxhzz2r55xl6z7v1s3ZlWszuU6rYvuzJ/DIyEk81UHbZi9OTx/TQezlPoEqQqV92if1oI1srN2PRxZ2AhMVTpjyRPfCr43mJTUerxhWrW2idsbs1U3U72jIoAqloI1sXACOBwMAUcZ12o8/qgqf9SuXzHlsCl6m3L9TkeRX2mmzkY2ixwBAlBGm+ve0F0APvPI6hq6/KnAi9mu9EHTHw7496eAaAFEGuFsmO/XvI2P11CfB5rTioWdOYPmOg3O6hbp1srbh6GZhnDonPm14MmdoaEhHR0fTHgZRJNqv+E07Zp30SRI7eG1VyqVI8/LOf4f6eGN2k5p7Ixt1R0SOqeqQ+3neARClwPaQlHPjjTlX1gCsunHGKcqNWaZNapz8k8EAQJQC24Vdd2+cWrWC//jxNXMqc6LQXw43FUS1LsFTu9LFAECUApsJtFyS2d447rWB9SuXhD6mEcC8E70EwANrB9AIeUJ7VOsS3PmbLgYAohTYTKDT04qDP/ipsY3DnntWG8/NdVv3vqtQq1aguNzTp1at4Mub1+Dx4dWhJvQoF2e7aYtB3WMAIErB+pVLAs/XvQQYj150rpDfmbK7cv/TV9/0zbMHHfzujDVMZY8NVv+ki/sAiBI2MlbHc8fqXW3UWlQp46FnTli3W3a/yt1+2n16lnPOwFuNZqw9+XlqV7oYAIgS1unO3nZRnOzlzrOndXoWT+1KDwMAUcLiXuB0990xHfxS7S9j3d7DODfewKJKGSIzp3nxKrw4uAZAlLA4Fzidg17ad+Pev3ZgXp69XBL8/O2p2eqi8UYTFyaa83YhuzldSYN2BVNv4B0AUcLi2tnbJzAu0A5df9WcPPvFd6Z800heR1R2e5YxZQ8DAFEKriz3RR4A/I5sdOfZl+84GPj73Kmqbs4ypmxiCohyL0tpC+cq2lTe2Y3mJbXeQWuThnK/hpu28ocBgHLNr8tmGrzO5zWpVsqhd/vaTsbuHv9uXrX43LSVPwwAlGtx9Jrp9I5iZKweqnzzrUZzXnvlJzavwROb1xjTPbaT8ZHT540/M2324qat/OEaAOVaUNrCdAiLSacLoSNjdTz0zIlQY+8TwbYDx7G01bLB/ftNh6/YMP13cQ5q98JNW/nDAEC55nfSVCeTeZiF0PY+96ZafD/OLl+vcXU7GXd6Ahc3beULD4ShXPM6UN050MR0hm1JBJdUPSfV5TsOek7kAuAne+/2/dwgi/vLGJ9o+h4OY7o6D8vvvwsn+PwxHQjDOwDqKWFTNn5XytsOHPd8j9+Vt+nKuX9haXZX7dJqBROTU6HLPPsXLsDYrjuNJZpRVtswnUOAZQAQkSsBfArAL2PmTvb7AL6mqm/HODaiOTrNv5vSFqbJvJ07vbN94wps/+YJNKfnXqFfnJzGxcnG7Lg64UzwSR2QznQO2VYB/XcAqwD8ZwBfBfCLAP5HXIMi8hJ1RU9QC2RH+5X38GAN71oYz42zM8Gz2oaSYvs3eYWq3tT2+IiIhCtpIOpS1BuR3GkQU+7dfeX9VoedOJ2F4P5yHyZcJ3C1T/BMz1BSbAPAmIisVdWjACAitwF4qdsPF5HPAtgHYImq/qzb30f5FkdqpD0NYloY9doQFTbNU3NN4kFrGUzPUBJsA8BtAP6liJxtPR4A8CMROQlAVfXGsB8sItcB+DCAs0GvJQK8m6hFmRqxvfJev3IJnjp61rqss7/cN696hxM8ZYFtALgrhs/+MoDPAfhWDL+bciiJ1EjQxNzJaV4TzUsYGat7/t6wVU3dvo+onW0AeL+qvtj+hIg8qKq/38mHisgmAHVVPSE+HQxbr90CYAsADAwMdPJxlCNxXDn7Tabun5nKO4M2epk2ivlVNZnGxbbMFBXbALBLRO4F8FkA7wbwuwDeAWAMACLyIoBrPH70MIDfAnCnzQer6n4A+4GZjWCW4yWy4jeZApj3M5Ogv5heC9VBVU2mcbEtM0XFNgDcDuAhAM7OmV2q+rTfG1T1Dq/nRWQ1gOUAnKv/awH8XxG5VVX/1nI8RJEwTaZbDxxHyVAV5GVxf9m3xbPXQrVfVZPfJM+2zBQV230AizGzEPwqZq78r5eg3I2Bqp5U1feq6jJVXQbgDQAf5ORPafCbNG0nf+d4RRMBPBeq/dor+03ybMtMUbENAEcB/JGq3gXgFgBLEUEZKFHaup00SyJY0CdoXvIOFgLg/rUDnqmZ9SuXwH0V5VQ1+U3y3ChGUbFNAd0B4HYR2aWqnxeRLwJYFsUAWncBlDNZr1LptlOnY1oVjab53V5tnJ3Pd1cTCYB7b768yG0qeeVGMYqKbQDYCeASgA0APg/gHwB8CTN3A0RzZL1KxT2+uCoLatVKqLbSissHtQRN8txHQFGw3gimqh8UkTEAUNULIrIwxnFRD8t6lYrX+KIWlJKxWcjlJE9xs10DaIpICa2LJRFZgpk7AqJ5sl6lksQ4gvrqcyGXssA2APwnAH8A4L0i8h8w0w76t2MbFfW0rE9ucY/DnfrxOkOYC7mUBVYBQFWfwkzbhj0AfgpgWFWfjXNg1LuyPrnZtoHuhPt7OusN9fEGFHPXQ9wHvvM0LkqadWNzVT0N4HSMY6GcyHqVSvv4grp6Lu4v4+4b/xGOnD7v+1oBPL+n33rISzs2ZOa/CRUTj4SkWGR9AbN9fMsMRzAKgLFdlzuWrNt72DMI+J3Vm/X1ECo22zUAoszzyrXbqFmuWXSS2sr6eggVGwMA5YIp1+4EAb/gYDuxDw/WQufts74eQsXGFBDlQqedNdtTQTZrFn6pLb/dz1ldD6FiE7VseJUFQ0NDOjo6mvYwKIOW7zjouaPXWZwNm7sPy3ScJCt7KAtE5JiqDrmf5x0A5YLfecG2C7Gd9C9q7ynklqXdz0ReuAZAudBpZ01H0BqCl/b3mLDah7KMAYB6XlBnTZuF2KA1BC82PYVY7UNZxgBAPc+vs+bIWB27Xzg15+fVSnlebr6Tev2gq3tW+1DWcQ2AMidsLt40EdfHG9j+7Il5h7VcnJx/epffGoKJ6T3AzAIzq30o63gHQJkSNhc/MlZHn+F00pJ4n9TVnNZ5qZ1O6vVN73li8xq2eaCewABAmRImF+8EC6+zeyvlku+Zvu67hk42eXXyHqIsYQqIMiVMLt60CFsSwZ57Vvs2e3NXALWnnEzHOHrJes8jIj+8A6BMCdM7xxQsLqlieLCG9SuXeP68TzCb2vFKOW07cByPjJzs7AsQ9RAGAMoUr7x6uU8wMTk1r49PULBwztd1W1Qpz2nR4FVB9NTRs9bN5Ih6FQMAZYo7r16tlAEBLkw05y0KBy3cmu4Qxieas382vUYB3z0ARHnANQDKnPa8+rq9hzHeaM75eaM5jce+fQr9Cxeg0ZxGSQTTqvNKL21KO/1KObmLl/KOdwCUaaZJ+MJEc3binladvfIfHqzNtn6ujzeM7SEcnfTyJ8oL3gFQpvldobcztX5WzLSFUJg3Z5X6BNOu/QLlPuEuXso93gFQpoU5wP3ceMO4qOu0fnZP/vsOnZk3+QPAu69cwPJOyj3eAVCmOZPw1gPHA18bpvVz0PPtC8VEecU7AMq84cGa8dxeR5jWz908T5QnDADUE7xSQc4Cb3sLhrA9fXhmLxVZaikgEfkMgE8DmAJwUFU/l9ZYKPtsz9YNewYvz+ylIkvlTGARWQ/gYQB3q+o7IvJeVf27oPfxTODs6uQ4RSJKRtbOBP4kgL2q+g4A2Ez+lF3uA9Gd3boAGASIMiytNYAPAPgVEXlZRP5ERG4xvVBEtojIqIiMnj/v3duF0tXJcYpElL7Y7gBE5EUA13j86OHW5y4GsBbALQCeEZFfUI98lKruB7AfmEkBxTVeCqc95WP6H4WtFIiyLbYAoKp3mH4mIp8E8Hxrwv9zEbkE4GoAvMTvAe6UjwlLKYmyLa01gBEAGwB8T0Q+AGAhgJ+lNBZqsV3IdR+y7oWtFIiyL60A8CSAJ0XkhwAmATzolf6h5Ngu5I6M1ed15/TkfUxv12NkpRFRdFJZBFbVSVV9QFVvUNUPqurhNMZBl9ku5Nou7HodvN6NsIfFE1Ew7gQmAPa9csIs7Ea5CMxKI6LoMQAQAPueOGEWdqNcBA7bzI2IgjEAEAD7njieZ/aWBOW+uUl/AYyHsneCTduIoscAkGPOyVjuw9S9uM/ibW+wFvS6fb96Ezbfet2cdV8F8NyxunWOPmisbNpGFL1UegF1ir2A7HnV6lfKJc9JPQrOEYxuzkEsUYyVVUBEnclaLyCKmd+iaRyTZjc5etuxth8WT0TdYwoop5JeNO0mR88FXqJ0MADklO2EHGadwE83OXou8BKlgwEgp2wm5Cg3V9kuInc6ViKKHtcAcsrmpKuo1wk6zdHzVC6idDAA5FjQhJyl3DsXeImSxxRQgTH3TlRsDAAFxtw7UbExBVRwVyzom10HWNxfxqMfWRUqFcPNWUS9iwGgoLx2377dvOT7evdED4CHwRP1MKaACipMe2VTuehj355/MhhbNBP1Dt4BFFSYCiBTsDAdC8kdvES9gQGgoJZWK57N27wqgMJO6J1UEXEtgSh5TAEVVJgKINOEXq2UI6ki4nGPROlgACioMK0bTMFi96ZVHbd/aMfjHonSwRRQgdnuvg1q1dBtqiZLO5KJioQBgKzE2aohzHoEEUWHKSBKHXckE6WDdwAJYqWLN3YDJUoHA0BC3DtvuWt2LnYDJUoeU0AJYaULEWUNA0BCWOlCRFnDAJAQ9t4noqzhGkBCtm9cMa/7ZqVcwvqVS7Bu7+FQi5+PjJzE0y+/jmlVlERw323X4fHh1XF/BSLKmVQCgIisAfBfAVwJYArAp1T1z9MYS1K8Kl3Wr1yC547VQy0MPzJyEl8/enb28bTq7GMGASIKI60U0BcAPKaqawDsaj3ONa8S0COnz4deGH765ddDPU9EZJJWCkgBvKf150UAzqU0jkSYSkA7aac8rRrq+aBxsfaeqLjSCgBbARwSkS9i5i7kn6Y0jkSYSkBLIp4Tt9/CsOk9JZFQY+K+BCKKLQUkIi+KyA89/vkogE8C2Kaq1wHYBuD3fH7PFhEZFZHR8+fPxzXcWJmu6KdVQ7dAuO+260I9b8J9CUQUWwBQ1TtU9QaPf74F4EEAz7de+iyAW31+z35VHVLVoSVLlsQ13FiZruid9slh2ik/PrwaD6wdmL3iL4nggbUDoReAuS+BiNJKAZ0DcDuA7wHYAOCvUhpHIkwloE7OPWzK5fHh1V1X/LADJxGlVQX0GwC+JCInAPw2gC0pjSMRYQ5fSQo7cBKRaAfVI2kZGhrS0dHRtIeRG6wCIioGETmmqkPu57kTuMDYgZOo2NgLiIiooBgAiIgKiimgAuMaAFGxMQAUFHcCExFTQAXFncBExABQUNwJTEQMAAXFE8qIiAGgoLgTmIi4CFxQXieUsQqIqFgYAAqMO4GJio0pICKigmIAICIqKAYAIqKCyv0aANsdEBF5y3UAYLsDIiKzXKeA2O6AiMgs1wGA7Q6IiMxyHQDY7oCIyCzXAYDtDoiIzHK9CMx2B0REZrkOAADbHRARmeQ6BURERGYMAEREBcUAQERUUAwAREQFxQBARFRQoqppj8GaiJwH8Fra4whwNYCfpT2IhBXtOxft+wL8zr3uelVd4n6ypwJALxCRUVUdSnscSSrady7a9wX4nfOKKSAiooJiACAiKigGgOjtT3sAKSjady7a9wX4nXOJawBERAXFOwAiooJiACAiKigGgBiIyBoROSoix0VkVERuTXtMcRORz4jIGRE5JSJfSHs8SRGRz4qIisjVaY8lbiKyT0ROi8gPROQPRKSa9pjiICJ3tf4u/1hEdqQ9njgxAMTjCwAeU9U1AHa1HueWiKwH8FEAN6rqKgBfTHlIiRCR6wB8GMDZtMeSkO8CuEFVbwTwlwB2pjyeyIlICcDvAPhnAH4JwH0i8kvpjio+DADxUADvaf15EYBzKY4lCZ8EsFdV3wEAVf27lMeTlC8D+Bxm/vfOPVX9jqpOtR4eBXBtmuOJya0Afqyqf62qkwC+gZmLm1xiAIjHVgD7ROR1zFwN5+5KyeUDAH5FRF4WkT8RkVvSHlDcRGQTgLqqnkh7LCn5dQB/lPYgYlAD8Hrb4zdaz+VS7k8Ei4uIvAjgGo8fPQzgQwC2qepzIvJxAL8H4I4kxxe1gO+7AMBiAGsB3ALgGRH5Be3xGuOA7/xbAO5MdkTx8/vOqvqt1mseBjAF4Kkkx5YQ8Xiup/8e++E+gBiIyFsAqqqqIiIA3lLV9wS9r1eJyB9jJgX0vdbjVwGsVdXzqQ4sJiKyGsD/AjDReupazKT5blXVv01tYAkQkQcBfALAh1R1Iuj1vUZE/gmA3aq6sfV4JwCo6p5UBxYTpoDicQ7A7ZLvJCwAAAIhSURBVK0/bwDwVymOJQkjmPmeEJEPAFiI/HRRnEdVT6rqe1V1maouw0ya4IMFmPzvAvCbADblcfJveQXA+0VkuYgsBPBrAF5IeUyxYQooHr8B4CsisgDA2wC2pDyeuD0J4EkR+SGASQAP9nr6hzx9FcAVAL47c2OLo6r6iXSHFC1VnRKRTwM4BKAE4ElVPZXysGLDFBARUUExBUREVFAMAEREBcUAQERUUAwAREQFxQBARFRQDABECRKRn6c9BiIHAwBRl1odJIl6DgMAkQ8RWdbqgf/7rT743xSRfhH5GxHZJSLfB/AxEXmfiPyxiBwTkf8jIitb718uIn8mIq+IyL9P+esQzcEAQBRsBYD9rT74fw/gU63n31bVX1bVb2DmAPHPqOrNAD4L4L+0XvMVAF9T1VsA5LpVBPUe7gQm8iEiywD8b1UdaD3eAODfAlgD4HZVfU1E3g3gPIAzbW+9QlV/UUT+H4BrVLUpIu8BcE5V353olyAyYC8gomDuqyTn8cXWv/sAjLdOgLN5P1EmMAVEFGyg1SYYAO4D8P32H6rq3wP4iYh8DABkxk2tH7+EmY6SAHB/EoMlssUAQBTsRwAeFJEfALgKwNc8XnM/gH8tIicAnMLlYwT/HYB/IyKvYOZ4UKLM4BoAkY/WGsD/VNUbUh4KUeR4B0BEVFC8AyAiKijeARARFRQDABFRQTEAEBEVFAMAEVFBMQAQERXU/wf7rLfcvyy+hwAAAABJRU5ErkJggg==\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.clf()\n",
"plt.scatter(res, test_sol)\n",
"plt.xlabel('pred')\n",
"plt.ylabel('exp')"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.metrics import r2_score"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.9098691301661277\n"
]
}
],
"source": [
"print(r2_score(test_sol, res))"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.ensemble import RandomForestRegressor\n",
"from rdkit import Chem\n",
"from rdkit.Chem import AllChem\n",
"from rdkit.Chem import DataStructs"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"train_fp = [AllChem.GetMorganFingerprintAsBitVect(mol,2) for mol in train_mols]\n",
"test_fp = [AllChem.GetMorganFingerprintAsBitVect(mol,2) for mol in test_mols]"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [],
"source": [
"rfr = RandomForestRegressor()"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/takayuki/anaconda3/envs/chemo37/lib/python3.7/site-packages/sklearn/ensemble/forest.py:245: FutureWarning: The default value of n_estimators will change from 10 in version 0.20 to 100 in 0.22.\n",
" \"10 in version 0.20 to 100 in 0.22.\", FutureWarning)\n",
"/home/takayuki/anaconda3/envs/chemo37/lib/python3.7/site-packages/ipykernel_launcher.py:1: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n",
" \"\"\"Entry point for launching an IPython kernel.\n"
]
},
{
"data": {
"text/plain": [
"RandomForestRegressor(bootstrap=True, criterion='mse', max_depth=None,\n",
" max_features='auto', max_leaf_nodes=None,\n",
" min_impurity_decrease=0.0, min_impurity_split=None,\n",
" min_samples_leaf=1, min_samples_split=2,\n",
" min_weight_fraction_leaf=0.0, n_estimators=10,\n",
" n_jobs=None, oob_score=False, random_state=None,\n",
" verbose=0, warm_start=False)"
]
},
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"rfr.fit(train_fp, train_sol)"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
"rfr_pred = rfr.predict(test_fp)"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.6846847631483519"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"r2_score(test_sol, rfr_pred)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment