This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch_geometric.datasets import Planetoid | |
import torch_geometric.transforms as T | |
transform = T.NormalizeFeatures() # used to normalise | |
dataset = Planetoid(path, "Cora", transform=transform) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class MNISTDataModule(pl.LightningDataModule): | |
def __init__(self, data_dir: str = PATH, batch_size): | |
super().__init__() | |
self.batch_size = batch_size | |
def setup(self, stage=None): | |
self.mnist_test = MNIST(self.data_dir, train=False) | |
mnist_full = MNIST(self.data_dir, train=True) | |
self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
from pytorch_lightning import LightningDataModule | |
from torch_geometric.datasets import Planetoid | |
import torch_geometric.transforms as T | |
class CoraDataset(LightningDataModule): | |
NAME = "cora" | |
def __init__(self): |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class CoraDataset(LightningDataModule): | |
NAME = "cora" | |
.... | |
def create_neighbor_sampler(self, batch_size=2, stage=None): | |
# https://github.com/rusty1s/pytorch_geometric/tree/ | |
# master/torch_geometric/data/sampler.py#L18 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch.nn import Sequential as Seq, Linear as Lin, ReLU | |
from torch_geometric.nn import MessagePassing | |
class EdgeConv(MessagePassing): | |
def __init__(self, F_in, F_out): | |
super(EdgeConv, self).__init__(aggr='max') # "Max" aggregation. | |
self.mlp = Seq(Lin(2 * F_in, F_out), ReLU(), Lin(F_out, F_out)) | |
def forward(self, x, edge_index): |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch.nn import Module | |
from torch_geometric.nn.conv import DNAConv | |
class DNAConvNet(Module): | |
def __init__(self, | |
num_layers: int = 2, | |
hidden_channels: int = 128, | |
heads: int = 8, | |
groups: int = 16, |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
... | |
from pytorch_lightning import LightningModule | |
from pytorch_lightning.metrics import Accuracy | |
class DNAConvNet(LightningModule): | |
def __init__(self, | |
num_layers: int = 2, | |
hidden_channels: int = 128, |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def instantiate_datamodule(args): | |
datamodule = CoraDataset( | |
num_workers=args.num_workers, | |
batch_size=args.batch_size, | |
drop_last=args.drop_last, | |
pin_memory=args.pin_memory, | |
num_layers=args.num_layers, | |
) | |
return datamodule |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch_geometric.transforms as T | |
class CoraDataset(LightningDataModule): | |
NAME = "cora" | |
... | |
def gather_data(self, batch, batch_nb): | |
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class DNAConvNet(LightningModule): | |
... | |
def _convert_to_jittable(self, module): | |
for key, m in module._modules.items(): | |
if isinstance(m, MessagePassing) and m.jittable is not None: | |
# Pytorch Geometric MessagePassing implements a `.jittable` function | |
# which converts the current module into its jittable version. | |
module._modules[key] = m.jittable() | |
else: | |
self._convert_to_jittable(m) |
OlderNewer