Skip to content

Instantly share code, notes, and snippets.

@mayukh18
Last active January 4, 2022 02:51
Show Gist options
  • Save mayukh18/250a679cd3d2635e4906a8acc1b4deea to your computer and use it in GitHub Desktop.
Save mayukh18/250a679cd3d2635e4906a8acc1b4deea to your computer and use it in GitHub Desktop.
class TripletData(Dataset):
def __init__(self, path, transforms, split="train"):
self.path = path
self.split = split # train or valid
self.cats = 102 # number of categories
self.transforms = transforms
def __getitem__(self, idx):
# our positive class for the triplet
idx = str(idx%self.cats + 1)
# choosing our pair of positive images (im1, im2)
positives = os.listdir(os.path.join(self.path, idx))
im1, im2 = random.sample(positives, 2)
# choosing a negative class and negative image (im3)
negative_cats = [str(x+1) for x in range(self.cats)]
negative_cats.remove(idx)
negative_cat = str(random.choice(negative_cats))
negatives = os.listdir(os.path.join(self.path, negative_cat))
im3 = random.choice(negatives)
im1,im2,im3 = os.path.join(self.path, idx, im1), os.path.join(self.path, idx, im2), os.path.join(self.path, negative_cat, im3)
im1 = self.transforms(Image.open(im1))
im2 = self.transforms(Image.open(im2))
im3 = self.transforms(Image.open(im3))
return [im1, im2, im3]
# we'll put some value that we want since there can be far too many triplets possible
# multiples of the number of images/ number of categories is a good choice
def __len__(self):
return self.cats*8
# Transforms
train_transforms = transforms.Compose([
transforms.Resize((224,224)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
val_transforms = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
# Datasets and Dataloaders
train_data = TripletData(PATH_TRAIN, train_transforms)
val_data = TripletData(PATH_VALID, val_transforms)
train_loader = torch.utils.data.DataLoader(dataset = train_data, batch_size=32, shuffle=True, num_workers=4)
val_loader = torch.utils.data.DataLoader(dataset = val_data, batch_size=32, shuffle=False, num_workers=4)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment