This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch_geometric.datasets import Planetoid | |
import torch_geometric.transforms as T | |
transform = T.NormalizeFeatures() # used to normalise | |
dataset = Planetoid(path, "Cora", transform=transform) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class MNISTDataModule(pl.LightningDataModule): | |
def __init__(self, data_dir: str = PATH, batch_size): | |
super().__init__() | |
self.batch_size = batch_size | |
def setup(self, stage=None): | |
self.mnist_test = MNIST(self.data_dir, train=False) | |
mnist_full = MNIST(self.data_dir, train=True) | |
self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch.nn import Sequential as Seq, Linear as Lin, ReLU | |
from torch_geometric.nn import MessagePassing | |
class EdgeConv(MessagePassing): | |
def __init__(self, F_in, F_out): | |
super(EdgeConv, self).__init__(aggr='max') # "Max" aggregation. | |
self.mlp = Seq(Lin(2 * F_in, F_out), ReLU(), Lin(F_out, F_out)) | |
def forward(self, x, edge_index): |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def instantiate_datamodule(args): | |
datamodule = CoraDataset( | |
num_workers=args.num_workers, | |
batch_size=args.batch_size, | |
drop_last=args.drop_last, | |
pin_memory=args.pin_memory, | |
num_layers=args.num_layers, | |
) | |
return datamodule |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
from pytorch_lightning import LightningDataModule | |
from torch_geometric.datasets import Planetoid | |
import torch_geometric.transforms as T | |
class CoraDataset(LightningDataModule): | |
NAME = "cora" | |
def __init__(self): |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch_geometric.transforms as T | |
class CoraDataset(LightningDataModule): | |
NAME = "cora" | |
... | |
def gather_data(self, batch, batch_nb): | |
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class DNAConvNet(LightningModule): | |
... | |
def _convert_to_jittable(self, module): | |
for key, m in module._modules.items(): | |
if isinstance(m, MessagePassing) and m.jittable is not None: | |
# Pytorch Geometric MessagePassing implements a `.jittable` function | |
# which converts the current module into its jittable version. | |
module._modules[key] = m.jittable() | |
else: | |
self._convert_to_jittable(m) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import List, Optional, NamedTuple | |
# use to make model jittable | |
OptTensor = Optional[Tensor] | |
ListTensor = List[Tensor] | |
class TensorBatch(NamedTuple): | |
x: Tensor | |
edge_index: ListTensor | |
edge_attr: OptTensor |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def get_single_batch(datamodule): | |
for batch in datamodule.test_dataloader(): | |
return datamodule.gather_data(batch, 0) | |
def run(args): | |
datamodule: LightningDataModule = instantiate_datamodule(args) | |
model: LightningModule = instantiate_model(args, datamodule) | |
print(model) | |
model.jittable() | |
print(model) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
... | |
from pytorch_lightning import LightningModule | |
from pytorch_lightning.metrics import Accuracy | |
class DNAConvNet(LightningModule): | |
def __init__(self, | |
num_layers: int = 2, | |
hidden_channels: int = 128, |
OlderNewer