Last active
January 4, 2022 02:51
-
-
Save mayukh18/250a679cd3d2635e4906a8acc1b4deea to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class TripletData(Dataset): | |
def __init__(self, path, transforms, split="train"): | |
self.path = path | |
self.split = split # train or valid | |
self.cats = 102 # number of categories | |
self.transforms = transforms | |
def __getitem__(self, idx): | |
# our positive class for the triplet | |
idx = str(idx%self.cats + 1) | |
# choosing our pair of positive images (im1, im2) | |
positives = os.listdir(os.path.join(self.path, idx)) | |
im1, im2 = random.sample(positives, 2) | |
# choosing a negative class and negative image (im3) | |
negative_cats = [str(x+1) for x in range(self.cats)] | |
negative_cats.remove(idx) | |
negative_cat = str(random.choice(negative_cats)) | |
negatives = os.listdir(os.path.join(self.path, negative_cat)) | |
im3 = random.choice(negatives) | |
im1,im2,im3 = os.path.join(self.path, idx, im1), os.path.join(self.path, idx, im2), os.path.join(self.path, negative_cat, im3) | |
im1 = self.transforms(Image.open(im1)) | |
im2 = self.transforms(Image.open(im2)) | |
im3 = self.transforms(Image.open(im3)) | |
return [im1, im2, im3] | |
# we'll put some value that we want since there can be far too many triplets possible | |
# multiples of the number of images/ number of categories is a good choice | |
def __len__(self): | |
return self.cats*8 | |
# Transforms | |
train_transforms = transforms.Compose([ | |
transforms.Resize((224,224)), | |
transforms.RandomHorizontalFlip(), | |
transforms.ToTensor(), | |
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), | |
]) | |
val_transforms = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), | |
]) | |
# Datasets and Dataloaders | |
train_data = TripletData(PATH_TRAIN, train_transforms) | |
val_data = TripletData(PATH_VALID, val_transforms) | |
train_loader = torch.utils.data.DataLoader(dataset = train_data, batch_size=32, shuffle=True, num_workers=4) | |
val_loader = torch.utils.data.DataLoader(dataset = val_data, batch_size=32, shuffle=False, num_workers=4) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment