Skip to content

Instantly share code, notes, and snippets.

@mckelvin
Created February 18, 2014 02:26
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save mckelvin/9063510 to your computer and use it in GitHub Desktop.
Save mckelvin/9063510 to your computer and use it in GitHub Desktop.
# coding: utf-8
# code snippet from http://www.johnvinyard.com/blog/?p=268
import numpy as np
from numpy.lib.stride_tricks import as_strided as ast
def norm_shape(shape):
'''
Normalize numpy array shapes so they're always expressed as a tuple,
even for one-dimensional shapes.
Parameters
shape - an int, or a tuple of ints
Returns
a shape tuple
'''
try:
i = int(shape)
return (i,)
except TypeError:
# shape was not a number
pass
try:
t = tuple(shape)
return t
except TypeError:
# shape was not iterable
pass
raise TypeError('shape must be an int, or a tuple of ints')
def sliding_window(a, ws, ss=None, flatten=True):
'''
Return a sliding window over a in any number of dimensions
Parameters:
a - an n-dimensional numpy array
ws - an int (a is 1D) or tuple (a is 2D or greater) representing the
size of each dimension of the window
ss - an int (a is 1D) or tuple (a is 2D or greater) representing the
amount to slide the window in each dimension. If not specified, it
defaults to ws.
flatten - if True, all slices are flattened, otherwise, there is an
extra dimension for each dimension of the input.
Returns
an array containing each n-dimensional window from a
'''
if None is ss:
# ss was not provided. the windows will not overlap in any direction.
ss = ws
ws = norm_shape(ws)
ss = norm_shape(ss)
# convert ws, ss, and a.shape to numpy arrays so that we can do math in
# every dimension at once.
ws = np.array(ws)
ss = np.array(ss)
shape = np.array(a.shape)
# ensure that ws, ss, and a.shape all have the same number of dimensions
ls = [len(shape), len(ws), len(ss)]
if 1 != len(set(ls)):
raise ValueError('a.shape, ws and ss must all have the same length. '
'They were %s' % str(ls))
# ensure that ws is smaller than a in every dimension
if np.any(ws > shape):
raise ValueError('ws cannot be larger than a in any dimension. '
'a.shape was %s and ws was %s' % (
str(a.shape), str(ws)))
# how many slices will there be in each dimension?
newshape = norm_shape(((shape - ws) // ss) + 1)
# the shape of the strided array will be the number of slices in each
# dimension plus the shape of the window (tuple addition)
newshape += norm_shape(ws)
# the strides tuple will be the array's strides multiplied by step size,
# plus the array's strides (tuple addition)
newstrides = norm_shape(np.array(a.strides) * ss) + a.strides
strided = ast(a, shape=newshape, strides=newstrides)
if not flatten:
return strided
# Collapse strided so that it has one more dimension than the window.
# I.e., the new array is a flat list of slices.
meat = len(ws) if ws.shape else 0
firstdim = (np.product(newshape[:-meat]),) if ws.shape else ()
dim = firstdim + (newshape[-meat:])
# remove any dimensions with size 1
dim = filter(lambda i: i != 1, dim)
return strided.reshape(dim)
def main():
arr = np.arange(300).reshape((100, 3))
print sliding_window(arr, (10, 3), (10, 3))
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment