Skip to content

Instantly share code, notes, and snippets.

@tchaton
Last active December 6, 2020 20:08
Show Gist options
  • Save tchaton/86aa4e14318be858d346168505317d4a to your computer and use it in GitHub Desktop.
Save tchaton/86aa4e14318be858d346168505317d4a to your computer and use it in GitHub Desktop.
import os
from pytorch_lightning import LightningDataModule
from torch_geometric.datasets import Planetoid
import torch_geometric.transforms as T
class CoraDataset(LightningDataModule):
NAME = "cora"
def __init__(self):
super().__init__()
@property
def num_features(self):
return 1433
@property
def num_classes(self):
return 7
@property
def hyper_parameters(self):
# used to inform the model of the dataset specifications
return {"num_features": self.num_features, "num_classes": self.num_classes}
def prepare_data(self):
self.dataset = Planetoid(os.getcwd(), self.NAME, transform=self._transform)
self.data = self.dataset[0]
@staticmethod
def add_argparse_args(parser):
parser.add_argument("--num_workers", type=int, default=1)
parser.add_argument("--batch_size", type=int, default=2)
parser.add_argument("--drop_last", default=True)
parser.add_argument("--pin_memory", default=True)
return parser
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment