Skip to content

Instantly share code, notes, and snippets.

@iwatobipen
Created March 28, 2023 12:51
Show Gist options
  • Save iwatobipen/5b3101e24102457d08aa138d97aa69b5 to your computer and use it in GitHub Desktop.
Save iwatobipen/5b3101e24102457d08aa138d97aa69b5 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "509df0a0",
"metadata": {
"scrolled": true
},
"source": [
"!wget https://raw.githubusercontent.com/rdkit/rdkit/master/Docs/Book/data/solubility.train.sdf\n",
"!wget https://raw.githubusercontent.com/rdkit/rdkit/master/Docs/Book/data/solubility.test.sdf"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "7c763163",
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline\n",
"import matplotlib.pyplot as plt\n",
"from rdkit import Chem\n",
"from torch_geometric.utils import smiles as pygsmi\n",
"import torch\n",
"import torch_geometric\n",
"import torch.nn.functional as F\n",
"from torch.nn import Linear\n",
"from torch.nn import BatchNorm1d\n",
"from torch.utils.data import Dataset\n",
"from torch_geometric.nn import GCNConv\n",
"from torch_geometric.nn import ChebConv\n",
"from torch_geometric.nn import global_add_pool, global_mean_pool\n",
"from torch_geometric.loader import DataLoader\n",
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "ce901cce",
"metadata": {},
"outputs": [],
"source": [
"train_mols = [m for m in Chem.SDMolSupplier('solubility.train.sdf')]\n",
"test_mols = [m for m in Chem.SDMolSupplier('solubility.test.sdf')]\n",
"sol_cls_dict = {'(A) low':0, '(B) medium':1, '(C) high':2}\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "849925b3",
"metadata": {},
"outputs": [],
"source": [
"n_features = 9\n",
"# definenet\n",
"class Net(torch.nn.Module):\n",
" def __init__(self):\n",
" super(Net, self).__init__()\n",
" self.conv1 = GCNConv(n_features, 128, cached=False) # if you defined cache=True, the shape of batch must be same!\n",
" self.bn1 = BatchNorm1d(128)\n",
" self.conv2 = GCNConv(128, 64, cached=False)\n",
" self.bn2 = BatchNorm1d(64)\n",
" self.fc1 = Linear(64, 64)\n",
" self.bn3 = BatchNorm1d(64)\n",
" self.fc2 = Linear(64, 64)\n",
" self.fc3 = Linear(64, 3)\n",
" \n",
" def forward(self, data):\n",
" x, edge_index = data.x, data.edge_index\n",
" x = F.relu(self.conv1(x, edge_index))\n",
" x = self.bn1(x)\n",
" x = F.relu(self.conv2(x, edge_index))\n",
" x = self.bn2(x)\n",
" x = global_add_pool(x, data.batch)\n",
" x = F.relu(self.fc1(x))\n",
" x = self.bn3(x)\n",
" x = F.relu(self.fc2(x))\n",
" x = F.dropout(x, p=0.2, training=self.training)\n",
" x = self.fc3(x)\n",
" x = F.log_softmax(x, dim=1)\n",
" return x "
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "8aba9d3e",
"metadata": {},
"outputs": [],
"source": [
"train_X = [pygsmi.from_smiles(Chem.MolToSmiles(m)) for m in train_mols]\n",
"for i, data in enumerate(train_X):\n",
" y = sol_cls_dict[train_mols[i].GetProp('SOL_classification')]\n",
" data.y = torch.tensor([y], dtype=torch.long)\n",
" data.x = data.x.float()\n",
"\n",
"test_X = [pygsmi.from_smiles(Chem.MolToSmiles(m)) for m in test_mols]\n",
"for i, data in enumerate(test_X):\n",
" y = sol_cls_dict[test_mols[i].GetProp('SOL_classification')]\n",
" data.y = torch.tensor([y], dtype=torch.long)\n",
" data.x = data.x.float()\n",
"train_loader = DataLoader(train_X, batch_size=64, shuffle=True, drop_last=True)\n",
"test_loader = DataLoader(test_X, batch_size=64, shuffle=True, drop_last=True)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5f815e91",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 7,
"id": "5ef858b0",
"metadata": {},
"outputs": [],
"source": [
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
"model = Net().to(device)\n"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "9ef2dcd9",
"metadata": {},
"outputs": [],
"source": [
"model = torch_geometric.compile(model)\n",
"optimizer = torch.optim.Adam(model.parameters(), lr=0.01)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "0b3acac8",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch: 1, Train loss: 0.867, Train_acc: 0.522, Test_acc: 0.537\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"[2023-03-28 21:48:02,019] torch._dynamo.convert_frame: [WARNING] torch._dynamo hit config.cache_size_limit (64)\n",
" function: 'forward' (/tmp/iwatobipen_pyg/tmp_e83a8op.py:217)\n",
" reasons: self.training == True\n",
"to diagnose recompilation issues, see https://pytorch.org/docs/master/dynamo/troubleshooting.html.\n",
"[2023-03-28 21:48:02,730] torch._dynamo.convert_frame: [WARNING] torch._dynamo hit config.cache_size_limit (64)\n",
" function: 'forward' (/tmp/iwatobipen_pyg/tmpp4kgewmy.py:217)\n",
" reasons: self.training == True\n",
"to diagnose recompilation issues, see https://pytorch.org/docs/master/dynamo/troubleshooting.html.\n",
"[2023-03-28 21:48:04,160] torch._dynamo.convert_frame: [WARNING] torch._dynamo hit config.cache_size_limit (64)\n",
" function: 'gcn_norm' (/home/iwatobipen/miniconda3/envs/torch2/lib/python3.10/site-packages/torch_geometric/nn/conv/gcn_conv.py:43)\n",
" reasons: num_nodes == 863\n",
"to diagnose recompilation issues, see https://pytorch.org/docs/master/dynamo/troubleshooting.html.\n",
"[2023-03-28 21:48:04,164] torch._dynamo.convert_frame: [WARNING] torch._dynamo hit config.cache_size_limit (64)\n",
" function: 'add_remaining_self_loops' (/home/iwatobipen/miniconda3/envs/torch2/lib/python3.10/site-packages/torch_geometric/utils/loop.py:300)\n",
" reasons: num_nodes == 863\n",
"to diagnose recompilation issues, see https://pytorch.org/docs/master/dynamo/troubleshooting.html.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch: 2, Train loss: 0.734, Train_acc: 0.498, Test_acc: 0.482\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"[2023-03-28 21:48:11,312] torch._dynamo.convert_frame: [WARNING] torch._dynamo hit config.cache_size_limit (64)\n",
" function: 'global_add_pool' (/home/iwatobipen/miniconda3/envs/torch2/lib/python3.10/site-packages/torch_geometric/nn/pool/glob.py:8)\n",
" reasons: tensor 'x' size mismatch at index 0. expected 863, actual 879\n",
"to diagnose recompilation issues, see https://pytorch.org/docs/master/dynamo/troubleshooting.html.\n",
"[2023-03-28 21:48:23,685] torch._dynamo.convert_frame: [WARNING] torch._dynamo hit config.cache_size_limit (64)\n",
" function: 'forward' (/home/iwatobipen/miniconda3/envs/torch2/lib/python3.10/site-packages/torch_geometric/nn/dense/linear.py:127)\n",
" reasons: tensor 'x' size mismatch at index 0. expected 774, actual 780\n",
"to diagnose recompilation issues, see https://pytorch.org/docs/master/dynamo/troubleshooting.html.\n",
"[2023-03-28 21:48:27,829] torch._dynamo.convert_frame: [WARNING] torch._dynamo hit config.cache_size_limit (64)\n",
" function: 'scatter' (/home/iwatobipen/miniconda3/envs/torch2/lib/python3.10/site-packages/torch_geometric/utils/scatter.py:23)\n",
" reasons: dim == 0\n",
"to diagnose recompilation issues, see https://pytorch.org/docs/master/dynamo/troubleshooting.html.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch: 3, Train loss: 0.608, Train_acc: 0.54, Test_acc: 0.537\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"[2023-03-28 21:48:38,694] torch._dynamo.convert_frame: [WARNING] torch._dynamo hit config.cache_size_limit (64)\n",
" function: 'propagate' (/tmp/iwatobipen_pyg/tmp_e83a8op.py:178)\n",
" reasons: tensor 'x' size mismatch at index 0. expected 774, actual 815\n",
"to diagnose recompilation issues, see https://pytorch.org/docs/master/dynamo/troubleshooting.html.\n",
"[2023-03-28 21:48:39,004] torch._dynamo.convert_frame: [WARNING] torch._dynamo hit config.cache_size_limit (64)\n",
" function: 'propagate' (/tmp/iwatobipen_pyg/tmpp4kgewmy.py:178)\n",
" reasons: tensor 'x' size mismatch at index 0. expected 774, actual 815\n",
"to diagnose recompilation issues, see https://pytorch.org/docs/master/dynamo/troubleshooting.html.\n",
"[2023-03-28 21:48:46,226] torch._dynamo.convert_frame: [WARNING] torch._dynamo hit config.cache_size_limit (64)\n",
" function: 'broadcast' (/home/iwatobipen/miniconda3/envs/torch2/lib/python3.10/site-packages/torch_geometric/utils/scatter.py:18)\n",
" reasons: tensor 'ref' size mismatch at index 0. expected 766, actual 922\n",
"to diagnose recompilation issues, see https://pytorch.org/docs/master/dynamo/troubleshooting.html.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch: 4, Train loss: 0.569, Train_acc: 0.484, Test_acc: 0.463\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"[2023-03-28 21:48:59,980] torch._dynamo.convert_frame: [WARNING] torch._dynamo hit config.cache_size_limit (64)\n",
" function: 'aggregate' (/home/iwatobipen/miniconda3/envs/torch2/lib/python3.10/site-packages/torch_geometric/nn/conv/message_passing.py:595)\n",
" reasons: dim_size == 815\n",
"to diagnose recompilation issues, see https://pytorch.org/docs/master/dynamo/troubleshooting.html.\n",
"[2023-03-28 21:49:00,516] torch._dynamo.convert_frame: [WARNING] torch._dynamo hit config.cache_size_limit (64)\n",
" function: 'message' (/home/iwatobipen/miniconda3/envs/torch2/lib/python3.10/site-packages/torch_geometric/nn/conv/gcn_conv.py:240)\n",
" reasons: tensor 'x_j' size mismatch at index 0. expected 2481, actual 2369\n",
"to diagnose recompilation issues, see https://pytorch.org/docs/master/dynamo/troubleshooting.html.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch: 5, Train loss: 0.527, Train_acc: 0.51, Test_acc: 0.49\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"[2023-03-28 21:49:14,022] torch._dynamo.convert_frame: [WARNING] torch._dynamo hit config.cache_size_limit (64)\n",
" function: '_collect' (/tmp/iwatobipen_pyg/tmp_e83a8op.py:115)\n",
" reasons: self.training == False\n",
"to diagnose recompilation issues, see https://pytorch.org/docs/master/dynamo/troubleshooting.html.\n",
"[2023-03-28 21:49:14,124] torch._dynamo.convert_frame: [WARNING] torch._dynamo hit config.cache_size_limit (64)\n",
" function: '__call__' (/home/iwatobipen/miniconda3/envs/torch2/lib/python3.10/site-packages/torch_geometric/nn/aggr/base.py:86)\n",
" reasons: self.training == False\n",
"to diagnose recompilation issues, see https://pytorch.org/docs/master/dynamo/troubleshooting.html.\n",
"[2023-03-28 21:49:14,244] torch._dynamo.convert_frame: [WARNING] torch._dynamo hit config.cache_size_limit (64)\n",
" function: '_collect' (/tmp/iwatobipen_pyg/tmpp4kgewmy.py:115)\n",
" reasons: self.training == False\n",
"to diagnose recompilation issues, see https://pytorch.org/docs/master/dynamo/troubleshooting.html.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch: 6, Train loss: 0.557, Train_acc: 0.585, Test_acc: 0.553\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"[2023-03-28 21:49:26,379] torch._dynamo.convert_frame: [WARNING] torch._dynamo hit config.cache_size_limit (64)\n",
" function: 'forward' (/home/iwatobipen/miniconda3/envs/torch2/lib/python3.10/site-packages/torch_geometric/nn/aggr/basic.py:18)\n",
" reasons: dim_size == 853\n",
"to diagnose recompilation issues, see https://pytorch.org/docs/master/dynamo/troubleshooting.html.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch: 7, Train loss: 0.508, Train_acc: 0.726, Test_acc: 0.743\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"[2023-03-28 21:49:37,808] torch._dynamo.convert_frame: [WARNING] torch._dynamo hit config.cache_size_limit (64)\n",
" function: 'reduce' (/home/iwatobipen/miniconda3/envs/torch2/lib/python3.10/site-packages/torch_geometric/nn/aggr/base.py:146)\n",
" reasons: dim_size == 809\n",
"to diagnose recompilation issues, see https://pytorch.org/docs/master/dynamo/troubleshooting.html.\n",
"[2023-03-28 21:49:38,043] torch._dynamo.convert_frame: [WARNING] torch._dynamo hit config.cache_size_limit (64)\n",
" function: '_lift' (/tmp/iwatobipen_pyg/tmp_e83a8op.py:87)\n",
" reasons: tensor 'src' size mismatch at index 0. expected 822, actual 836\n",
"to diagnose recompilation issues, see https://pytorch.org/docs/master/dynamo/troubleshooting.html.\n",
"[2023-03-28 21:49:38,045] torch._dynamo.convert_frame: [WARNING] torch._dynamo hit config.cache_size_limit (64)\n",
" function: '_lift' (/tmp/iwatobipen_pyg/tmpp4kgewmy.py:87)\n",
" reasons: tensor 'src' size mismatch at index 0. expected 822, actual 836\n",
"to diagnose recompilation issues, see https://pytorch.org/docs/master/dynamo/troubleshooting.html.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch: 8, Train loss: 0.563, Train_acc: 0.519, Test_acc: 0.494\n",
"Epoch: 9, Train loss: 0.526, Train_acc: 0.647, Test_acc: 0.634\n",
"Epoch: 10, Train loss: 0.485, Train_acc: 0.809, Test_acc: 0.805\n",
"Epoch: 11, Train loss: 0.513, Train_acc: 0.542, Test_acc: 0.545\n",
"Epoch: 12, Train loss: 0.487, Train_acc: 0.698, Test_acc: 0.689\n",
"Epoch: 13, Train loss: 0.466, Train_acc: 0.628, Test_acc: 0.634\n",
"Epoch: 14, Train loss: 0.468, Train_acc: 0.532, Test_acc: 0.549\n",
"Epoch: 15, Train loss: 0.495, Train_acc: 0.617, Test_acc: 0.599\n",
"Epoch: 16, Train loss: 0.437, Train_acc: 0.53, Test_acc: 0.525\n",
"Epoch: 17, Train loss: 0.463, Train_acc: 0.669, Test_acc: 0.654\n",
"Epoch: 18, Train loss: 0.505, Train_acc: 0.687, Test_acc: 0.634\n",
"Epoch: 19, Train loss: 0.478, Train_acc: 0.715, Test_acc: 0.759\n",
"Epoch: 20, Train loss: 0.449, Train_acc: 0.645, Test_acc: 0.65\n",
"Epoch: 21, Train loss: 0.477, Train_acc: 0.803, Test_acc: 0.786\n",
"Epoch: 22, Train loss: 0.438, Train_acc: 0.734, Test_acc: 0.712\n",
"Epoch: 23, Train loss: 0.406, Train_acc: 0.773, Test_acc: 0.759\n",
"Epoch: 24, Train loss: 0.415, Train_acc: 0.793, Test_acc: 0.759\n",
"Epoch: 25, Train loss: 0.481, Train_acc: 0.778, Test_acc: 0.735\n",
"Epoch: 26, Train loss: 0.476, Train_acc: 0.58, Test_acc: 0.584\n",
"Epoch: 27, Train loss: 0.464, Train_acc: 0.573, Test_acc: 0.654\n",
"Epoch: 28, Train loss: 0.474, Train_acc: 0.74, Test_acc: 0.724\n",
"Epoch: 29, Train loss: 0.424, Train_acc: 0.77, Test_acc: 0.755\n",
"Epoch: 30, Train loss: 0.422, Train_acc: 0.611, Test_acc: 0.595\n",
"Epoch: 31, Train loss: 0.457, Train_acc: 0.751, Test_acc: 0.743\n",
"Epoch: 32, Train loss: 0.408, Train_acc: 0.592, Test_acc: 0.564\n",
"Epoch: 33, Train loss: 0.412, Train_acc: 0.746, Test_acc: 0.759\n",
"Epoch: 34, Train loss: 0.446, Train_acc: 0.674, Test_acc: 0.65\n",
"Epoch: 35, Train loss: 0.435, Train_acc: 0.674, Test_acc: 0.63\n",
"Epoch: 36, Train loss: 0.444, Train_acc: 0.714, Test_acc: 0.712\n",
"Epoch: 37, Train loss: 0.438, Train_acc: 0.651, Test_acc: 0.677\n",
"Epoch: 38, Train loss: 0.446, Train_acc: 0.635, Test_acc: 0.658\n",
"Epoch: 39, Train loss: 0.395, Train_acc: 0.789, Test_acc: 0.751\n",
"Epoch: 40, Train loss: 0.423, Train_acc: 0.507, Test_acc: 0.553\n",
"Epoch: 41, Train loss: 0.467, Train_acc: 0.642, Test_acc: 0.646\n",
"Epoch: 42, Train loss: 0.417, Train_acc: 0.719, Test_acc: 0.732\n",
"Epoch: 43, Train loss: 0.374, Train_acc: 0.836, Test_acc: 0.805\n",
"Epoch: 44, Train loss: 0.382, Train_acc: 0.826, Test_acc: 0.79\n",
"Epoch: 45, Train loss: 0.435, Train_acc: 0.745, Test_acc: 0.739\n",
"Epoch: 46, Train loss: 0.401, Train_acc: 0.782, Test_acc: 0.759\n",
"Epoch: 47, Train loss: 0.397, Train_acc: 0.566, Test_acc: 0.568\n",
"Epoch: 48, Train loss: 0.391, Train_acc: 0.762, Test_acc: 0.728\n",
"Epoch: 49, Train loss: 0.415, Train_acc: 0.739, Test_acc: 0.704\n",
"Epoch: 50, Train loss: 0.378, Train_acc: 0.784, Test_acc: 0.763\n",
"Epoch: 51, Train loss: 0.431, Train_acc: 0.818, Test_acc: 0.774\n",
"Epoch: 52, Train loss: 0.412, Train_acc: 0.647, Test_acc: 0.65\n",
"Epoch: 53, Train loss: 0.394, Train_acc: 0.639, Test_acc: 0.638\n",
"Epoch: 54, Train loss: 0.385, Train_acc: 0.784, Test_acc: 0.759\n",
"Epoch: 55, Train loss: 0.406, Train_acc: 0.751, Test_acc: 0.786\n",
"Epoch: 56, Train loss: 0.36, Train_acc: 0.675, Test_acc: 0.681\n",
"Epoch: 57, Train loss: 0.404, Train_acc: 0.821, Test_acc: 0.751\n",
"Epoch: 58, Train loss: 0.378, Train_acc: 0.792, Test_acc: 0.778\n",
"Epoch: 59, Train loss: 0.416, Train_acc: 0.508, Test_acc: 0.506\n",
"Epoch: 60, Train loss: 0.385, Train_acc: 0.759, Test_acc: 0.743\n",
"Epoch: 61, Train loss: 0.396, Train_acc: 0.778, Test_acc: 0.735\n",
"Epoch: 62, Train loss: 0.388, Train_acc: 0.832, Test_acc: 0.786\n",
"Epoch: 63, Train loss: 0.37, Train_acc: 0.85, Test_acc: 0.802\n",
"Epoch: 64, Train loss: 0.395, Train_acc: 0.587, Test_acc: 0.591\n",
"Epoch: 65, Train loss: 0.395, Train_acc: 0.817, Test_acc: 0.782\n",
"Epoch: 66, Train loss: 0.378, Train_acc: 0.78, Test_acc: 0.747\n",
"Epoch: 67, Train loss: 0.363, Train_acc: 0.72, Test_acc: 0.708\n",
"Epoch: 68, Train loss: 0.331, Train_acc: 0.786, Test_acc: 0.763\n",
"Epoch: 69, Train loss: 0.345, Train_acc: 0.728, Test_acc: 0.716\n",
"Epoch: 70, Train loss: 0.336, Train_acc: 0.609, Test_acc: 0.591\n",
"Epoch: 71, Train loss: 0.34, Train_acc: 0.684, Test_acc: 0.646\n",
"Epoch: 72, Train loss: 0.384, Train_acc: 0.778, Test_acc: 0.782\n",
"Epoch: 73, Train loss: 0.367, Train_acc: 0.706, Test_acc: 0.704\n",
"Epoch: 74, Train loss: 0.414, Train_acc: 0.764, Test_acc: 0.728\n",
"Epoch: 75, Train loss: 0.397, Train_acc: 0.798, Test_acc: 0.759\n",
"Epoch: 76, Train loss: 0.378, Train_acc: 0.784, Test_acc: 0.72\n",
"Epoch: 77, Train loss: 0.355, Train_acc: 0.767, Test_acc: 0.728\n",
"Epoch: 78, Train loss: 0.405, Train_acc: 0.66, Test_acc: 0.661\n",
"Epoch: 79, Train loss: 0.376, Train_acc: 0.555, Test_acc: 0.541\n",
"Epoch: 80, Train loss: 0.383, Train_acc: 0.635, Test_acc: 0.619\n",
"Epoch: 81, Train loss: 0.39, Train_acc: 0.686, Test_acc: 0.665\n",
"Epoch: 82, Train loss: 0.398, Train_acc: 0.55, Test_acc: 0.545\n",
"Epoch: 83, Train loss: 0.381, Train_acc: 0.675, Test_acc: 0.704\n",
"Epoch: 84, Train loss: 0.38, Train_acc: 0.704, Test_acc: 0.704\n",
"Epoch: 85, Train loss: 0.374, Train_acc: 0.788, Test_acc: 0.743\n",
"Epoch: 86, Train loss: 0.391, Train_acc: 0.802, Test_acc: 0.774\n",
"Epoch: 87, Train loss: 0.368, Train_acc: 0.732, Test_acc: 0.732\n",
"Epoch: 88, Train loss: 0.382, Train_acc: 0.78, Test_acc: 0.743\n",
"Epoch: 89, Train loss: 0.374, Train_acc: 0.57, Test_acc: 0.556\n",
"Epoch: 90, Train loss: 0.35, Train_acc: 0.766, Test_acc: 0.747\n",
"Epoch: 91, Train loss: 0.359, Train_acc: 0.78, Test_acc: 0.763\n",
"Epoch: 92, Train loss: 0.327, Train_acc: 0.805, Test_acc: 0.755\n",
"Epoch: 93, Train loss: 0.312, Train_acc: 0.853, Test_acc: 0.782\n",
"Epoch: 94, Train loss: 0.389, Train_acc: 0.643, Test_acc: 0.638\n",
"Epoch: 95, Train loss: 0.411, Train_acc: 0.829, Test_acc: 0.774\n",
"Epoch: 96, Train loss: 0.383, Train_acc: 0.652, Test_acc: 0.642\n",
"Epoch: 97, Train loss: 0.342, Train_acc: 0.741, Test_acc: 0.724\n",
"Epoch: 98, Train loss: 0.373, Train_acc: 0.856, Test_acc: 0.798\n",
"Epoch: 99, Train loss: 0.357, Train_acc: 0.736, Test_acc: 0.747\n",
"Epoch: 100, Train loss: 0.385, Train_acc: 0.552, Test_acc: 0.514\n"
]
},
{
"data": {
"text/plain": [
"<matplotlib.legend.Legend at 0x7fb6fb144f10>"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"def train(epoch):\n",
" model.train()\n",
" loss_all = 0\n",
" for data in train_loader:\n",
" data = data.to(device)\n",
" optimizer.zero_grad()\n",
" output = model(data)\n",
" loss = F.nll_loss(output, data.y)\n",
" loss.backward()\n",
" loss_all += loss.item() * data.num_graphs\n",
" optimizer.step()\n",
" return loss_all / len(train_X)\n",
"def test(loader):\n",
" model.eval()\n",
" correct = 0\n",
" for data in loader:\n",
" data = data.to(device)\n",
" output = model(data)\n",
" pred = output.max(dim=1)[1]\n",
" correct += pred.eq(data.y).sum().item()\n",
" return correct / len(loader.dataset)\n",
"hist = {\"loss\":[], \"acc\":[], \"test_acc\":[]}\n",
"for epoch in range(1, 101):\n",
" train_loss = train(epoch)\n",
" train_acc = test(train_loader)\n",
" test_acc = test(test_loader)\n",
" hist[\"loss\"].append(train_loss)\n",
" hist[\"acc\"].append(train_acc)\n",
" hist[\"test_acc\"].append(test_acc)\n",
" print(f'Epoch: {epoch}, Train loss: {train_loss:.3}, Train_acc: {train_acc:.3}, Test_acc: {test_acc:.3}')\n",
"ax = plt.subplot(1,1,1)\n",
"ax.plot([e for e in range(1,101)], hist[\"loss\"], label=\"train_loss\")\n",
"ax.plot([e for e in range(1,101)], hist[\"acc\"], label=\"train_acc\")\n",
"ax.plot([e for e in range(1,101)], hist[\"test_acc\"], label=\"test_acc\")\n",
"plt.xlabel(\"epoch\")\n",
"ax.legend()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "465cde87",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.10"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment