Skip to content

Instantly share code, notes, and snippets.

@andrewjong
Last active February 27, 2024 09:24
Show Gist options
  • Save andrewjong/6b02ff237533b3b2c554701fb53d5c4d to your computer and use it in GitHub Desktop.
Save andrewjong/6b02ff237533b3b2c554701fb53d5c4d to your computer and use it in GitHub Desktop.
PyTorch Image File Paths With Dataset Dataloader
import torch
from torchvision import datasets
class ImageFolderWithPaths(datasets.ImageFolder):
"""Custom dataset that includes image file paths. Extends
torchvision.datasets.ImageFolder
"""
# override the __getitem__ method. this is the method that dataloader calls
def __getitem__(self, index):
# this is what ImageFolder normally returns
original_tuple = super(ImageFolderWithPaths, self).__getitem__(index)
# the image file path
path = self.imgs[index][0]
# make a new tuple that includes original and the path
tuple_with_path = (original_tuple + (path,))
return tuple_with_path
# EXAMPLE USAGE:
# instantiate the dataset and dataloader
data_dir = "your/data_dir/here"
dataset = ImageFolderWithPaths(data_dir) # our custom dataset
dataloader = torch.utils.DataLoader(dataset)
# iterate over data
for inputs, labels, paths in dataloader:
# use the above variables freely
print(inputs, labels, paths)
@shanej199127
Copy link

easily understandable..

@tcrapse
Copy link

tcrapse commented May 29, 2019

excellent contribution. thank you.

@junkwhinger
Copy link

Thank you

@jkanti
Copy link

jkanti commented Aug 13, 2019

Hi,
This code normally works fine. But When I use another package https://github.com/ufoym/imbalanced-dataset-sampler It throws error.

Can you suggest a work around, how to use your code snippets to work with that package ?

@DonghoonPark12
Copy link

Can you make it as below version??

class ImageFolderWithPaths(Dataset):
...

@ramcandrews
Copy link

Thanks!

@flydragon2018
Copy link

which version of torch? my torch 1.3.1 not working.
torch.utils.data.DataLoader()
TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class 'PIL.Image.Image'>

@abhigarg
Copy link

abhigarg commented Feb 29, 2020

class ImageFolderWithPaths(torchvision.datasets.ImageFolder):    
    def __getitem__(self, index):
        original_tuple = super(ImageFolderWithPaths, self).__getitem__(index)
        path = self.imgs[index][0]
        tuple_with_path = (original_tuple + (path,))
        return tuple_with_path

dataset = ImageFolderWithPaths(root=data_dir, transform=test_transformer)
data_set_data_dir = DataLoader(dataset=dataset)

for i, data in enumerate(data_set_data_dir):
  images,labels,paths = data
  print(paths[0])
  break

It works for me in this way

@tehreemnaqvi
Copy link

tehreemnaqvi commented Jun 22, 2020

import numpy as np
import matplotlib.pyplot as plt
import os

##root = "/Users/Tehreem/Desktop/CV LAB DATA/SNN important codes/DSNN-master/DSNN-master/new_images"
dataset = ImageFolder("/Users/Tehreem/Desktop/SpykeTorch-master/new_images",transform) # adding transform to the dataset
plt.style.use('seaborn-white')
plt_idx = 0
sw = dataset[sample_idx][0]
for f in range(4):
for t in range(5):
plt_idx += 1
ax = plt.subplot(5, 5, plt_idx)
plt.setp(ax, xticklabels=[])
plt.setp(ax, yticklabels=[])
if t == 0:
ax.set_ylabel('Feature ' + str(f))
plt.imshow(sw[t,f].numpy(),cmap='gray')
if f == 3:
ax = plt.subplot(5, 5, plt_idx + 5)
plt.setp(ax, xticklabels=[])
plt.setp(ax, yticklabels=[])
if t == 0:
ax.set_ylabel('Sum')
ax.set_xlabel('t = ' + str(t))
plt.imshow(sw[t].sum(dim=0).numpy(),cmap='gray')
plt.show()

I got this error. Anybody tell me. What's the issue??
Thanks

@a7906375
Copy link

I am getting this error message:

Traceback (most recent call last):
  File "file_location.py", line 22, in <module>
    dataset = ImageFolderWithPaths(data_dir) # our custom dataset
  File "/Users/nubstech/opt/anaconda3/envs/Cells_Counting/lib/python3.7/site-packages/torchvision/datasets/folder.py", line 206, in __init__
    is_valid_file=is_valid_file)
  File "/Users/nubstech/opt/anaconda3/envs/Cells_Counting/lib/python3.7/site-packages/torchvision/datasets/folder.py", line 98, in __init__
    "Supported extensions are: " + ",".join(extensions)))
RuntimeError: Found 0 files in subfolders of: Eddata/Healthy_curated
Supported extensions are: .jpg,.jpeg,.png,.ppm,.bmp,.pgm,.tif,.tiff,.webp

@tehreemnaqvi
Copy link

I got the same error.

@a7906375
Copy link

The issue is related the images folder location but I'm unable to figure it out.

@soumendra
Copy link

@flydragon2018 you need to add ToTensor() to your augmentation pipeline.

@soumendra
Copy link

@a7906375 @tehreemnaqvi if your data_dir is a pathlib.Path, you need to apply str() before passing it to ImageFolderWithPaths

@soumendra
Copy link

Here is a concise version that I can confirm works

class ImageFolderWithPaths(ImageFolder):
    def __getitem__(self, index):
        return super(ImageFolderWithPaths, self).__getitem__(index) + (self.imgs[index][0],)

@RizwanShaukat936
Copy link

`import torch
from torchvision import *

transforms = transforms.Compose([
transforms.Resize((64,64)),
transforms.ToTensor()
])
class ImageFolderWithPaths(datasets.ImageFolder):

def getitem(self, index):
# this is what ImageFolder normally returns
original_tuple = super(ImageFolderWithPaths, self).getitem(index)
# the image file path
path = self.imgs[index][0]
tuple_with_path = (original_tuple + (path,))
return tuple_with_path
data_dir = "./sig_datasets/"
dataset = ImageFolderWithPaths(data_dir, transform=transforms)
dataloader = torch.utils.data.DataLoader(dataset)

iterate over data
for i, data in enumerate(dataloader):
images,labels,paths = data
print(images)
break`

This code worked for me.

@jshtok
Copy link

jshtok commented Mar 19, 2021

Works out of the box. Thanks!

@kimseunghyuck
Copy link

You are my hero! thank you!

@eformx
Copy link

eformx commented Dec 16, 2021

How would I modify this to isolate files with a wildcard? For example if I wanted to isolate all image files that start with vid_1234.

@Lucylucy712
Copy link

Wondeeful! You save my day!

@realliyifei
Copy link

Thanks; hard to imagine that ImageFolder doesn't have this function / flag

@andrea137
Copy link

May I ask under what license this snippet is released?

@KennyKang7012
Copy link

KennyKang7012 commented Feb 27, 2024

import torch
import torchvision
from torchvision import datasets, transforms

transforms = transforms.Compose([
transforms.Resize((64,64)),
transforms.ToTensor()
])

class ImageFolderWithPaths(torchvision.datasets.ImageFolder):
"""Custom dataset that includes image file paths. Extends
torchvision.datasets.ImageFolder
"""
#override the getitem method. this is the method that dataloader calls
def getitem(self, index):
# this is what ImageFolder normally returns
original_tuple = super(ImageFolderWithPaths, self).getitem(index)
# the image file path
path = self.imgs[index][0]
# make a new tuple that includes original and the path
tuple_with_path = (original_tuple + (path,))
return tuple_with_path

#EXAMPLE USAGE:
#instantiate the dataset and dataloader
data_dir = './dog_vs_cat/train/'
dataset = ImageFolderWithPaths(data_dir, transform=transforms) # our custom dataset
dataloader = torch.utils.data.DataLoader(dataset)

#iterate over data
for inputs, labels, paths in dataloader:
# use the above variables freely
print(inputs, labels, paths)

This code worked for me.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment