-
-
Save leelasd/ae511789e5f6bb528132736cfb92810c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import torch" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from autogl.solver import AutoNodeClassifier\n", | |
"from autogl.solver import AutoGraphClassifier\n", | |
"from autogl.module.feature import BaseFeatureEngineer\n", | |
"from autogl.module.feature import BaseFeatureAtom\n", | |
"from autogl.datasets import utils" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import os\n", | |
"from rdkit import Chem\n", | |
"from rdkit.Chem import RDConfig\n", | |
"import molutil\n", | |
"from torch_geometric.data import Data, DataLoader, Dataset\n", | |
"from torch_geometric.data import InMemoryDataset\n", | |
"\n", | |
"class ChemDataset(InMemoryDataset):\n", | |
" def __init__(self, datalist) -> None:\n", | |
" super().__init__()\n", | |
" self.data, self.slices = self.collate(datalist)\n", | |
" \n", | |
"sol_cls_dict = {'(A) low':0, '(B) medium':1, '(C) high':2}" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"trainpath = os.path.join(RDConfig.RDDocsDir, 'Book/data/solubility.train.sdf')\n", | |
"testpath = os.path.join(RDConfig.RDDocsDir, 'Book/data/solubility.test.sdf')\n", | |
"\n", | |
"train_mols = [m for m in Chem.SDMolSupplier(trainpath)]\n", | |
"test_mols = [m for m in Chem.SDMolSupplier(testpath)]\n", | |
"\n", | |
"train_X = [molutil.mol2vec(m) for m in train_mols]\n", | |
"for i, data in enumerate(train_X):\n", | |
" y = sol_cls_dict[train_mols[i].GetProp('SOL_classification')]\n", | |
" data.y = torch.tensor([y], dtype=torch.long)\n", | |
" \n", | |
"test_X = [molutil.mol2vec(m) for m in test_mols]\n", | |
"for i, data in enumerate(test_X):\n", | |
" y = sol_cls_dict[test_mols[i].GetProp('SOL_classification')]\n", | |
" data.y = torch.tensor([y], dtype=torch.long)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"ChemDataset(257)" | |
] | |
}, | |
"execution_count": 12, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"trainData = ChemDataset(train_X)\n", | |
"testData = ChemDataset(test_X)\n", | |
"\n", | |
"utils.graph_random_splits(trainData, train_ratio=0.4, val_ratio=0.4)\n", | |
"utils.graph_random_splits(testData, train_ratio=0.0, val_ratio=0.0)\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"config = {\n", | |
" 'models':{'gin': None},\n", | |
" 'feature': [{'name': 'NxLargeCliqueSize'}], \n", | |
" 'hpo': {'name': 'anneal', 'max_evals': 10},\n", | |
" 'ensemble': {'name': 'voting', 'size': 2},\n", | |
" 'trainer' : [\n", | |
" # trainer hp space\n", | |
" {'parameterName': 'max_epoch', 'type': 'INTEGER', 'maxValue': 20, 'minValue': 10, 'scalingType': 'LINEAR'},\n", | |
" {'parameterName': 'batch_size', 'type': 'INTEGER', 'maxValue': 128, 'minValue': 32, 'scalingType': 'LOG'},\n", | |
" {'parameterName': 'early_stopping_round', 'type': 'INTEGER', 'maxValue': 30, 'minValue': 10, 'scalingType': 'LINEAR'},\n", | |
" {'parameterName': 'lr', 'type': 'DOUBLE', 'maxValue': 1e-3, 'minValue': 1e-4, 'scalingType': 'LOG'},\n", | |
" {'parameterName': 'weight_decay', 'type': 'DOUBLE', 'maxValue': 5e-3, 'minValue': 5e-4, 'scalingType': 'LOG'},\n", | |
" ]\n", | |
"}\n", | |
"\n", | |
"solver = AutoGraphClassifier.from_config(config)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"<autogl.solver.classifier.graph_classifier.AutoGraphClassifier at 0x7f39c84b4fd0>" | |
] | |
}, | |
"execution_count": 8, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"solver.fit(trainData,\n", | |
" time_limit=720, \n", | |
" train_split=0.9, \n", | |
" val_split=0.1, \n", | |
" cross_validation=True,\n", | |
" cv_split=10, \n", | |
" )" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"best single model:\n", | |
" <class 'torch.optim.adam.Adam'>-0.0008968685743489439-15-15-AutoGIN(\n", | |
" (model): GIN(\n", | |
" (convs): ModuleList(\n", | |
" (0): GINConv(nn=Sequential(\n", | |
" (0): Linear(in_features=75, out_features=25, bias=True)\n", | |
" (1): ELU(alpha=1.0)\n", | |
" (2): Linear(in_features=25, out_features=25, bias=True)\n", | |
" (3): ELU(alpha=1.0)\n", | |
" (4): Linear(in_features=25, out_features=25, bias=True)\n", | |
" ))\n", | |
" (1): GINConv(nn=Sequential(\n", | |
" (0): Linear(in_features=25, out_features=37, bias=True)\n", | |
" (1): ELU(alpha=1.0)\n", | |
" (2): Linear(in_features=37, out_features=37, bias=True)\n", | |
" (3): ELU(alpha=1.0)\n", | |
" (4): Linear(in_features=37, out_features=37, bias=True)\n", | |
" ))\n", | |
" (2): GINConv(nn=Sequential(\n", | |
" (0): Linear(in_features=37, out_features=52, bias=True)\n", | |
" (1): ELU(alpha=1.0)\n", | |
" (2): Linear(in_features=52, out_features=52, bias=True)\n", | |
" (3): ELU(alpha=1.0)\n", | |
" (4): Linear(in_features=52, out_features=52, bias=True)\n", | |
" ))\n", | |
" (3): GINConv(nn=Sequential(\n", | |
" (0): Linear(in_features=52, out_features=23, bias=True)\n", | |
" (1): ELU(alpha=1.0)\n", | |
" (2): Linear(in_features=23, out_features=23, bias=True)\n", | |
" (3): ELU(alpha=1.0)\n", | |
" (4): Linear(in_features=23, out_features=23, bias=True)\n", | |
" ))\n", | |
" )\n", | |
" (bns): ModuleList(\n", | |
" (0): BatchNorm1d(25, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (1): BatchNorm1d(37, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (2): BatchNorm1d(52, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (3): BatchNorm1d(23, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" )\n", | |
" (fc1): Linear(in_features=24, out_features=8, bias=True)\n", | |
" (fc2): Linear(in_features=8, out_features=3, bias=True)\n", | |
" )\n", | |
")-cpu|num_layers-6-hidden-[25, 37, 52, 23, 8]-dropout-0.4706157382603818-act-elu-eps-False-mlp_layers-3_cv5_idx0\n" | |
] | |
} | |
], | |
"source": [ | |
"lb = solver.get_leaderboard()\n", | |
"\n", | |
"print('best single model:\\n', solver.get_leaderboard().get_best_model(0))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
" name acc\n", | |
"10 ensemble 0.766990\n", | |
"5 <class 'torch.optim.adam.Adam'>-0.000896868574... 0.757282\n", | |
"6 <class 'torch.optim.adam.Adam'>-0.000650776957... 0.737864\n", | |
"9 <class 'torch.optim.adam.Adam'>-0.000540590597... 0.689320\n", | |
"4 <class 'torch.optim.adam.Adam'>-0.000290357625... 0.669903\n", | |
"1 <class 'torch.optim.adam.Adam'>-0.000231124807... 0.650485\n", | |
"2 <class 'torch.optim.adam.Adam'>-0.000378524042... 0.650485\n", | |
"0 <class 'torch.optim.adam.Adam'>-0.000716295596... 0.631068\n", | |
"8 <class 'torch.optim.adam.Adam'>-0.000196322231... 0.601942\n", | |
"7 <class 'torch.optim.adam.Adam'>-0.000172943198... 0.582524\n", | |
"3 <class 'torch.optim.adam.Adam'>-0.000664233713... 0.563107\n" | |
] | |
} | |
], | |
"source": [ | |
"lb.show()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"pred = solver.predict(testData,\n", | |
" inplaced=False, \n", | |
" inplace=False,\n", | |
" use_ensemble=True, \n", | |
" use_best=True\n", | |
" )" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from sklearn.metrics import accuracy_score" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"0.3657587548638132" | |
] | |
}, | |
"execution_count": 17, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"accuracy_score(testData.data.y.numpy(), pred)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.7.9" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 4 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment