Skip to content

Instantly share code, notes, and snippets.

@tchaton
Last active December 7, 2020 14:05
Show Gist options
  • Save tchaton/d3361ee5868d784f832715dbe05b8624 to your computer and use it in GitHub Desktop.
Save tchaton/d3361ee5868d784f832715dbe05b8624 to your computer and use it in GitHub Desktop.
class CoraDataset(LightningDataModule):
NAME = "cora"
....
def create_neighbor_sampler(self, batch_size=2, stage=None):
# https://github.com/rusty1s/pytorch_geometric/tree/
# master/torch_geometric/data/sampler.py#L18
# NeighborSampler is used to create random bipartite graph between
# a given node and its neighbors using random walk.
# Those random subgraphs will be used to train the graph convolution model.
return NeighborSampler(
self.data.edge_index,
# the nodes that should be considered for sampling.
node_idx=data[f'{stage}_mask'],
# -1 indicates all neighbors will be selected
sizes=self._num_layers * [-1],
)
def train_dataloader(self):
return self.create_neighbor_sampler(stage="train")
def validation_dataloader(self):
return self.create_neighbor_sampler(stage="val")
def test_dataloader(self):
return self.create_neighbor_sampler(stage="test")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment