Skip to content

Instantly share code, notes, and snippets.

@mvoelk
Forked from ernestum/elastic_transform.py
Last active November 12, 2020 15:59
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mvoelk/0880f5de7c101c093165e1e46ce3f6e5 to your computer and use it in GitHub Desktop.
Save mvoelk/0880f5de7c101c093165e1e46ce3f6e5 to your computer and use it in GitHub Desktop.
Elastic transformation of an image in Python
import numpy as np
from scipy.ndimage.interpolation import map_coordinates
from scipy.ndimage.filters import gaussian_filter
def elastic_transform(image, alpha, sigma, random_state=None):
"""Elastic deformation of images as described in [Simard2003].
.. [Simard2003] Simard, Steinkraus and Platt, "Best Practices for
Convolutional Neural Networks applied to Visual Document Analysis", in
Proc. of the International Conference on Document Analysis and
Recognition, 2003.
"""
if random_state is None:
random_state = np.random.RandomState(None)
h, w = image.shape[:2]
x, y = np.meshgrid(np.arange(w), np.arange(h))
dx = gaussian_filter((random_state.rand(h,w) * 2 - 1), sigma, mode="constant", cval=0) * alpha
dy = gaussian_filter((random_state.rand(h,w) * 2 - 1), sigma, mode="constant", cval=0) * alpha
indices = np.reshape(y+dy, (-1, 1)), np.reshape(x+dx, (-1, 1))
if len(image.shape) > 2:
c = image.shape[2]
distored_image = [map_coordinates(image[:,:,i], indices, order=1, mode='reflect') for i in range(c)]
distored_image = np.concatenate(distored_image, axis=1)
else:
distored_image = map_coordinates(image, indices, order=1, mode='reflect')
return distored_image.reshape(image.shape)
@mvoelk
Copy link
Author

mvoelk commented Jun 30, 2020

Can handle RGB and grayscale images. Interpolation is done channel-wise and works for an arbitrary number of channels.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment