Skip to content

Instantly share code, notes, and snippets.

@andrewjong
Last active August 3, 2024 08:02
Show Gist options
  • Save andrewjong/6b02ff237533b3b2c554701fb53d5c4d to your computer and use it in GitHub Desktop.
Save andrewjong/6b02ff237533b3b2c554701fb53d5c4d to your computer and use it in GitHub Desktop.
PyTorch Image File Paths With Dataset Dataloader
import torch
from torchvision import datasets
class ImageFolderWithPaths(datasets.ImageFolder):
"""Custom dataset that includes image file paths. Extends
torchvision.datasets.ImageFolder
"""
# override the __getitem__ method. this is the method that dataloader calls
def __getitem__(self, index):
# this is what ImageFolder normally returns
original_tuple = super(ImageFolderWithPaths, self).__getitem__(index)
# the image file path
path = self.imgs[index][0]
# make a new tuple that includes original and the path
tuple_with_path = (original_tuple + (path,))
return tuple_with_path
# EXAMPLE USAGE:
# instantiate the dataset and dataloader
data_dir = "your/data_dir/here"
dataset = ImageFolderWithPaths(data_dir) # our custom dataset
dataloader = torch.utils.DataLoader(dataset)
# iterate over data
for inputs, labels, paths in dataloader:
# use the above variables freely
print(inputs, labels, paths)
@Lucylucy712
Copy link

Wondeeful! You save my day!

@realliyifei
Copy link

Thanks; hard to imagine that ImageFolder doesn't have this function / flag

@andrea137
Copy link

May I ask under what license this snippet is released?

@KennyKang7012
Copy link

KennyKang7012 commented Feb 27, 2024

import torch
import torchvision
from torchvision import datasets, transforms

transforms = transforms.Compose([
transforms.Resize((64,64)),
transforms.ToTensor()
])

class ImageFolderWithPaths(torchvision.datasets.ImageFolder):
"""Custom dataset that includes image file paths. Extends
torchvision.datasets.ImageFolder
"""
#override the getitem method. this is the method that dataloader calls
def getitem(self, index):
# this is what ImageFolder normally returns
original_tuple = super(ImageFolderWithPaths, self).getitem(index)
# the image file path
path = self.imgs[index][0]
# make a new tuple that includes original and the path
tuple_with_path = (original_tuple + (path,))
return tuple_with_path

#EXAMPLE USAGE:
#instantiate the dataset and dataloader
data_dir = './dog_vs_cat/train/'
dataset = ImageFolderWithPaths(data_dir, transform=transforms) # our custom dataset
dataloader = torch.utils.data.DataLoader(dataset)

#iterate over data
for inputs, labels, paths in dataloader:
# use the above variables freely
print(inputs, labels, paths)

This code worked for me.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment