Skip to content

Instantly share code, notes, and snippets.

@dongchirua
Created February 15, 2022 18:23
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dongchirua/71c6c02c71d580664dcea4e512c05198 to your computer and use it in GitHub Desktop.
Save dongchirua/71c6c02c71d580664dcea4e512c05198 to your computer and use it in GitHub Desktop.
MUTAG classification with Lightning
from abc import ABC
import torch
import torch.nn.functional as F
from torch_geometric.loader import DataLoader
from pytorch_lightning import Trainer
from typing import Optional
from torch_geometric import seed_everything
from torch_geometric.nn import global_mean_pool, SAGEConv, Linear
from torch_geometric.datasets import TUDataset
from pytorch_lightning import LightningModule
from pytorch_lightning import LightningDataModule
from torch_geometric.data import Batch
from torchmetrics.functional import accuracy, f1
class TrainingModule(LightningModule, ABC):
def __init__(self, model, in_channels: int, out_channels: int, hidden_channels: int = 64):
super().__init__()
self.save_hyperparameters()
self.model = model(input_channel=in_channels, n_class=out_channels, hidden_channels=hidden_channels)
self.criterion = torch.nn.CrossEntropyLoss()
def forward(self, x, edge_index, batch) -> torch.Tensor:
return self.model(x, edge_index, batch)
def evaluate(self, batch, stage=None):
y_hat = self(batch.x, batch.edge_index, batch.batch)
loss = self.criterion(y_hat, batch.y)
preds = torch.argmax(y_hat.softmax(dim=1), dim=1)
acc = accuracy(preds, batch.y)
f1_score = f1(preds, batch.y)
if stage:
self.log(f"{stage}_loss", loss, on_step=True, on_epoch=True, logger=True)
self.log(f"{stage}_acc", acc, on_step=True, on_epoch=True, logger=True)
self.log(f"{stage}_f1", f1_score, on_step=True, on_epoch=True, logger=True)
return loss
def training_step(self, batch: Batch, batch_idx: int):
return self.evaluate(batch, 'train')
def validation_step(self, batch: Batch, batch_idx: int):
self.evaluate(batch, 'val')
def test_step(self, batch: Batch, batch_idx: int):
self.evaluate(batch, 'test')
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.01)
class DataModule(LightningDataModule, ABC):
def __init__(self, dataset, batch_size=64, num_workers=1):
super().__init__()
self.train_dataset = None
self.val_dataset = None
self.test_dataset = None
self.dataset = dataset
self.batch_size = batch_size
self.num_workers = num_workers
@property
def num_features(self) -> int:
return self.dataset.num_features
@property
def num_classes(self) -> int:
return self.dataset.num_classes
def prepare_data(self):
pass
def setup(self, stage: Optional[str] = None):
data_set = self.dataset.shuffle()
self.train_dataset = data_set[:130]
self.val_dataset = data_set[130:150]
self.test_dataset = data_set[150:]
def train_dataloader(self) -> DataLoader:
return DataLoader(self.train_dataset, batch_size=self.batch_size, num_workers=self.num_workers)
def val_dataloader(self) -> DataLoader:
return DataLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers)
def test_dataloader(self) -> DataLoader:
return DataLoader(self.test_dataset, batch_size=self.batch_size, num_workers=self.num_workers)
class GNN(torch.nn.Module):
def __init__(self, input_channel, hidden_channels, n_class):
super(GNN, self).__init__()
self.conv1 = SAGEConv(input_channel, hidden_channels)
self.conv2 = SAGEConv(hidden_channels, hidden_channels)
self.conv3 = SAGEConv(hidden_channels, hidden_channels)
self.lin = Linear(hidden_channels, n_class)
def forward(self, x, edge_index, batch):
x = self.conv1(x, edge_index)
x = x.relu()
x = self.conv2(x, edge_index)
x = x.relu()
x = self.conv3(x, edge_index)
x = global_mean_pool(x, batch)
x = F.dropout(x, p=0.5, training=self.training)
x = self.lin(x)
return x
if __name__ == "__main__":
dataset = TUDataset(root='data/TUDataset', name='MUTAG')
seed_everything(12345)
data_module = DataModule(dataset, 10, num_workers=4)
trainer = Trainer(fast_dev_run=True)
model = TrainingModule(GNN, data_module.num_features, data_module.num_classes, 64)
trainer.fit(model, datamodule=data_module)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment