Skip to content

Instantly share code, notes, and snippets.

View tchaton's full-sized avatar
👻
Always up for it !

thomas chaton tchaton

👻
Always up for it !
View GitHub Profile
@tchaton
tchaton / blog_1.py
Last active December 6, 2020 19:02
from torch_geometric.datasets import Planetoid
import torch_geometric.transforms as T
transform = T.NormalizeFeatures() # used to normalise
dataset = Planetoid(path, "Cora", transform=transform)
@tchaton
tchaton / blog_2.py
Last active December 6, 2020 19:04
class MNISTDataModule(pl.LightningDataModule):
def __init__(self, data_dir: str = PATH, batch_size):
super().__init__()
self.batch_size = batch_size
def setup(self, stage=None):
self.mnist_test = MNIST(self.data_dir, train=False)
mnist_full = MNIST(self.data_dir, train=True)
self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
@tchaton
tchaton / blog_3.py
Last active December 6, 2020 20:08
import os
from pytorch_lightning import LightningDataModule
from torch_geometric.datasets import Planetoid
import torch_geometric.transforms as T
class CoraDataset(LightningDataModule):
NAME = "cora"
def __init__(self):
@tchaton
tchaton / blog_4.py
Last active December 7, 2020 14:05
class CoraDataset(LightningDataModule):
NAME = "cora"
....
def create_neighbor_sampler(self, batch_size=2, stage=None):
# https://github.com/rusty1s/pytorch_geometric/tree/
# master/torch_geometric/data/sampler.py#L18
import torch
from torch.nn import Sequential as Seq, Linear as Lin, ReLU
from torch_geometric.nn import MessagePassing
class EdgeConv(MessagePassing):
def __init__(self, F_in, F_out):
super(EdgeConv, self).__init__(aggr='max') # "Max" aggregation.
self.mlp = Seq(Lin(2 * F_in, F_out), ReLU(), Lin(F_out, F_out))
def forward(self, x, edge_index):
@tchaton
tchaton / blog_6.py
Last active December 7, 2020 14:05
from torch.nn import Module
from torch_geometric.nn.conv import DNAConv
class DNAConvNet(Module):
def __init__(self,
num_layers: int = 2,
hidden_channels: int = 128,
heads: int = 8,
groups: int = 16,
@tchaton
tchaton / blog_7.py
Last active December 7, 2020 14:04
...
from pytorch_lightning import LightningModule
from pytorch_lightning.metrics import Accuracy
class DNAConvNet(LightningModule):
def __init__(self,
num_layers: int = 2,
hidden_channels: int = 128,
def instantiate_datamodule(args):
datamodule = CoraDataset(
num_workers=args.num_workers,
batch_size=args.batch_size,
drop_last=args.drop_last,
pin_memory=args.pin_memory,
num_layers=args.num_layers,
)
return datamodule
import torch_geometric.transforms as T
class CoraDataset(LightningDataModule):
NAME = "cora"
...
def gather_data(self, batch, batch_nb):
"""
class DNAConvNet(LightningModule):
...
def _convert_to_jittable(self, module):
for key, m in module._modules.items():
if isinstance(m, MessagePassing) and m.jittable is not None:
# Pytorch Geometric MessagePassing implements a `.jittable` function
# which converts the current module into its jittable version.
module._modules[key] = m.jittable()
else:
self._convert_to_jittable(m)