-
-
Save andrewjong/6b02ff237533b3b2c554701fb53d5c4d to your computer and use it in GitHub Desktop.
import torch | |
from torchvision import datasets | |
class ImageFolderWithPaths(datasets.ImageFolder): | |
"""Custom dataset that includes image file paths. Extends | |
torchvision.datasets.ImageFolder | |
""" | |
# override the __getitem__ method. this is the method that dataloader calls | |
def __getitem__(self, index): | |
# this is what ImageFolder normally returns | |
original_tuple = super(ImageFolderWithPaths, self).__getitem__(index) | |
# the image file path | |
path = self.imgs[index][0] | |
# make a new tuple that includes original and the path | |
tuple_with_path = (original_tuple + (path,)) | |
return tuple_with_path | |
# EXAMPLE USAGE: | |
# instantiate the dataset and dataloader | |
data_dir = "your/data_dir/here" | |
dataset = ImageFolderWithPaths(data_dir) # our custom dataset | |
dataloader = torch.utils.DataLoader(dataset) | |
# iterate over data | |
for inputs, labels, paths in dataloader: | |
# use the above variables freely | |
print(inputs, labels, paths) |
Hi,
This code normally works fine. But When I use another package https://github.com/ufoym/imbalanced-dataset-sampler It throws error.
Can you suggest a work around, how to use your code snippets to work with that package ?
Can you make it as below version??
class ImageFolderWithPaths(Dataset):
...
Thanks!
which version of torch? my torch 1.3.1 not working.
torch.utils.data.DataLoader()
TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class 'PIL.Image.Image'>
class ImageFolderWithPaths(torchvision.datasets.ImageFolder):
def __getitem__(self, index):
original_tuple = super(ImageFolderWithPaths, self).__getitem__(index)
path = self.imgs[index][0]
tuple_with_path = (original_tuple + (path,))
return tuple_with_path
dataset = ImageFolderWithPaths(root=data_dir, transform=test_transformer)
data_set_data_dir = DataLoader(dataset=dataset)
for i, data in enumerate(data_set_data_dir):
images,labels,paths = data
print(paths[0])
break
It works for me in this way
import numpy as np
import matplotlib.pyplot as plt
import os
##root = "/Users/Tehreem/Desktop/CV LAB DATA/SNN important codes/DSNN-master/DSNN-master/new_images"
dataset = ImageFolder("/Users/Tehreem/Desktop/SpykeTorch-master/new_images",transform) # adding transform to the dataset
plt.style.use('seaborn-white')
plt_idx = 0
sw = dataset[sample_idx][0]
for f in range(4):
for t in range(5):
plt_idx += 1
ax = plt.subplot(5, 5, plt_idx)
plt.setp(ax, xticklabels=[])
plt.setp(ax, yticklabels=[])
if t == 0:
ax.set_ylabel('Feature ' + str(f))
plt.imshow(sw[t,f].numpy(),cmap='gray')
if f == 3:
ax = plt.subplot(5, 5, plt_idx + 5)
plt.setp(ax, xticklabels=[])
plt.setp(ax, yticklabels=[])
if t == 0:
ax.set_ylabel('Sum')
ax.set_xlabel('t = ' + str(t))
plt.imshow(sw[t].sum(dim=0).numpy(),cmap='gray')
plt.show()
I got this error. Anybody tell me. What's the issue??
Thanks
I am getting this error message:
Traceback (most recent call last):
File "file_location.py", line 22, in <module>
dataset = ImageFolderWithPaths(data_dir) # our custom dataset
File "/Users/nubstech/opt/anaconda3/envs/Cells_Counting/lib/python3.7/site-packages/torchvision/datasets/folder.py", line 206, in __init__
is_valid_file=is_valid_file)
File "/Users/nubstech/opt/anaconda3/envs/Cells_Counting/lib/python3.7/site-packages/torchvision/datasets/folder.py", line 98, in __init__
"Supported extensions are: " + ",".join(extensions)))
RuntimeError: Found 0 files in subfolders of: Eddata/Healthy_curated
Supported extensions are: .jpg,.jpeg,.png,.ppm,.bmp,.pgm,.tif,.tiff,.webp
I got the same error.
The issue is related the images folder location but I'm unable to figure it out.
@flydragon2018 you need to add ToTensor()
to your augmentation pipeline.
@a7906375 @tehreemnaqvi if your data_dir
is a pathlib.Path
, you need to apply str()
before passing it to ImageFolderWithPaths
Here is a concise version that I can confirm works
class ImageFolderWithPaths(ImageFolder):
def __getitem__(self, index):
return super(ImageFolderWithPaths, self).__getitem__(index) + (self.imgs[index][0],)
`import torch
from torchvision import *
transforms = transforms.Compose([
transforms.Resize((64,64)),
transforms.ToTensor()
])
class ImageFolderWithPaths(datasets.ImageFolder):
def getitem(self, index):
# this is what ImageFolder normally returns
original_tuple = super(ImageFolderWithPaths, self).getitem(index)
# the image file path
path = self.imgs[index][0]
tuple_with_path = (original_tuple + (path,))
return tuple_with_path
data_dir = "./sig_datasets/"
dataset = ImageFolderWithPaths(data_dir, transform=transforms)
dataloader = torch.utils.data.DataLoader(dataset)
iterate over data
for i, data in enumerate(dataloader):
images,labels,paths = data
print(images)
break`
This code worked for me.
Works out of the box. Thanks!
You are my hero! thank you!
How would I modify this to isolate files with a wildcard? For example if I wanted to isolate all image files that start with vid_1234.
Wondeeful! You save my day!
Thanks; hard to imagine that ImageFolder doesn't have this function / flag
May I ask under what license this snippet is released?
import torch
import torchvision
from torchvision import datasets, transforms
transforms = transforms.Compose([
transforms.Resize((64,64)),
transforms.ToTensor()
])
class ImageFolderWithPaths(torchvision.datasets.ImageFolder):
"""Custom dataset that includes image file paths. Extends
torchvision.datasets.ImageFolder
"""
#override the getitem method. this is the method that dataloader calls
def getitem(self, index):
# this is what ImageFolder normally returns
original_tuple = super(ImageFolderWithPaths, self).getitem(index)
# the image file path
path = self.imgs[index][0]
# make a new tuple that includes original and the path
tuple_with_path = (original_tuple + (path,))
return tuple_with_path
#EXAMPLE USAGE:
#instantiate the dataset and dataloader
data_dir = './dog_vs_cat/train/'
dataset = ImageFolderWithPaths(data_dir, transform=transforms) # our custom dataset
dataloader = torch.utils.data.DataLoader(dataset)
#iterate over data
for inputs, labels, paths in dataloader:
# use the above variables freely
print(inputs, labels, paths)
This code worked for me.
Thank you