Skip to content

Instantly share code, notes, and snippets.

@justusschock
Last active May 8, 2018 14:07
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save justusschock/6f9c55e423db2f39e9ca93100a74b515 to your computer and use it in GitHub Desktop.
Save justusschock/6f9c55e423db2f39e9ca93100a74b515 to your computer and use it in GitHub Desktop.
Gist to show examples of doing own dataloader and imagefolder in pytorch
class DataLoader(object):
def __init__(self, options, data_path, load_fkt, shuffle, return_paths):
self.pairedData = None
self.initialize(options, load_fkt, data, shuffle, return_paths)
def initialize(self, options, load_fkt, data_path, shuffle, return_paths):
pass
def load_data(self):
"""
Function to load one dataPair
:return: Paired data
"""
return self.pairedData
@staticmethod
def name():
"""
Function to get class name
:return: class name 8string)
"""
return 'BaseDataLoader'
def __len__(self):
pass
class AlignedDataLoader(DataLoader):
def __init__(self, options, data_path, load_fkt, shuffle=True, return_paths=True):
self.dataset = None
super(AlignedDataLoader, self).__init__(options, data_path, load_fkt, shuffle, return_paths)
self.initialize(options, load_fkt, shuffle, return_paths)
def initialize(self, options, load_fkt, data_path, shuffle, return_paths):
if options.inputNc == 1:
norm = transforms.Normalize([0.5], [0.5])
else:
norm = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
transform = transforms.Compose([transforms.ToTensor(), norm])
dataset = ImageFolder(data_path, loader=load_fkt, transform=transform, return_paths=return_paths)
data_loader = data.DataLoader(dataset, batch_size=options.batchSize, shuffle=shuffle, num_workers=0)
self.dataset = dataset
self.pairedData = AlignedPairedData(data_loader, return_paths)
@staticmethod
def name():
return 'AlignedDataLoader'
def __len__(self):
return len(self.dataset)
class UnalignedDataLoader(DataLoader):
"""Class to handle data load process of a dataset"""
def __init__(self, options, data_path, load_fkt=image_load_fkt_pillow_unaligned, shuffle=True, return_paths=True):
"""
Function to create class variables
:param options: class containing options (args of BaseOptions or subclass)
:param data_path: path containing the dataset
:param load_fkt: function to load the data
:param shuffle: True if random item of dataset should be loaded, False otherwise
:param return_paths: True if paths should be returned alongside data, False otherwise
"""
self.datasetA = None
self.datasetB = None
super(UnalignedDataLoader, self).__init__(options, data_path, load_fkt, shuffle, return_paths)
self.initialize(options, load_fkt, data_path, shuffle, return_paths)
def initialize(self, options, load_fkt, data_path, shuffle, return_paths):
"""
Function to initialize class variables
:param options: class containing options (args of BaseOptions or subclass)
:param data_path: path containing the dataset
:param load_fkt: function to load the data
:param shuffle: True if random item of dataset should be loaded, False otherwise
:param return_paths: True if paths should be returned alongside data, False otherwise
:return None
"""
if options.inputNc == 1:
norm = transforms.Normalize([0.5], [0.5])
else:
norm = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
transform = transforms.Compose([transforms.ToTensor(), norm])
datasetA = ImageFolder(data_path + "/A", options, loader=load_fkt, transform=transform,
return_paths=return_paths)
datasetB = ImageFolder(data_path + "/B", options, loader=load_fkt, transform=transform,
return_paths=return_paths)
data_loader_a = data.DataLoader(dataset=datasetA, batch_size=options.batchSize, shuffle=shuffle, num_workers=0)
data_loader_b = data.DataLoader(dataset=datasetB, batch_size=options.batchSize, shuffle=shuffle, num_workers=0)
self.datasetA = datasetA
self.datasetB = datasetB
self.pairedData = UnalignedPairedData(data_loader_a, data_loader_b, return_paths=return_paths)
@staticmethod
def name():
"""
Function to get class name
:return: class name 8string)
"""
return 'UnalignedDataLoader'
def __len__(self):
"""
Function to get the maximum number of items in the datasets
:return: number of items
"""
return max(len(self.datasetA), len(self.datasetB))
import numpy as np
import random
import torch.utils.data as data
from PIL import Image
import os
import os.path
from torchvision import transforms
import torch
from utility_functions import *
class ImageFolder(data.Dataset):
"""Class for handling image load process and transformations"""
def __init__(self, image_path, options, transform=None, return_paths=True,
loader=default_loader_unaligned):
"""
Function to create the dataset and initialize the class variables
:param image_path: path containing image-files
:param options: class containing all options (args of BaseOptions or subclass)
:param transform: transformation to apply on the Image after loading it
:param return_paths: Boolean, True if paths should be returned alongside images , False if only images
:param loader: function to load and resize images
"""
imgs = make_dataset(image_path)
if len(imgs) == 0:
raise(RuntimeError("Found 0 images in: " + image_path + "\n"
"Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
self.root = image_path
self.imgs = imgs
self.transform = transform
self.return_paths = return_paths
self.loader = loader
self.options = options
def __getitem__(self, index):
"""
Function to get certain item in dataset
:param index: index of dataset-list
:return: item in dataset with given index
"""
path = self.imgs[index]
img = self.loader(path, self.options.imageSize, self.options.inputNc)
if self.transform is not None:
img = self.transform(img)
if self.return_paths:
return img, path
else:
return img
def __len__(self):
"""Function to get number of items in dataset"""
return len(self.imgs)
import numpy as np
import random
import torch.utils.data as data
from PIL import Image
import os
import os.path
from torchvision import transforms
import torch
from utility_functions import *
from image_folder import *
class PairedData(object):
def __init__(self, return_paths):
self.return_paths = return_paths
self.iter = None
def __iter__(self):
pass
def __next__(self):
pass
class AlignedPairedData(PairedData):
def __init__(self, data_loader, return_paths):
super(AlignedPairedData, self).__init__(return_paths)
self.data_loader = data_loader
self.data_loader_iter = None
self.stop = False
def __iter__(self):
self.data_loader_iter = iter(self.data_loader)
self.iter = 0
def __next__(self):
if self.return_paths:
data, data_path = next(self.data_loader_iter)
self.iter += 1
return {'A': data[0], 'B': data[1], 'A_Path': data_path[0], 'B_Path': data_path[1]}
else:
data = next(self.data_loader_iter)
self.iter += 1
return {'A': data[0], 'B': data[1]}
class UnalignedPairedData(PairedData):
"""Class to combine two items of 2 datasets"""
def __init__(self, data_loader_a, data_loader_b, return_paths=True):
"""Function to initialize and create class variables"""
super(UnalignedPairedData, self).__init__(return_paths)
self.dataLoaderA = data_loader_a
self.dataLoaderB = data_loader_b
self.dataLoaderAIter = None
self.dataLoaderBIter = None
self.stopA = False
self.stopB = False
def __iter__(self):
"""
Function to iterate through datasets
:return: self
"""
self.stopA = False
self.stopB = False
self.dataLoaderAIter = iter(self.dataLoaderA)
self.dataLoaderBIter = iter(self.dataLoaderB)
self.iter = 0
return self
def __next__(self):
"""
Function to get next items of datasets
:return: Dictionary containing the items
"""
if self.return_paths:
a, a_path = None, None
b, b_path = None, None
try:
a, a_path = next(self.dataLoaderAIter)
except StopIteration:
if a is None or a_path is None:
self.stopA = True
self.dataLoaderAIter = iter(self.dataLoaderA)
a, a_path = next(self.dataLoaderAIter)
try:
b, b_path = next(self.dataLoaderBIter)
except StopIteration:
if b is None or b_path is None:
self.stopB = True
self.dataLoaderBIter = iter(self.dataLoaderB)
b, b_path = next(self.dataLoaderBIter)
if self.stopA and self.stopB:
self.stopA = False
self.stopB = False
raise StopIteration()
else:
self.iter += 1
return {'A': a, 'B': b, 'A_Path': a_path, 'B_Path': b_path}
else:
a = None
b = None
try:
a = next(self.dataLoaderAIter)
except StopIteration:
if a is None:
self.stopA = True
self.dataLoaderAIter = iter(self.dataLoaderA)
a = next(self.dataLoaderAIter)
try:
b = next(self.dataLoaderBIter)
except StopIteration:
if b is None:
self.stopB = True
self.dataLoaderBIter = iter(self.dataLoaderB)
b = next(self.dataLoaderBIter)
if self.stopA and self.stopB:
self.stopA = False
self.stopB = False
raise StopIteration()
else:
self.iter += 1
return {'A': a, 'B': b}
import numpy as np
import random
import torch.utils.data as data
from PIL import Image
import os
import os.path
from torchvision import transforms
import torch
def image_load_fkt_pillow_unaligned(filename, desired_size, n_channels=1):
"""
Function to load and resize images with one or 3 channels
:param filename: name of the file the image should be loaded from
:param desired_size: size and width the image should be resized to (currently only same width and height are supported)
:param n_channels: number of color channels
:return: image as numpy array
"""
img = Image.open(filename).convert('RGB')
if n_channels == 1:
img = img.convert('L')
img = img.resize(size=(desired_size, desired_size), resample=Image.BILINEAR)
img = np.array(img)
img = np.reshape(img, (img.shape[0], img.shape[1], n_channels))
return img
def image_load_fkt_pillow_aligned_label_x(filename, desired_size, n_channels=1):
img = Image.open(filename).convert('RGB')
if n_channels == 1:
img = img.convert('L')
img = img.resize(size=(desired_size, desired_size), resample=Image.BILINEAR)
img = np.array(img)
img = np.reshape(img, (img.shape[0], img.shape[1], n_channels))
label_file = str(filename).rsplit(".", maxsplit=1)[1] + ".txt"
with open(label_file, "r") as f:
label_lines = f.readlines()
for line in label_lines:
label_x = int(line.split(",")[0])
label_y = int(line.split(",")[1])
if str(filename).rsplit(".", maxsplit=1)[1].endswith("a"):
endstring = "b"
else:
endstring = "a"
opposite_data_set = [x for x in os.listdir(os.path.split(filename)[0]) if x.endswith(endstring)]
possible_files = []
for file in opposite_data_set:
file_path = os.path.join(os.path.split(filename)[0], file)
with open(str(file_path.rsplit(".", maxsplit=1)[1]) + ".txt", "r") as f:
opposite_label_lines = f.readlines()
for line in opposite_label_lines:
opposite_label_x = int(line.split(",")[0])
opposite_label_y = int(line.split(",")[1])
if -5 <= (label_x - opposite_label_x) <= 5:
possible_files.append(file_path)
aligned_img = Image.open(random.choice(possible_files))
if n_channels == 1:
aligned_img = aligned_img.convert('L')
aligned_img = aligned_img.resize(size=(desired_size, desired_size), resample=Image.BILINEAR)
aligned_img = np.array(aligned_img)
aligned_img = np.reshape(aligned_img, (aligned_img.shape[0], aligned_img.shape[1], n_channels))
return [img, aligned_img]
IMG_EXTENSIONS = [
'.jpg', '.JPG', '.jpeg', '.JPEG',
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
]
def is_image_file(filename):
"""
Helper Function to determine whether a file is an image file or not
:param filename: the filename containing a possible image
:return: True if file is image file, False otherwise
"""
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
def make_dataset(dir):
"""
Helper Function to make a dataset containing all images in a certain directory
:param dir: the directory containing the dataset
:return: images: list of image paths
"""
images = []
assert os.path.isdir(dir), '%s is not a valid directory' % dir
for root, _, fnames in sorted(os.walk(dir)):
for fname in fnames:
if is_image_file(fname):
path = os.path.join(root, fname)
images.append(path)
return images
def default_loader_unaligned(path):
"""
Helper function to load an Image with PIL
:param path: path of image file
:return: loaded image in RGB mode as PIL Image
"""
return Image.open(path).convert('RGB')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment