Skip to content

Instantly share code, notes, and snippets.

@tchaton
Created December 6, 2020 20:10
Show Gist options
  • Save tchaton/e4679789db49e537a18a3c89ec6ab869 to your computer and use it in GitHub Desktop.
Save tchaton/e4679789db49e537a18a3c89ec6ab869 to your computer and use it in GitHub Desktop.
import torch_geometric.transforms as T
class CoraDataset(LightningDataModule):
NAME = "cora"
...
def gather_data(self, batch, batch_nb):
"""
This function will select features using node_idx
and convert it into a Namedtuple
"""
usual_keys = ["x", "edge_index", "edge_attr", "batch"]
Batch: TensorBatch = namedtuple("Batch", usual_keys)
return (
Batch(
self.data.x[batch[1]],
[e.edge_index for e in batch[2]],
None, # Not used as we don't have any attr for the edge
None # tensor associated its node to its orginal batch index
),
self.data.y[batch[1]],
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment