Skip to content

Instantly share code, notes, and snippets.

@qubvel
Last active September 25, 2019 23:06
Show Gist options
  • Save qubvel/a024dc7fc5d4ee309c739a4c84b6590e to your computer and use it in GitHub Desktop.
Save qubvel/a024dc7fc5d4ee309c739a4c84b6590e to your computer and use it in GitHub Desktop.
class TestDataset(Dataset):
def __init__(self, base_path, ids, transform):
self.base_path = base_path
self.ids = ids
self.transform = transform
def __getitem__(self, i):
sample = self.get_sample(i)
if self.transform is not None:
sample = self.transform(**sample)
return sample
def get_sample(self, i):
return {
'id': self.ids[i],
'image': self.read_image(i),
}
def read_image(self, i):
image = ...
return image
class TrainDataset(TestDataset):
def __init__(self, base_path, ids, transform, ...):
super().__init__(base_path, ids, transform)
...
def get_sample(self, i):
sample = super().get_sample(i)
sample['mask'] = self.read_mask(i)
sample['label'] = self.read_label(i)
return sample
def read_mask(self, i)
mask = ...
return mask
def read_label(self, i):
label = ...
return label
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment