Last active
December 6, 2020 20:08
-
-
Save tchaton/86aa4e14318be858d346168505317d4a to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
from pytorch_lightning import LightningDataModule | |
from torch_geometric.datasets import Planetoid | |
import torch_geometric.transforms as T | |
class CoraDataset(LightningDataModule): | |
NAME = "cora" | |
def __init__(self): | |
super().__init__() | |
@property | |
def num_features(self): | |
return 1433 | |
@property | |
def num_classes(self): | |
return 7 | |
@property | |
def hyper_parameters(self): | |
# used to inform the model of the dataset specifications | |
return {"num_features": self.num_features, "num_classes": self.num_classes} | |
def prepare_data(self): | |
self.dataset = Planetoid(os.getcwd(), self.NAME, transform=self._transform) | |
self.data = self.dataset[0] | |
@staticmethod | |
def add_argparse_args(parser): | |
parser.add_argument("--num_workers", type=int, default=1) | |
parser.add_argument("--batch_size", type=int, default=2) | |
parser.add_argument("--drop_last", default=True) | |
parser.add_argument("--pin_memory", default=True) | |
return parser |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment