Skip to content

Instantly share code, notes, and snippets.

@jirassimok
Last active May 5, 2020 23:57
Show Gist options
  • Save jirassimok/d58c2c752bdb8f7a121f55ec4d8d2868 to your computer and use it in GitHub Desktop.
Save jirassimok/d58c2c752bdb8f7a121f55ec4d8d2868 to your computer and use it in GitHub Desktop.
Multiple slice assignment
"""
Simulate assignment to different-shaped, non-contiguous slices of a NumPy array.
Note that this is unnecessarily complex and inefficient for most use cases.
Created for this Stack Overflow answer:
https://stackoverflow.com/a/60803701
"""
from itertools import chain
import numpy as np
def prepare_slices(arr, *slices):
"""Prepare multiple slices of an array.
Returns a tuple of index arrays and a function that will broadcast
values to the same shape as the slices.
"""
# Make all slices into tuples
slices = tuple(s if isinstance(s, tuple) else (s,) for s in slices)
# Get flat indices of each slice
indices = np.indices(arr.shape)
axis0 = np.s_[:,]
idx_iters = (np.ravel_multi_index(indices[axis0 + s], arr.shape).flat
for s in slices)
raveled_indices = list(chain.from_iterable(idx_iters))
# Convert back to multidimensional indices as index arrays
slice_indices = np.unravel_index(raveled_indices, arr.shape)
def broadcast(*vals):
"""Broadcast the values to the shapes of the corresponding slices."""
if len(slices) != len(vals):
raise ValueError("Wrong number of values for index broadcasts")
return [v for s, val in zip(slices, vals)
for v in np.broadcast_to(val, arr[s].shape).flat]
return slice_indices, broadcast
# Simplified, more-efficient version for 1-D arrays
def prepare_slices_1d(arr, *slices):
"""Prepare multiple slices of an array
Returns an index array and a function that will broadcast
values to the same shape as the slices.
"""
def broadcast(*vals):
"""Broadcast the values to the shapes of the corresponding slices."""
if len(slices) != len(vals):
raise ValueError("Wrong number of values for index broadcast")
return [v for s, val in zip(slices, vals)
for v in np.broadcast_to(val, arr[s].shape)]
indices = np.arange(len(arr))
selector = np.concatenate([indices[s] for s in slices])
return selector, broadcast
import numpy as np
from numpy import s_
from prepare_slices import prepare_slices, prepare_slices_1d
# Top left to (-1 -2)(-3 -4), middle row 9, middle column 1 2 3 4 5, bottom right (7 8)(7 8)
arr = np.zeros((5, 5), dtype=int)
idx, cast = prepare_slices(arr, s_[0:2, 0:2], s_[2], s_[3:, -2:], s_[:, 2])
arr[idx] = cast([[-1, -2], [-3, -4]], 9, [7, 8], range(1, 6))
assert np.array_equal(arr,
[[-1,-2, 1, 0, 0],
[-3,-4, 2, 0, 0],
[ 9, 9, 3, 9, 9],
[ 0, 0, 4, 7, 8],
[ 0, 0, 5, 7, 8]])
arr = np.zeros(10, dtype=int)
idx, cast = prepare_slices_1d(arr, s_[1:3], s_[8:5:-1], s_[-2:])
arr[idx] = cast([1, 2], [8], -1)
assert np.array_equal(arr, [0, 1, 2, 0, 0, 0, 8, 8, -1, -1])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment