Skip to content

Instantly share code, notes, and snippets.

@davidbau
Last active Jan 8, 2019
Embed
What would you like to do?
import os
import torch.utils.data as data
from torchvision.datasets.folder import default_loader, is_image_file
from PIL import Image
def grayscale_loader(path):
with open(path, 'rb') as f:
return Image.open(f).convert('L')
class FeatureFolder(data.Dataset):
"""
A data loader that looks for parallel image filenames
photo/park/004234.jpg
photo/park/004236.jpg
photo/park/004237.jpg
feature/park/004234.png
feature/park/004236.png
feature/park/004237.png
"""
def __init__(self, source_root, target_root,
source_transform=None, target_transform=None,
source_loader=default_loader, target_loader=grayscale_loader):
self.imagepairs = make_feature_dataset(source_root, target_root)
if len(self.imagepairs) == 0:
raise RuntimeError("Found 0 images within: %s" % source_root)
self.root = source_root
self.target_root = target_root
self.source_transform = source_transform
self.target_transform = target_transform
self.source_loader = source_loader
self.target_loader = target_loader
def __getitem__(self, index):
path, target_path = self.imagepairs[index]
source = self.source_loader(path)
target = self.target_loader(target_path)
if self.source_transform is not None:
source = self.source_transform(source)
if self.target_transform is not None:
target = self.target_transform(target)
return source, target
def __len__(self):
return len(self.imagepairs)
class FeatureAndClassFolder(data.Dataset):
"""
A data loader that looks for parallel image filenames
photo/park/004234.jpg
photo/park/004236.jpg
photo/park/004237.jpg
feature/park/004234.png
feature/park/004236.png
feature/park/004237.png
"""
def __init__(self, source_root, target_root,
source_transform=None, target_transform=None,
source_loader=default_loader, target_loader=grayscale_loader):
classes, class_to_idx = find_classes(source_root)
self.imagetriples= make_triples(source_root, target_root, class_to_idx)
if len(self.imagetriples) == 0:
raise RuntimeError("Found 0 images within: %s" % source_root)
self.root = source_root
self.target_root = target_root
self.classes = classes
self.class_to_idx = class_to_idx
self.source_transform = source_transform
self.target_transform = target_transform
self.source_loader = source_loader
self.target_loader = target_loader
def __getitem__(self, index):
path, classidx, target_path = self.imagetriples[index]
source = self.source_loader(path)
target = self.target_loader(target_path)
if self.source_transform is not None:
source = self.source_transform(source)
if self.target_transform is not None:
target = self.target_transform(target)
return source, (classidx, target)
def __len__(self):
return len(self.imagetriples)
class CachedImageFolder(data.Dataset):
"""
A version of torchvision.dataset.ImageFolder that takes advantage
of cached filename lists.
photo/park/004234.jpg
photo/park/004236.jpg
photo/park/004237.jpg
"""
def __init__(self, root,
transform=None,
loader=default_loader):
classes, class_to_idx = find_classes(root)
self.imgs = make_class_dataset(root, class_to_idx)
if len(self.imgs) == 0:
raise RuntimeError("Found 0 images within: %s" % root)
self.root = root
self.classes = classes
self.class_to_idx = class_to_idx
self.transform = transform
self.loader = loader
def __getitem__(self, index):
path, classidx = self.imgs[index]
source = self.loader(path)
if self.transform is not None:
source = self.transform(source)
return source, classidx
def __len__(self):
return len(self.imgs)
class StackFeatureChannels(object):
def __init__(self, channels=None, keep_only=None):
self.channels = channels
self.keep_only = keep_only
def __call__(self, tensor):
if self.channels:
channels = self.channels
height = tensor.shape[1] // channels
else:
height = tensor.shape[2]
channels = tensor.shape[1] // height
result = tensor.view(channels, height, tensor.shape[2])
if self.keep_only:
result = result[:self.keep_only,...]
return result
class SoftExpScale(object):
def __init__(self, alpha=45.0):
self.scale = 255.0 / alpha
def __call__(self, tensor):
return (tensor * self.scale).exp_().sub_(1)
def is_npy_file(path):
return path.endswith('.npy') or path.endswith('.NPY')
def walk_image_files(rootdir):
if os.path.isfile('%s.txt' % rootdir):
print('Loading file list from %s.txt instead of scanning dir' % rootdir)
basedir = os.path.dirname(rootdir)
with open('%s.txt' % rootdir) as f:
result = sorted([os.path.join(basedir, line.strip())
for line in f.readlines()])
import random
random.Random(1).shuffle(result)
return result
result = []
for dirname, _, fnames in sorted(os.walk(rootdir)):
for fname in sorted(fnames):
if is_image_file(fname) or is_npy_file(fname):
result.append(os.path.join(dirname, fname))
return result
def find_classes(dir):
classes = [d for d in os.listdir(dir)
if os.path.isdir(os.path.join(dir, d))]
classes.sort()
class_to_idx = {classes[i]: i for i in range(len(classes))}
return classes, class_to_idx
def make_feature_dataset(source_root, target_root):
"""
Finds images in the subdirectories under source_root, and looks for
similarly-located images (with the same directory structure
and base filenames, but with possibly different file extensions)
under target_root. Each source image have a corresponding
target image.
"""
source_root = os.path.expanduser(source_root)
target_root = os.path.expanduser(target_root)
target_images = {}
for path in walk_image_files(target_root):
key = os.path.splitext(os.path.relpath(path, target_root))[0]
target_images[key] = path
imagepairs = []
for path in walk_image_files(source_root):
key = os.path.splitext(os.path.relpath(path, source_root))[0]
if key not in target_images:
raise RuntimeError('%s has no matching target %s.*' %
(path, os.path.join(target_root, key)) )
imagepairs.append((path, target_images[key]))
return imagepairs
def make_triples(source_root, target_root, class_to_idx):
"""
Returns (source, classnum, feature)
"""
source_root = os.path.expanduser(source_root)
target_root = os.path.expanduser(target_root)
target_images = {}
for path in walk_image_files(target_root):
key = os.path.splitext(os.path.relpath(path, target_root))[0]
target_images[key] = path
imagetriples = []
for path in walk_image_files(source_root):
key = os.path.splitext(os.path.relpath(path, source_root))[0]
if key not in target_images:
raise RuntimeError('%s has no matching target %s.*' %
(path, os.path.join(target_root, key)) )
classname = os.path.basename(os.path.dirname(key))
imagetriples.append((path, class_to_idx[classname], target_images[key]))
return imagetriples
def make_class_dataset(source_root, class_to_idx):
"""
Returns (source, classnum, feature)
"""
imagepairs = []
source_root = os.path.expanduser(source_root)
for path in walk_image_files(source_root):
classname = os.path.basename(os.path.dirname(path))
imagepairs.append((path, class_to_idx[classname]))
return imagepairs
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment