Last active
May 5, 2020 23:57
-
-
Save jirassimok/d58c2c752bdb8f7a121f55ec4d8d2868 to your computer and use it in GitHub Desktop.
Multiple slice assignment
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Simulate assignment to different-shaped, non-contiguous slices of a NumPy array. | |
Note that this is unnecessarily complex and inefficient for most use cases. | |
Created for this Stack Overflow answer: | |
https://stackoverflow.com/a/60803701 | |
""" | |
from itertools import chain | |
import numpy as np | |
def prepare_slices(arr, *slices): | |
"""Prepare multiple slices of an array. | |
Returns a tuple of index arrays and a function that will broadcast | |
values to the same shape as the slices. | |
""" | |
# Make all slices into tuples | |
slices = tuple(s if isinstance(s, tuple) else (s,) for s in slices) | |
# Get flat indices of each slice | |
indices = np.indices(arr.shape) | |
axis0 = np.s_[:,] | |
idx_iters = (np.ravel_multi_index(indices[axis0 + s], arr.shape).flat | |
for s in slices) | |
raveled_indices = list(chain.from_iterable(idx_iters)) | |
# Convert back to multidimensional indices as index arrays | |
slice_indices = np.unravel_index(raveled_indices, arr.shape) | |
def broadcast(*vals): | |
"""Broadcast the values to the shapes of the corresponding slices.""" | |
if len(slices) != len(vals): | |
raise ValueError("Wrong number of values for index broadcasts") | |
return [v for s, val in zip(slices, vals) | |
for v in np.broadcast_to(val, arr[s].shape).flat] | |
return slice_indices, broadcast | |
# Simplified, more-efficient version for 1-D arrays | |
def prepare_slices_1d(arr, *slices): | |
"""Prepare multiple slices of an array | |
Returns an index array and a function that will broadcast | |
values to the same shape as the slices. | |
""" | |
def broadcast(*vals): | |
"""Broadcast the values to the shapes of the corresponding slices.""" | |
if len(slices) != len(vals): | |
raise ValueError("Wrong number of values for index broadcast") | |
return [v for s, val in zip(slices, vals) | |
for v in np.broadcast_to(val, arr[s].shape)] | |
indices = np.arange(len(arr)) | |
selector = np.concatenate([indices[s] for s in slices]) | |
return selector, broadcast |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from numpy import s_ | |
from prepare_slices import prepare_slices, prepare_slices_1d | |
# Top left to (-1 -2)(-3 -4), middle row 9, middle column 1 2 3 4 5, bottom right (7 8)(7 8) | |
arr = np.zeros((5, 5), dtype=int) | |
idx, cast = prepare_slices(arr, s_[0:2, 0:2], s_[2], s_[3:, -2:], s_[:, 2]) | |
arr[idx] = cast([[-1, -2], [-3, -4]], 9, [7, 8], range(1, 6)) | |
assert np.array_equal(arr, | |
[[-1,-2, 1, 0, 0], | |
[-3,-4, 2, 0, 0], | |
[ 9, 9, 3, 9, 9], | |
[ 0, 0, 4, 7, 8], | |
[ 0, 0, 5, 7, 8]]) | |
arr = np.zeros(10, dtype=int) | |
idx, cast = prepare_slices_1d(arr, s_[1:3], s_[8:5:-1], s_[-2:]) | |
arr[idx] = cast([1, 2], [8], -1) | |
assert np.array_equal(arr, [0, 1, 2, 0, 0, 0, 8, 8, -1, -1]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment