PyTorch Image File Paths With Dataset Dataloader
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torchvision import datasets | |
class ImageFolderWithPaths(datasets.ImageFolder): | |
"""Custom dataset that includes image file paths. Extends | |
torchvision.datasets.ImageFolder | |
""" | |
# override the __getitem__ method. this is the method that dataloader calls | |
def __getitem__(self, index): | |
# this is what ImageFolder normally returns | |
original_tuple = super(ImageFolderWithPaths, self).__getitem__(index) | |
# the image file path | |
path = self.imgs[index][0] | |
# make a new tuple that includes original and the path | |
tuple_with_path = (original_tuple + (path,)) | |
return tuple_with_path | |
# EXAMPLE USAGE: | |
# instantiate the dataset and dataloader | |
data_dir = "your/data_dir/here" | |
dataset = ImageFolderWithPaths(data_dir) # our custom dataset | |
dataloader = torch.utils.DataLoader(dataset) | |
# iterate over data | |
for inputs, labels, paths in dataloader: | |
# use the above variables freely | |
print(inputs, labels, paths) |
Works out of the box. Thanks!
You are my hero! thank you!
How would I modify this to isolate files with a wildcard? For example if I wanted to isolate all image files that start with vid_1234.
Wondeeful! You save my day!
Thanks; hard to imagine that ImageFolder doesn't have this function / flag
May I ask under what license this snippet is released?
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
`import torch
from torchvision import *
transforms = transforms.Compose([
transforms.Resize((64,64)),
transforms.ToTensor()
])
class ImageFolderWithPaths(datasets.ImageFolder):
def getitem(self, index):
# this is what ImageFolder normally returns
original_tuple = super(ImageFolderWithPaths, self).getitem(index)
# the image file path
path = self.imgs[index][0]
tuple_with_path = (original_tuple + (path,))
return tuple_with_path
data_dir = "./sig_datasets/"
dataset = ImageFolderWithPaths(data_dir, transform=transforms)
dataloader = torch.utils.data.DataLoader(dataset)
iterate over data
for i, data in enumerate(dataloader):
images,labels,paths = data
print(images)
break`
This code worked for me.