Skip to content

Instantly share code, notes, and snippets.

@aliwaqas333
Last active June 19, 2020 09:24
Show Gist options
  • Save aliwaqas333/f9aa4209adb79bf840ebc8bb11f66c4b to your computer and use it in GitHub Desktop.
Save aliwaqas333/f9aa4209adb79bf840ebc8bb11f66c4b to your computer and use it in GitHub Desktop.
We will use pytorch Dataset Class to create our dataset
from torch.utils.data.dataset import Dataset
from torchvision.transforms import Compose, ToTensor
class MaskDataset(Dataset):
""" Masked faces dataset
0 = 'no mask'
1 = 'mask'
"""
def __init__(self, train_data):
self.train_data = train_data
self.transformations = Compose([
ToTensor(), # [0, 1]
])
def __getitem__(self, key):
if isinstance(key, slice):
raise NotImplementedError('slicing is not supported')
return [
self.transformations(self.train_data[key][0]),
torch.tensor(self.train_data[key][1]) # pylint: disable=not-callable
]
def __len__(self):
return len(self.train_data)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment