Skip to content

Instantly share code, notes, and snippets.

View tchaton's full-sized avatar
👻
Always up for it !

thomas chaton tchaton

👻
Always up for it !
View GitHub Profile
@tchaton
tchaton / blog_1.py
Last active December 6, 2020 19:02
from torch_geometric.datasets import Planetoid
import torch_geometric.transforms as T
transform = T.NormalizeFeatures() # used to normalise
dataset = Planetoid(path, "Cora", transform=transform)
@tchaton
tchaton / blog_2.py
Last active December 6, 2020 19:04
class MNISTDataModule(pl.LightningDataModule):
def __init__(self, data_dir: str = PATH, batch_size):
super().__init__()
self.batch_size = batch_size
def setup(self, stage=None):
self.mnist_test = MNIST(self.data_dir, train=False)
mnist_full = MNIST(self.data_dir, train=True)
self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
import torch
from torch.nn import Sequential as Seq, Linear as Lin, ReLU
from torch_geometric.nn import MessagePassing
class EdgeConv(MessagePassing):
def __init__(self, F_in, F_out):
super(EdgeConv, self).__init__(aggr='max') # "Max" aggregation.
self.mlp = Seq(Lin(2 * F_in, F_out), ReLU(), Lin(F_out, F_out))
def forward(self, x, edge_index):
def instantiate_datamodule(args):
datamodule = CoraDataset(
num_workers=args.num_workers,
batch_size=args.batch_size,
drop_last=args.drop_last,
pin_memory=args.pin_memory,
num_layers=args.num_layers,
)
return datamodule
@tchaton
tchaton / blog_3.py
Last active December 6, 2020 20:08
import os
from pytorch_lightning import LightningDataModule
from torch_geometric.datasets import Planetoid
import torch_geometric.transforms as T
class CoraDataset(LightningDataModule):
NAME = "cora"
def __init__(self):
import torch_geometric.transforms as T
class CoraDataset(LightningDataModule):
NAME = "cora"
...
def gather_data(self, batch, batch_nb):
"""
class DNAConvNet(LightningModule):
...
def _convert_to_jittable(self, module):
for key, m in module._modules.items():
if isinstance(m, MessagePassing) and m.jittable is not None:
# Pytorch Geometric MessagePassing implements a `.jittable` function
# which converts the current module into its jittable version.
module._modules[key] = m.jittable()
else:
self._convert_to_jittable(m)
from typing import List, Optional, NamedTuple
# use to make model jittable
OptTensor = Optional[Tensor]
ListTensor = List[Tensor]
class TensorBatch(NamedTuple):
x: Tensor
edge_index: ListTensor
edge_attr: OptTensor
def get_single_batch(datamodule):
for batch in datamodule.test_dataloader():
return datamodule.gather_data(batch, 0)
def run(args):
datamodule: LightningDataModule = instantiate_datamodule(args)
model: LightningModule = instantiate_model(args, datamodule)
print(model)
model.jittable()
print(model)
@tchaton
tchaton / blog_7.py
Last active December 7, 2020 14:04
...
from pytorch_lightning import LightningModule
from pytorch_lightning.metrics import Accuracy
class DNAConvNet(LightningModule):
def __init__(self,
num_layers: int = 2,
hidden_channels: int = 128,