Skip to content

Instantly share code, notes, and snippets.

@andrewjong
Last active August 27, 2023 18:43
Star You must be signed in to star a gist
Embed
What would you like to do?
PyTorch Image File Paths With Dataset Dataloader
import torch
from torchvision import datasets
class ImageFolderWithPaths(datasets.ImageFolder):
"""Custom dataset that includes image file paths. Extends
torchvision.datasets.ImageFolder
"""
# override the __getitem__ method. this is the method that dataloader calls
def __getitem__(self, index):
# this is what ImageFolder normally returns
original_tuple = super(ImageFolderWithPaths, self).__getitem__(index)
# the image file path
path = self.imgs[index][0]
# make a new tuple that includes original and the path
tuple_with_path = (original_tuple + (path,))
return tuple_with_path
# EXAMPLE USAGE:
# instantiate the dataset and dataloader
data_dir = "your/data_dir/here"
dataset = ImageFolderWithPaths(data_dir) # our custom dataset
dataloader = torch.utils.DataLoader(dataset)
# iterate over data
for inputs, labels, paths in dataloader:
# use the above variables freely
print(inputs, labels, paths)
@monajalal
Copy link

fixed it. figured you are looking for folders within the initial path :)

@maverickdas
Copy link

Works really well ^_^

@shanej199127
Copy link

easily understandable..

@tcrapse
Copy link

tcrapse commented May 29, 2019

excellent contribution. thank you.

@junkwhinger
Copy link

Thank you

@jkanti
Copy link

jkanti commented Aug 13, 2019

Hi,
This code normally works fine. But When I use another package https://github.com/ufoym/imbalanced-dataset-sampler It throws error.

Can you suggest a work around, how to use your code snippets to work with that package ?

@DonghunP
Copy link

Can you make it as below version??

class ImageFolderWithPaths(Dataset):
...

@ramcandrews
Copy link

Thanks!

@flydragon2018
Copy link

which version of torch? my torch 1.3.1 not working.
torch.utils.data.DataLoader()
TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class 'PIL.Image.Image'>

@abhigarg
Copy link

abhigarg commented Feb 29, 2020

class ImageFolderWithPaths(torchvision.datasets.ImageFolder):    
    def __getitem__(self, index):
        original_tuple = super(ImageFolderWithPaths, self).__getitem__(index)
        path = self.imgs[index][0]
        tuple_with_path = (original_tuple + (path,))
        return tuple_with_path

dataset = ImageFolderWithPaths(root=data_dir, transform=test_transformer)
data_set_data_dir = DataLoader(dataset=dataset)

for i, data in enumerate(data_set_data_dir):
  images,labels,paths = data
  print(paths[0])
  break

It works for me in this way

@tehreemnaqvi
Copy link

tehreemnaqvi commented Jun 22, 2020

import numpy as np
import matplotlib.pyplot as plt
import os

##root = "/Users/Tehreem/Desktop/CV LAB DATA/SNN important codes/DSNN-master/DSNN-master/new_images"
dataset = ImageFolder("/Users/Tehreem/Desktop/SpykeTorch-master/new_images",transform) # adding transform to the dataset
plt.style.use('seaborn-white')
plt_idx = 0
sw = dataset[sample_idx][0]
for f in range(4):
for t in range(5):
plt_idx += 1
ax = plt.subplot(5, 5, plt_idx)
plt.setp(ax, xticklabels=[])
plt.setp(ax, yticklabels=[])
if t == 0:
ax.set_ylabel('Feature ' + str(f))
plt.imshow(sw[t,f].numpy(),cmap='gray')
if f == 3:
ax = plt.subplot(5, 5, plt_idx + 5)
plt.setp(ax, xticklabels=[])
plt.setp(ax, yticklabels=[])
if t == 0:
ax.set_ylabel('Sum')
ax.set_xlabel('t = ' + str(t))
plt.imshow(sw[t].sum(dim=0).numpy(),cmap='gray')
plt.show()

I got this error. Anybody tell me. What's the issue??
Thanks

@a7906375
Copy link

I am getting this error message:

Traceback (most recent call last):
  File "file_location.py", line 22, in <module>
    dataset = ImageFolderWithPaths(data_dir) # our custom dataset
  File "/Users/nubstech/opt/anaconda3/envs/Cells_Counting/lib/python3.7/site-packages/torchvision/datasets/folder.py", line 206, in __init__
    is_valid_file=is_valid_file)
  File "/Users/nubstech/opt/anaconda3/envs/Cells_Counting/lib/python3.7/site-packages/torchvision/datasets/folder.py", line 98, in __init__
    "Supported extensions are: " + ",".join(extensions)))
RuntimeError: Found 0 files in subfolders of: Eddata/Healthy_curated
Supported extensions are: .jpg,.jpeg,.png,.ppm,.bmp,.pgm,.tif,.tiff,.webp

@tehreemnaqvi
Copy link

I got the same error.

@a7906375
Copy link

The issue is related the images folder location but I'm unable to figure it out.

@soumendra
Copy link

@flydragon2018 you need to add ToTensor() to your augmentation pipeline.

@soumendra
Copy link

@a7906375 @tehreemnaqvi if your data_dir is a pathlib.Path, you need to apply str() before passing it to ImageFolderWithPaths

@soumendra
Copy link

Here is a concise version that I can confirm works

class ImageFolderWithPaths(ImageFolder):
    def __getitem__(self, index):
        return super(ImageFolderWithPaths, self).__getitem__(index) + (self.imgs[index][0],)

@RizwanShaukat936
Copy link

`import torch
from torchvision import *

transforms = transforms.Compose([
transforms.Resize((64,64)),
transforms.ToTensor()
])
class ImageFolderWithPaths(datasets.ImageFolder):

def getitem(self, index):
# this is what ImageFolder normally returns
original_tuple = super(ImageFolderWithPaths, self).getitem(index)
# the image file path
path = self.imgs[index][0]
tuple_with_path = (original_tuple + (path,))
return tuple_with_path
data_dir = "./sig_datasets/"
dataset = ImageFolderWithPaths(data_dir, transform=transforms)
dataloader = torch.utils.data.DataLoader(dataset)

iterate over data
for i, data in enumerate(dataloader):
images,labels,paths = data
print(images)
break`

This code worked for me.

@jshtok
Copy link

jshtok commented Mar 19, 2021

Works out of the box. Thanks!

@kimseunghyuck
Copy link

You are my hero! thank you!

@eformx
Copy link

eformx commented Dec 16, 2021

How would I modify this to isolate files with a wildcard? For example if I wanted to isolate all image files that start with vid_1234.

@Lucylucy712
Copy link

Wondeeful! You save my day!

@realliyifei
Copy link

Thanks; hard to imagine that ImageFolder doesn't have this function / flag

@andrea137
Copy link

May I ask under what license this snippet is released?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment