Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save KennyKang7012/cd37a57dbf947e7fba4c31828af3477a to your computer and use it in GitHub Desktop.
Save KennyKang7012/cd37a57dbf947e7fba4c31828af3477a to your computer and use it in GitHub Desktop.
PyTorch Image File Paths With Dataset Dataloader
import torch
import torchvision
from torchvision import datasets, transforms
transforms = transforms.Compose([
transforms.Resize((64,64)),
transforms.ToTensor()
])
class ImageFolderWithPaths(torchvision.datasets.ImageFolder)):
"""Custom dataset that includes image file paths. Extends
torchvision.datasets.ImageFolder
"""
# override the __getitem__ method. this is the method that dataloader calls
def __getitem__(self, index):
# this is what ImageFolder normally returns
original_tuple = super(ImageFolderWithPaths, self).__getitem__(index)
# the image file path
path = self.imgs[index][0]
# make a new tuple that includes original and the path
tuple_with_path = (original_tuple + (path,))
return tuple_with_path
# EXAMPLE USAGE:
# instantiate the dataset and dataloader
# data_dir = './dog_vs_cat/'
# data_dir = './dog_vs_cat/train/'
data_dir = "your/data_dir/here"
dataset = ImageFolderWithPaths(data_dir, transform=transforms) # our custom dataset
dataloader = torch.utils.DataLoader(dataset)
# iterate over data
for inputs, labels, paths in dataloader:
# use the above variables freely
print(inputs, labels, paths)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment