Skip to content

Instantly share code, notes, and snippets.

@chenyaofo
Created August 6, 2018 08:30
Show Gist options
  • Save chenyaofo/d6148b2ff6e85fdf40f983109f64326d to your computer and use it in GitHub Desktop.
Save chenyaofo/d6148b2ff6e85fdf40f983109f64326d to your computer and use it in GitHub Desktop.
MIODataset
import copy
from torch.utils.data import Dataset
from torchlearning.mio import MIO, Split
class MioDataset(Dataset):
def __init__(self, root, sampler, transform=None, target_transform=None):
self.root = root
self.sampler = sampler
self.transform = transform
self.target_transform = target_transform
self.mio = MIO(self.root)
self.split = None
def to_split(self, split: Split):
dataset = copy.copy(self)
dataset.split = split.items
return dataset
def __getitem__(self, id_):
if self.split is not None:
id_ = self.split[id_]
size = self.mio.get_collection_size(id_)
selected_samples = self.sampler(size)
if isinstance(selected_samples, int):
object_id = selected_samples
data = self.mio.fetchone(id_, object_id)
else:
data = self.mio.fetchmany(id_, selected_samples)
target = self.mio.get_collection_metadata(id_)
if self.transform is not None:
data = self.transform(data)
if self.target_transform is not None:
target = self.target_transform(target)
return data, target
def __len__(self):
if self.split is None:
return self.mio.size
else:
return len(self.split)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment