Skip to content

Instantly share code, notes, and snippets.

@chsasank
Forked from fmder/elastic_transform.py
Last active October 14, 2023 01:55
Show Gist options
  • Save chsasank/4d8f68caf01f041a6453e67fb30f8f5a to your computer and use it in GitHub Desktop.
Save chsasank/4d8f68caf01f041a6453e67fb30f8f5a to your computer and use it in GitHub Desktop.
Elastic transformation of an image in Python
import numpy as np
from scipy.ndimage.interpolation import map_coordinates
from scipy.ndimage.filters import gaussian_filter
def elastic_transform(image, alpha, sigma, random_state=None):
"""Elastic deformation of images as described in [Simard2003]_.
.. [Simard2003] Simard, Steinkraus and Platt, "Best Practices for
Convolutional Neural Networks applied to Visual Document Analysis", in
Proc. of the International Conference on Document Analysis and
Recognition, 2003.
"""
assert len(image.shape)==2
if random_state is None:
random_state = np.random.RandomState(None)
shape = image.shape
dx = gaussian_filter((random_state.rand(*shape) * 2 - 1), sigma, mode="constant", cval=0) * alpha
dy = gaussian_filter((random_state.rand(*shape) * 2 - 1), sigma, mode="constant", cval=0) * alpha
x, y = np.meshgrid(np.arange(shape[0]), np.arange(shape[1]), indexing='ij')
indices = np.reshape(x+dx, (-1, 1)), np.reshape(y+dy, (-1, 1))
return map_coordinates(image, indices, order=1).reshape(shape)
@koegl
Copy link

koegl commented Jun 17, 2022

for those who have trouble with using this code
first, the image has to be square like: (x,x,3)
second, for showing the transformed image ex: transformed_image=elastic_transform(img,40,2) then using plt.imshow(transformed_image)

  1. The image does not have to be square
  2. The image has to be grayscale, i.e. the shape has to be (a,b) - this line of code makes sure of that assert len(image.shape)==2

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment