Created
March 31, 2015 13:57
-
-
Save richardjgowers/2671643287bfb17736ea to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
def blocks_of(a, n, m): | |
"""Extract a view of (n, m) blocks along the diagonal. | |
Arguments: | |
a - starting array | |
n, m - size of each miniblock | |
Returns: | |
(nblocks, n, m) view of the original array. | |
Where nblocks is the number of times the miniblock fits in the original. | |
n, m must divide a into an identical integer number of blocks. | |
based on: | |
http://stackoverflow.com/a/10862636 | |
but generalised to handle non square blocks. | |
Uses strides so probably requires that the array is C contiguous | |
Returns a view, so editing this modifies the original array | |
""" | |
nblocks = a.shape[0] / n | |
nblocks2 = a.shape[1] / m | |
if not nblocks == nblocks2: | |
raise ValueError("Must divide into same number of blocks in both directions") | |
new_shape = (nblocks, n, m) | |
new_strides = (n * a.strides[0] + m * a.strides[1], | |
a.strides[0], a.strides[1]) | |
return np.lib.stride_tricks.as_strided(a, new_shape, new_strides) | |
a = np.arange(9*12).reshape(9, 12) | |
print a | |
b = blocks_of(a, 3, 4) | |
print b | |
b += 100 | |
print a |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment