Skip to content

Instantly share code, notes, and snippets.

@oeway
Last active October 1, 2020 17:30
Show Gist options
  • Star 30 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save oeway/2e3b989e0343f0884388ed7ed82eb3b0 to your computer and use it in GitHub Desktop.
Save oeway/2e3b989e0343f0884388ed7ed82eb3b0 to your computer and use it in GitHub Desktop.
Improved image transform functions for dense predictions (for pytorch, keras etc.)
import numpy as np
import scipy
import scipy.ndimage
from scipy.ndimage.filters import gaussian_filter
from scipy.ndimage.interpolation import map_coordinates
import collections
from PIL import Image
import numbers
__author__ = "Wei OUYANG"
__license__ = "GPL"
__version__ = "0.1.0"
__status__ = "Development"
def center_crop(x, center_crop_size):
assert x.ndim == 3
centerw, centerh = x.shape[1] // 2, x.shape[2] // 2
halfw, halfh = center_crop_size[0] // 2, center_crop_size[1] // 2
return x[:, centerw - halfw:centerw + halfw, centerh - halfh:centerh + halfh]
def to_tensor(x):
import torch
x = x.transpose((2, 0, 1))
return torch.from_numpy(x).float()
def random_num_generator(config, random_state=np.random):
if config[0] == 'uniform':
ret = random_state.uniform(config[1], config[2], 1)[0]
elif config[0] == 'lognormal':
ret = random_state.lognormal(config[1], config[2], 1)[0]
else:
print(config)
raise Exception('unsupported format')
return ret
def poisson_downsampling(image, peak, random_state=np.random):
if not isinstance(image, np.ndarray):
imgArr = np.array(image, dtype='float32')
else:
imgArr = image.astype('float32')
Q = imgArr.max(axis=(0, 1)) / peak
if Q[0] == 0:
return imgArr
ima_lambda = imgArr / Q
noisy_img = random_state.poisson(lam=ima_lambda)
return noisy_img.astype('float32')
def elastic_transform(image, alpha=1000, sigma=30, spline_order=1, mode='nearest', random_state=np.random):
"""Elastic deformation of image as described in [Simard2003]_.
.. [Simard2003] Simard, Steinkraus and Platt, "Best Practices for
Convolutional Neural Networks applied to Visual Document Analysis", in
Proc. of the International Conference on Document Analysis and
Recognition, 2003.
"""
assert image.ndim == 3
shape = image.shape[:2]
dx = gaussian_filter((random_state.rand(*shape) * 2 - 1),
sigma, mode="constant", cval=0) * alpha
dy = gaussian_filter((random_state.rand(*shape) * 2 - 1),
sigma, mode="constant", cval=0) * alpha
x, y = np.meshgrid(np.arange(shape[0]), np.arange(shape[1]), indexing='ij')
indices = [np.reshape(x + dx, (-1, 1)), np.reshape(y + dy, (-1, 1))]
result = np.empty_like(image)
for i in range(image.shape[2]):
result[:, :, i] = map_coordinates(
image[:, :, i], indices, order=spline_order, mode=mode).reshape(shape)
return result
class Merge(object):
"""Merge a group of images
"""
def __init__(self, axis=-1):
self.axis = axis
def __call__(self, images):
if isinstance(images, collections.Sequence) or isinstance(images, np.ndarray):
assert all([isinstance(i, np.ndarray)
for i in images]), 'only numpy array is supported'
shapes = [list(i.shape) for i in images]
for s in shapes:
s[self.axis] = None
assert all([s == shapes[0] for s in shapes]
), 'shapes must be the same except the merge axis'
return np.concatenate(images, axis=self.axis)
else:
raise Exception("obj is not a sequence (list, tuple, etc)")
class Split(object):
"""Split images into individual arraies
"""
def __init__(self, *slices, **kwargs):
assert isinstance(slices, collections.Sequence)
slices_ = []
for s in slices:
if isinstance(s, collections.Sequence):
slices_.append(slice(*s))
else:
slices_.append(s)
assert all([isinstance(s, slice) for s in slices_]
), 'slices must be consist of slice instances'
self.slices = slices_
self.axis = kwargs.get('axis', -1)
def __call__(self, image):
if isinstance(image, np.ndarray):
ret = []
for s in self.slices:
sl = [slice(None)] * image.ndim
sl[self.axis] = s
ret.append(image[sl])
return ret
else:
raise Exception("obj is not an numpy array")
class ElasticTransform(object):
"""Apply elastic transformation on a numpy.ndarray (H x W x C)
"""
def __init__(self, alpha, sigma):
self.alpha = alpha
self.sigma = sigma
def __call__(self, image):
if isinstance(self.alpha, collections.Sequence):
alpha = random_num_generator(self.alpha)
else:
alpha = self.alpha
if isinstance(self.sigma, collections.Sequence):
sigma = random_num_generator(self.sigma)
else:
sigma = self.sigma
return elastic_transform(image, alpha=alpha, sigma=sigma)
class PoissonSubsampling(object):
"""Poisson subsampling on a numpy.ndarray (H x W x C)
"""
def __init__(self, peak, random_state=np.random):
self.peak = peak
self.random_state = random_state
def __call__(self, image):
if isinstance(self.peak, collections.Sequence):
peak = random_num_generator(
self.peak, random_state=self.random_state)
else:
peak = self.peak
return poisson_downsampling(image, peak, random_state=self.random_state)
class AddGaussianNoise(object):
"""Add gaussian noise to a numpy.ndarray (H x W x C)
"""
def __init__(self, mean, sigma, random_state=np.random):
self.sigma = sigma
self.mean = mean
self.random_state = random_state
def __call__(self, image):
if isinstance(self.sigma, collections.Sequence):
sigma = random_num_generator(
self.sigma, random_state=self.random_state)
else:
sigma = self.sigma
if isinstance(self.mean, collections.Sequence, random_state=self.random_state):
mean = random_num_generator(self.mean)
else:
mean = self.mean
row, col, ch = image.shape
gauss = self.random_state.normal(mean, sigma, (row, col, ch))
gauss = gauss.reshape(row, col, ch)
image += gauss
return image
class AddSpeckleNoise(object):
"""Add speckle noise to a numpy.ndarray (H x W x C)
"""
def __init__(self, mean, sigma, random_state=np.random):
self.sigma = sigma
self.mean = mean
self.random_state = random_state
def __call__(self, image):
if isinstance(self.sigma, collections.Sequence):
sigma = random_num_generator(
self.sigma, random_state=self.random_state)
else:
sigma = self.sigma
if isinstance(self.mean, collections.Sequence):
mean = random_num_generator(
self.mean, random_state=self.random_state)
else:
mean = self.mean
row, col, ch = image.shape
gauss = self.random_state.normal(mean, sigma, (row, col, ch))
gauss = gauss.reshape(row, col, ch)
image += image * gauss
return image
class GaussianBlurring(object):
"""Apply gaussian blur to a numpy.ndarray (H x W x C)
"""
def __init__(self, sigma, random_state=np.random):
self.sigma = sigma
self.random_state = random_state
def __call__(self, image):
if isinstance(self.sigma, collections.Sequence):
sigma = random_num_generator(
self.sigma, random_state=self.random_state)
else:
sigma = self.sigma
image = gaussian_filter(image, sigma=(sigma, sigma, 0))
return image
class AddGaussianPoissonNoise(object):
"""Add poisson noise with gaussian blurred image to a numpy.ndarray (H x W x C)
"""
def __init__(self, sigma, peak, random_state=np.random):
self.sigma = sigma
self.peak = peak
self.random_state = random_state
def __call__(self, image):
if isinstance(self.sigma, collections.Sequence):
sigma = random_num_generator(
self.sigma, random_state=self.random_state)
else:
sigma = self.sigma
if isinstance(self.peak, collections.Sequence):
peak = random_num_generator(
self.peak, random_state=self.random_state)
else:
peak = self.peak
bg = gaussian_filter(image, sigma=(sigma, sigma, 0))
bg = poisson_downsampling(
bg, peak=peak, random_state=self.random_state)
return image + bg
class MaxScaleNumpy(object):
"""scale with max and min of each channel of the numpy array i.e.
channel = (channel - mean) / std
"""
def __init__(self, range_min=0.0, range_max=1.0):
self.scale = (range_min, range_max)
def __call__(self, image):
mn = image.min(axis=(0, 1))
mx = image.max(axis=(0, 1))
return self.scale[0] + (image - mn) * (self.scale[1] - self.scale[0]) / (mx - mn)
class MedianScaleNumpy(object):
"""Scale with median and mean of each channel of the numpy array i.e.
channel = (channel - mean) / std
"""
def __init__(self, range_min=0.0, range_max=1.0):
self.scale = (range_min, range_max)
def __call__(self, image):
mn = image.min(axis=(0, 1))
md = np.median(image, axis=(0, 1))
return self.scale[0] + (image - mn) * (self.scale[1] - self.scale[0]) / (md - mn)
class NormalizeNumpy(object):
"""Normalize each channel of the numpy array i.e.
channel = (channel - mean) / std
"""
def __call__(self, image):
image -= image.mean(axis=(0, 1))
s = image.std(axis=(0, 1))
s[s == 0] = 1.0
image /= s
return image
class MutualExclude(object):
"""Remove elements from one channel
"""
def __init__(self, exclude_channel, from_channel):
self.from_channel = from_channel
self.exclude_channel = exclude_channel
def __call__(self, image):
mask = image[:, :, self.exclude_channel] > 0
image[:, :, self.from_channel][mask] = 0
return image
class RandomCropNumpy(object):
"""Crops the given numpy array at a random location to have a region of
the given size. size can be a tuple (target_height, target_width)
or an integer, in which case the target will be of a square shape (size, size)
"""
def __init__(self, size, random_state=np.random):
if isinstance(size, numbers.Number):
self.size = (int(size), int(size))
else:
self.size = size
self.random_state = random_state
def __call__(self, img):
w, h = img.shape[:2]
th, tw = self.size
if w == tw and h == th:
return img
x1 = self.random_state.randint(0, w - tw)
y1 = self.random_state.randint(0, h - th)
return img[x1:x1 + tw, y1: y1 + th, :]
class CenterCropNumpy(object):
"""Crops the given numpy array at the center to have a region of
the given size. size can be a tuple (target_height, target_width)
or an integer, in which case the target will be of a square shape (size, size)
"""
def __init__(self, size):
if isinstance(size, numbers.Number):
self.size = (int(size), int(size))
else:
self.size = size
def __call__(self, img):
w, h = img.shape[:2]
th, tw = self.size
x1 = int(round((w - tw) / 2.))
y1 = int(round((h - th) / 2.))
return img[x1:x1 + tw, y1: y1 + th, :]
class RandomRotate(object):
"""Rotate a PIL.Image or numpy.ndarray (H x W x C) randomly
"""
def __init__(self, angle_range=(0.0, 360.0), axes=(0, 1), mode='reflect', random_state=np.random):
assert isinstance(angle_range, tuple)
self.angle_range = angle_range
self.random_state = random_state
self.axes = axes
self.mode = mode
def __call__(self, image):
angle = self.random_state.uniform(
self.angle_range[0], self.angle_range[1])
if isinstance(image, np.ndarray):
mi, ma = image.min(), image.max()
image = scipy.ndimage.interpolation.rotate(
image, angle, reshape=False, axes=self.axes, mode=self.mode)
return np.clip(image, mi, ma)
elif isinstance(image, Image.Image):
return image.rotate(angle)
else:
raise Exception('unsupported type')
class BilinearResize(object):
"""Resize a PIL.Image or numpy.ndarray (H x W x C)
"""
def __init__(self, zoom):
self.zoom = [zoom, zoom, 1]
def __call__(self, image):
if isinstance(image, np.ndarray):
return scipy.ndimage.interpolation.zoom(image, self.zoom)
elif isinstance(image, Image.Image):
return image.resize(self.size, Image.BILINEAR)
else:
raise Exception('unsupported type')
class EnhancedCompose(object):
"""Composes several transforms together.
Args:
transforms (List[Transform]): list of transforms to compose.
Example:
>>> transforms.Compose([
>>> transforms.CenterCrop(10),
>>> transforms.ToTensor(),
>>> ])
"""
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, img):
for t in self.transforms:
if isinstance(t, collections.Sequence):
assert isinstance(img, collections.Sequence) and len(img) == len(
t), "size of image group and transform group does not fit"
tmp_ = []
for i, im_ in enumerate(img):
if callable(t[i]):
tmp_.append(t[i](im_))
else:
tmp_.append(im_)
img = tmp_
elif callable(t):
img = t(img)
elif t is None:
continue
else:
raise Exception('unexpected type')
return img
if __name__ == '__main__':
from torchvision.transforms import Lambda
input_channel = 3
target_channel = 3
# define a transform pipeline
transform = EnhancedCompose([
Merge(),
RandomCropNumpy(size=(512, 512)),
RandomRotate(),
Split([0, input_channel], [input_channel, input_channel+target_channel]),
[CenterCropNumpy(size=(256, 256)), CenterCropNumpy(size=(256, 256))],
[NormalizeNumpy(), MaxScaleNumpy(0, 1.0)],
# for non-pytorch usage, remove to_tensor conversion
[Lambda(to_tensor), Lambda(to_tensor)]
])
# read input data for test
image_in = np.array(Image.open('input.jpg'))
image_target = np.array(Image.open('target.jpg'))
# apply the transform
x, y = transform([image_in, image_target])
@poptree
Copy link

poptree commented Mar 8, 2018

i think there is something wrong in line 179, you put the arguments "random_state=self.random_state" in the instance function

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment