Created
February 15, 2022 18:23
-
-
Save dongchirua/71c6c02c71d580664dcea4e512c05198 to your computer and use it in GitHub Desktop.
MUTAG classification with Lightning
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from abc import ABC | |
import torch | |
import torch.nn.functional as F | |
from torch_geometric.loader import DataLoader | |
from pytorch_lightning import Trainer | |
from typing import Optional | |
from torch_geometric import seed_everything | |
from torch_geometric.nn import global_mean_pool, SAGEConv, Linear | |
from torch_geometric.datasets import TUDataset | |
from pytorch_lightning import LightningModule | |
from pytorch_lightning import LightningDataModule | |
from torch_geometric.data import Batch | |
from torchmetrics.functional import accuracy, f1 | |
class TrainingModule(LightningModule, ABC): | |
def __init__(self, model, in_channels: int, out_channels: int, hidden_channels: int = 64): | |
super().__init__() | |
self.save_hyperparameters() | |
self.model = model(input_channel=in_channels, n_class=out_channels, hidden_channels=hidden_channels) | |
self.criterion = torch.nn.CrossEntropyLoss() | |
def forward(self, x, edge_index, batch) -> torch.Tensor: | |
return self.model(x, edge_index, batch) | |
def evaluate(self, batch, stage=None): | |
y_hat = self(batch.x, batch.edge_index, batch.batch) | |
loss = self.criterion(y_hat, batch.y) | |
preds = torch.argmax(y_hat.softmax(dim=1), dim=1) | |
acc = accuracy(preds, batch.y) | |
f1_score = f1(preds, batch.y) | |
if stage: | |
self.log(f"{stage}_loss", loss, on_step=True, on_epoch=True, logger=True) | |
self.log(f"{stage}_acc", acc, on_step=True, on_epoch=True, logger=True) | |
self.log(f"{stage}_f1", f1_score, on_step=True, on_epoch=True, logger=True) | |
return loss | |
def training_step(self, batch: Batch, batch_idx: int): | |
return self.evaluate(batch, 'train') | |
def validation_step(self, batch: Batch, batch_idx: int): | |
self.evaluate(batch, 'val') | |
def test_step(self, batch: Batch, batch_idx: int): | |
self.evaluate(batch, 'test') | |
def configure_optimizers(self): | |
return torch.optim.Adam(self.parameters(), lr=0.01) | |
class DataModule(LightningDataModule, ABC): | |
def __init__(self, dataset, batch_size=64, num_workers=1): | |
super().__init__() | |
self.train_dataset = None | |
self.val_dataset = None | |
self.test_dataset = None | |
self.dataset = dataset | |
self.batch_size = batch_size | |
self.num_workers = num_workers | |
@property | |
def num_features(self) -> int: | |
return self.dataset.num_features | |
@property | |
def num_classes(self) -> int: | |
return self.dataset.num_classes | |
def prepare_data(self): | |
pass | |
def setup(self, stage: Optional[str] = None): | |
data_set = self.dataset.shuffle() | |
self.train_dataset = data_set[:130] | |
self.val_dataset = data_set[130:150] | |
self.test_dataset = data_set[150:] | |
def train_dataloader(self) -> DataLoader: | |
return DataLoader(self.train_dataset, batch_size=self.batch_size, num_workers=self.num_workers) | |
def val_dataloader(self) -> DataLoader: | |
return DataLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers) | |
def test_dataloader(self) -> DataLoader: | |
return DataLoader(self.test_dataset, batch_size=self.batch_size, num_workers=self.num_workers) | |
class GNN(torch.nn.Module): | |
def __init__(self, input_channel, hidden_channels, n_class): | |
super(GNN, self).__init__() | |
self.conv1 = SAGEConv(input_channel, hidden_channels) | |
self.conv2 = SAGEConv(hidden_channels, hidden_channels) | |
self.conv3 = SAGEConv(hidden_channels, hidden_channels) | |
self.lin = Linear(hidden_channels, n_class) | |
def forward(self, x, edge_index, batch): | |
x = self.conv1(x, edge_index) | |
x = x.relu() | |
x = self.conv2(x, edge_index) | |
x = x.relu() | |
x = self.conv3(x, edge_index) | |
x = global_mean_pool(x, batch) | |
x = F.dropout(x, p=0.5, training=self.training) | |
x = self.lin(x) | |
return x | |
if __name__ == "__main__": | |
dataset = TUDataset(root='data/TUDataset', name='MUTAG') | |
seed_everything(12345) | |
data_module = DataModule(dataset, 10, num_workers=4) | |
trainer = Trainer(fast_dev_run=True) | |
model = TrainingModule(GNN, data_module.num_features, data_module.num_classes, 64) | |
trainer.fit(model, datamodule=data_module) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment