Skip to content

Instantly share code, notes, and snippets.

Created November 3, 2016 15:33
Show Gist options
  • Save jcjohnson/97c9f9d73c66ae87174a14e7ea7198fb to your computer and use it in GitHub Desktop.
Save jcjohnson/97c9f9d73c66ae87174a14e7ea7198fb to your computer and use it in GitHub Desktop.
import argparse, os
import numpy as np
from scipy.misc import imread, imresize
from skimage.filters import gaussian
import h5py
parser = argparse.ArgumentParser()
parser.add_argument('--train_dir', default='data/yang-91')
parser.add_argument('--val_dir', default='data/set5')
parser.add_argument('--max_train', default=-1, type=int)
parser.add_argument('--max_val', default=-1, type=int)
parser.add_argument('--train_list', default=None)
parser.add_argument('--val_list', default=None)
parser.add_argument('--output_h5', default='data/yang-91.h5')
parser.add_argument('--patch_size', default=96, type=int)
parser.add_argument('--patch_stride', default=7, type=int)
parser.add_argument('--sizes', default='2,3,4,8,16')
parser.add_argument('--sigma', default=1.0, type=float)
args = parser.parse_args()
def handle_split(split, file_list, h5_file):
# This should be easy to fit in memory
sizes = [int(x) for x in args.sizes.split(',')]
size_tuples = [args.patch_size / s for s in sizes]
size_tuples = [(s, s) for s in size_tuples]
patches = []
small_patches = {s: [] for s in sizes}
# For validation images, use stride = patch size to reduce size
stride = args.patch_stride
if split == 'val':
stride = args.patch_size
# Extract patches from all images
num_patches = 0
for i, in_path in enumerate(file_list):
print 'Starting image %d / %d' % (i + 1, len(file_list))
# in_path = os.path.join(input_dir, filename)
img = imread(in_path)
if img.ndim == 0: continue
if img.ndim == 2: img = img[:, :, None][:, :, [0, 0, 0]]
H, W = img.shape[0], img.shape[1]
for x0 in xrange(0, W - args.patch_size, stride):
x1 = x0 + args.patch_size
for y0 in xrange(0, H - args.patch_size, stride):
y1 = y0 + args.patch_size
patch = img[y0:y1, x0:x1]
assert patch.shape == (args.patch_size, args.patch_size, 3), patch.shape
for size, size_tuple in zip(sizes, size_tuples):
small_patch = imresize(gaussian(patch, args.sigma), size_tuple)
# Shuffle and concatenate all patches into numpy arrays
patches = np.concatenate(patches, axis=0).transpose(0, 3, 1, 2)
order = np.random.permutation(patches.shape[0])
patches = patches[order]
for k, v in small_patches.iteritems():
small_patches[k] = np.concatenate(v, axis=0).transpose(0, 3, 1, 2)
small_patches[k] = small_patches[k][order]
# Write patches to an HDF5 file
print patches.shape
h5_file.create_dataset('%s/y' % split, data=patches)
for k, v in small_patches.iteritems():
print v.shape
h5_file.create_dataset('%s/x_%d' % (split, k), data=v)
def get_file_list(image_dir, image_list, max_files):
if image_list is None:
file_list = [os.path.join(image_dir, fn) for fn in os.listdir(image_dir)]
with open(image_list, 'r') as f:
file_list = [line.strip() for line in f]
if max_files > 0:
file_list = file_list[:max_files]
return file_list
if __name__ == '__main__':
with h5py.File(args.output_h5, 'w') as h5_file:
val_list = get_file_list(args.val_dir, args.val_list, args.max_val)
train_list = get_file_list(args.train_dir, args.train_list, args.max_train)
handle_split('val', val_list, h5_file)
handle_split('train', train_list, h5_file)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment