Skip to content

Instantly share code, notes, and snippets.

@eickenberg
Created November 23, 2012 15:57
Show Gist options
  • Save eickenberg/4136238 to your computer and use it in GitHub Desktop.
Save eickenberg/4136238 to your computer and use it in GitHub Desktop.
Patch extractor for numpy arrays
import numpy as np
from numpy.lib.stride_tricks import as_strided
import numbers
def make_patches(arr, patch_shape=2, extraction_step=1):
arr_ndim = arr.ndim
if isinstance(patch_shape, numbers.Number):
patch_shape = tuple([patch_shape] * arr_ndim)
if isinstance(extraction_step, numbers.Number):
extraction_step = tuple([extraction_step] * arr_ndim)
patch_strides = arr.strides
slices = [slice(None, None, st) for st in extraction_step]
indexing_strides = arr[slices].strides
patch_indices_shape = (np.array(arr.shape) - np.array(patch_shape)) /\
np.array(extraction_step) + 1
shape = tuple(list(patch_indices_shape) + list(patch_shape))
strides = tuple(list(indexing_strides) + list(patch_strides))
patches = as_strided(arr, shape=shape, strides=strides)
return patches
if __name__ == "__main__":
from scipy.misc import lena
l = lena()
patch_size = (150, 160)
patch_step = (50, 60)
p = make_patches(l, patch_size, patch_step)
import pylab as pl
pl.figure()
for i in range(np.prod(p.shape[:2])):
pl.subplot(p.shape[0], p.shape[1], i + 1)
pl.imshow(p[i / p.shape[0], i % p.shape[1]])
pl.gray()
@eickenberg
Copy link
Author

if name == "main":
from scipy.misc import lena
l = lena()

patch_size = (150, 160)
patch_step = (50, 60)

p = make_patches(l, patch_size, patch_step)

import pylab as pl
pl.figure()
for i in range(np.prod(p.shape[:2])):
    pl.subplot(p.shape[0], p.shape[1], i + 1)
    pl.imshow(p[i / p.shape[0], i % p.shape[1]])
    pl.gray()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment