Skip to content

Instantly share code, notes, and snippets.

@pranshuj73
Created July 3, 2020 11:09
Show Gist options
  • Save pranshuj73/695e9bd0621c2cbf0593adebad05495e to your computer and use it in GitHub Desktop.
Save pranshuj73/695e9bd0621c2cbf0593adebad05495e to your computer and use it in GitHub Desktop.
# Creating the class for our dataset for the FER
class FERDataset(Dataset):
def __init__(self, images, labels, transforms):
self.X = images
self.y = labels
self.transforms = transforms
def __len__(self):
return len(self.X)
def __getitem__(self, i):
data = [int(m) for m in self.X[i].split(' ')]
data = np.asarray(data).astype(np.uint8).reshape(48,48,1)
data = self.transforms(data)
label = self.y[i]
return (data, label)
# assigning the transformed data
train_data = FERDataset(train_images, train_labels, train_trfm)
val_data = FERDataset(test_images, test_labels, val_trfm)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment