Skip to content

Instantly share code, notes, and snippets.

@jaircastruita
Created March 5, 2021 06:39
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jaircastruita/92f4a582b559de422c537f578912f22e to your computer and use it in GitHub Desktop.
Save jaircastruita/92f4a582b559de422c537f578912f22e to your computer and use it in GitHub Desktop.
class PokemonDataset(Dataset):
def __init__(self, images, root_dir, imageset=None):
"""
pokemon dataset: loads image and target
"""
self.imageset = np.load(imageset, mmap_mode="r+") if not imageset is None else None
self.root_dir = root_dir
self.images = images
self.anchor_transform = transforms.Compose([
transforms.Resize([224, 224]),
transforms.RandomApply([
transforms.RandomRotation(35, fill=255),
transforms.ColorJitter(brightness=0.5, contrast=0.0, saturation=0.0, hue=0.0),
transforms.RandomHorizontalFlip(),
], p=1),
transforms.ToTensor(),
])
self.positive_transform = transforms.Compose([
transforms.Resize([224, 224]),
transforms.ToTensor(),
])
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
if not self.imageset is None:
img = self.imageset[idx]
img = torch.from_numpy(img)
positive_img = self.positive_transform(transforms.ToPILImage()(img))
anchor_img = self.anchor_transform(transforms.ToPILImage()(img))
return anchor_img, positive_img
else:
img = Image.open(os.path.join(self.root_dir, self.images[idx]))
img = img.convert("RGB")
anchor_img = self.anchor_transform(img)
positive_img = self.positive_transform(img)
return anchor_img, positive_img
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment