Skip to content

Instantly share code, notes, and snippets.

@iwatobipen
Created January 4, 2021 13:22
Show Gist options
  • Save iwatobipen/827d3921826607663dd50018be903ee7 to your computer and use it in GitHub Desktop.
Save iwatobipen/827d3921826607663dd50018be903ee7 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"from autogl.solver import AutoNodeClassifier\n",
"from autogl.solver import AutoGraphClassifier\n",
"from autogl.module.feature import BaseFeatureEngineer\n",
"from autogl.module.feature import BaseFeatureAtom\n",
"from autogl.datasets import utils"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"from rdkit import Chem\n",
"from rdkit.Chem import RDConfig\n",
"import molutil\n",
"from torch_geometric.data import Data, DataLoader, Dataset\n",
"from torch_geometric.data import InMemoryDataset\n",
"\n",
"class ChemDataset(InMemoryDataset):\n",
" def __init__(self, datalist) -> None:\n",
" super().__init__()\n",
" self.data, self.slices = self.collate(datalist)\n",
" \n",
"sol_cls_dict = {'(A) low':0, '(B) medium':1, '(C) high':2}"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"trainpath = os.path.join(RDConfig.RDDocsDir, 'Book/data/solubility.train.sdf')\n",
"testpath = os.path.join(RDConfig.RDDocsDir, 'Book/data/solubility.test.sdf')\n",
"\n",
"train_mols = [m for m in Chem.SDMolSupplier(trainpath)]\n",
"test_mols = [m for m in Chem.SDMolSupplier(testpath)]\n",
"\n",
"train_X = [molutil.mol2vec(m) for m in train_mols]\n",
"for i, data in enumerate(train_X):\n",
" y = sol_cls_dict[train_mols[i].GetProp('SOL_classification')]\n",
" data.y = torch.tensor([y], dtype=torch.long)\n",
" \n",
"test_X = [molutil.mol2vec(m) for m in test_mols]\n",
"for i, data in enumerate(test_X):\n",
" y = sol_cls_dict[test_mols[i].GetProp('SOL_classification')]\n",
" data.y = torch.tensor([y], dtype=torch.long)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"ChemDataset(257)"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"trainData = ChemDataset(train_X)\n",
"testData = ChemDataset(test_X)\n",
"\n",
"utils.graph_random_splits(trainData, train_ratio=0.4, val_ratio=0.4)\n",
"utils.graph_random_splits(testData, train_ratio=0.0, val_ratio=0.0)\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"config = {\n",
" 'models':{'gin': None},\n",
" 'feature': [{'name': 'NxLargeCliqueSize'}], \n",
" 'hpo': {'name': 'anneal', 'max_evals': 10},\n",
" 'ensemble': {'name': 'voting', 'size': 2},\n",
" 'trainer' : [\n",
" # trainer hp space\n",
" {'parameterName': 'max_epoch', 'type': 'INTEGER', 'maxValue': 20, 'minValue': 10, 'scalingType': 'LINEAR'},\n",
" {'parameterName': 'batch_size', 'type': 'INTEGER', 'maxValue': 128, 'minValue': 32, 'scalingType': 'LOG'},\n",
" {'parameterName': 'early_stopping_round', 'type': 'INTEGER', 'maxValue': 30, 'minValue': 10, 'scalingType': 'LINEAR'},\n",
" {'parameterName': 'lr', 'type': 'DOUBLE', 'maxValue': 1e-3, 'minValue': 1e-4, 'scalingType': 'LOG'},\n",
" {'parameterName': 'weight_decay', 'type': 'DOUBLE', 'maxValue': 5e-3, 'minValue': 5e-4, 'scalingType': 'LOG'},\n",
" ]\n",
"}\n",
"\n",
"solver = AutoGraphClassifier.from_config(config)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<autogl.solver.classifier.graph_classifier.AutoGraphClassifier at 0x7f39c84b4fd0>"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"solver.fit(trainData,\n",
" time_limit=720, \n",
" train_split=0.9, \n",
" val_split=0.1, \n",
" cross_validation=True,\n",
" cv_split=10, \n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"best single model:\n",
" <class 'torch.optim.adam.Adam'>-0.0008968685743489439-15-15-AutoGIN(\n",
" (model): GIN(\n",
" (convs): ModuleList(\n",
" (0): GINConv(nn=Sequential(\n",
" (0): Linear(in_features=75, out_features=25, bias=True)\n",
" (1): ELU(alpha=1.0)\n",
" (2): Linear(in_features=25, out_features=25, bias=True)\n",
" (3): ELU(alpha=1.0)\n",
" (4): Linear(in_features=25, out_features=25, bias=True)\n",
" ))\n",
" (1): GINConv(nn=Sequential(\n",
" (0): Linear(in_features=25, out_features=37, bias=True)\n",
" (1): ELU(alpha=1.0)\n",
" (2): Linear(in_features=37, out_features=37, bias=True)\n",
" (3): ELU(alpha=1.0)\n",
" (4): Linear(in_features=37, out_features=37, bias=True)\n",
" ))\n",
" (2): GINConv(nn=Sequential(\n",
" (0): Linear(in_features=37, out_features=52, bias=True)\n",
" (1): ELU(alpha=1.0)\n",
" (2): Linear(in_features=52, out_features=52, bias=True)\n",
" (3): ELU(alpha=1.0)\n",
" (4): Linear(in_features=52, out_features=52, bias=True)\n",
" ))\n",
" (3): GINConv(nn=Sequential(\n",
" (0): Linear(in_features=52, out_features=23, bias=True)\n",
" (1): ELU(alpha=1.0)\n",
" (2): Linear(in_features=23, out_features=23, bias=True)\n",
" (3): ELU(alpha=1.0)\n",
" (4): Linear(in_features=23, out_features=23, bias=True)\n",
" ))\n",
" )\n",
" (bns): ModuleList(\n",
" (0): BatchNorm1d(25, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (1): BatchNorm1d(37, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (2): BatchNorm1d(52, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (3): BatchNorm1d(23, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" (fc1): Linear(in_features=24, out_features=8, bias=True)\n",
" (fc2): Linear(in_features=8, out_features=3, bias=True)\n",
" )\n",
")-cpu|num_layers-6-hidden-[25, 37, 52, 23, 8]-dropout-0.4706157382603818-act-elu-eps-False-mlp_layers-3_cv5_idx0\n"
]
}
],
"source": [
"lb = solver.get_leaderboard()\n",
"\n",
"print('best single model:\\n', solver.get_leaderboard().get_best_model(0))"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" name acc\n",
"10 ensemble 0.766990\n",
"5 <class 'torch.optim.adam.Adam'>-0.000896868574... 0.757282\n",
"6 <class 'torch.optim.adam.Adam'>-0.000650776957... 0.737864\n",
"9 <class 'torch.optim.adam.Adam'>-0.000540590597... 0.689320\n",
"4 <class 'torch.optim.adam.Adam'>-0.000290357625... 0.669903\n",
"1 <class 'torch.optim.adam.Adam'>-0.000231124807... 0.650485\n",
"2 <class 'torch.optim.adam.Adam'>-0.000378524042... 0.650485\n",
"0 <class 'torch.optim.adam.Adam'>-0.000716295596... 0.631068\n",
"8 <class 'torch.optim.adam.Adam'>-0.000196322231... 0.601942\n",
"7 <class 'torch.optim.adam.Adam'>-0.000172943198... 0.582524\n",
"3 <class 'torch.optim.adam.Adam'>-0.000664233713... 0.563107\n"
]
}
],
"source": [
"lb.show()"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"pred = solver.predict(testData,\n",
" inplaced=False, \n",
" inplace=False,\n",
" use_ensemble=True, \n",
" use_best=True\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.metrics import accuracy_score"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.3657587548638132"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"accuracy_score(testData.data.y.numpy(), pred)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.9"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment