Skip to content

Instantly share code, notes, and snippets.

@edraizen
Created September 9, 2021 18:11
Show Gist options
  • Save edraizen/1d6572b798e3cae94d4cd8961d663ca5 to your computer and use it in GitHub Desktop.
Save edraizen/1d6572b798e3cae94d4cd8961d663ca5 to your computer and use it in GitHub Desktop.
from torch.utils.data import Dataset as _Dataset
from torch.utils.data import Subset
import h5pyd
class DistributedDataset(_Dataset):
"""Read dataset from h5 file. If key specifies a dataset, each row is an
independent sample. If kay specifies a group, each dataset is an independent
sample.
Paramaters
----------
path : str
The name of the full h5 file with all groups and datasets
key : str
The key to the dataset or group, specifying all intermediate groups
test : bool
Set mode to testing. This saves all rows or datasets IDs as an embedding
to compare against
dataset_group_name : str
Name of group to split datasets and data_splits. Default is 'datasets'
file_mode : str
Open h5 file for reading and or writing. Writing should only be used if
creating data_splits
"""
def __init__(self, path, key, test=False, dataset_group_name="datasets", file_mode="r"):
self.path = path
self.key = key
self.test = test
self.file_mode = file_mode
self.f = h5pyd.File(path, file_mode, use_cache=False)
self.data = self.f[key]
self.embedding = None
if not isinstance(self.data, h5pyd.Dataset):
if dataset_group_name in self.data.keys():
self.data = self.f[f'{key}/{dataset_group_name}']
self.order = sorted(self.data.keys())
if self.test:
from sklearn import preprocessing
self.embedding = preprocessing.LabelEncoder().fit(self.order)
else:
self.order = list(range(len(self.data)))
def __len__(self):
return len(self.data)
def __getitem__(self, index):
ds = self.data[self.order[index]]
return ds
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment