Skip to content

Instantly share code, notes, and snippets.

@khuangaf
Created May 11, 2019 15:13
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 2 You must be signed in to fork a gist
  • Save khuangaf/7f876c6ad4e4adcd36caea98b159b6f6 to your computer and use it in GitHub Desktop.
Save khuangaf/7f876c6ad4e4adcd36caea98b159b6f6 to your computer and use it in GitHub Desktop.
import torch
from torch_geometric.data import InMemoryDataset
class MyOwnDataset(InMemoryDataset):
def __init__(self, root, transform=None, pre_transform=None):
super(MyOwnDataset, self).__init__(root, transform, pre_transform)
self.data, self.slices = torch.load(self.processed_paths[0])
@property
def raw_file_names(self):
return ['some_file_1', 'some_file_2', ...]
@property
def processed_file_names(self):
return ['data.pt']
def download(self):
# Download to `self.raw_dir`.
def process(self):
# Read data into huge `Data` list.
data_list = [...]
if self.pre_filter is not None:
data_list [data for data in data_list if self.pre_filter(data)]
if self.pre_transform is not None:
data_list = [self.pre_transform(data) for data in data_list]
data, slices = self.collate(data_list)
torch.save((data, slices), self.processed_paths[0])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment