Skip to content

Instantly share code, notes, and snippets.

@brandonwillard
Last active December 23, 2021 01:20
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save brandonwillard/963393206f48f4dc46e5e9b82f5caed9 to your computer and use it in GitHub Desktop.
Save brandonwillard/963393206f48f4dc46e5e9b82f5caed9 to your computer and use it in GitHub Desktop.
Lift advanced indices through concatenate/stack
from typing import Tuple, Union
import numpy as np
def is_basic_idx(x):
return isinstance(x, (slice, type(None)))
def expand_indices(
indices: Tuple[Union[np.ndarray, int, slice]], shape: Tuple[int]
) -> Tuple[np.ndarray]:
"""Convert basic and/or advanced indices (minus the ``None`` case) into a single, broadcasted advanced indexing operation.
Example
-------
>>> indices = (
slice(1, 3),
1,
slice(None),
np.array([2, 1]),
)
>>> expand_indices(indices, (5, 4, 3, 2))
(array([[[1, 1, 1],
[2, 2, 2]],
[[1, 1, 1],
[2, 2, 2]]]),
array([[[1, 1, 1],
[1, 1, 1]],
[[1, 1, 1],
[1, 1, 1]]]),
array([[[0, 1, 2],
[0, 1, 2]],
[[0, 1, 2],
[0, 1, 2]]]),
array([[[2, 2, 2],
[2, 2, 2]],
[[1, 1, 1],
[1, 1, 1]]]))
Parameters
----------
indices
The indices to convert.
shape
The shape of the array being indexed.
"""
n_missing_dims = len(shape) - len(indices)
full_indices = list(indices) + [slice(None)] * n_missing_dims
# We need to know if a "subspace" was generated by advanced indices
# bookending basic indices. If so, we move the advanced indexing subspace
# to the "front" of the shape (i.e. left-most indices/last-most
# dimensions).
index_types = [is_basic_idx(idx) for idx in full_indices]
first_adv_idx = len(shape)
try:
first_adv_idx = index_types.index(False)
first_bsc_after_adv_idx = index_types.index(True, first_adv_idx)
index_types.index(False, first_bsc_after_adv_idx)
moved_subspace = True
except ValueError:
moved_subspace = False
n_basic_indices = sum(index_types)
# The number of dimensions in the subspace created by the advanced indices
n_subspace_dims = max(
(
np.ndim(idx)
for idx, is_basic in zip(full_indices, index_types)
if not is_basic
),
default=0,
)
# The number of dimensions for each expanded index
n_output_dims = n_subspace_dims + n_basic_indices
n_preceding_basics = 0
for d, (idx, s) in enumerate(zip(full_indices, shape)):
if not is_basic_idx(idx):
idx = np.asarray(idx)
if moved_subspace:
# The subspace generated by advanced indices appear as the
# upper dimensions in the "expanded" index space, so we need to
# add broadcast dimensions for the non-basic indices to the end
# of these advanced indices
expanded_idx = idx[(Ellipsis,) + (None,) * n_basic_indices]
else:
# In this case, we need to add broadcast dimensions for the
# basic indices that proceed and follow the group of advanced
# indices; otherwise, a contiguous group of advanced indices
# forms a broadcasted set of indices that are iterated over
# within the same subspace, which means that all their
# corresponding "expanded" indices have exactly the same shape.
expanded_idx = idx[(None,) * n_preceding_basics][
(Ellipsis,) + (None,) * (n_basic_indices - n_preceding_basics)
]
else:
if isinstance(idx, slice):
idx = np.arange(*idx.indices(s))
elif idx is None:
raise NotImplementedError("New axes not supported")
if moved_subspace:
# Basic indices appear in the lower dimensions
# (i.e. right-most) in the output, and are preceded by
# the subspace generated by the advanced indices.
expanded_idx = idx[(None,) * (n_subspace_dims + n_preceding_basics)][
(Ellipsis,) + (None,) * (n_basic_indices - n_preceding_basics - 1)
]
else:
# In this case, we need to know when the basic indices have
# moved past the contiguous group of advanced indices (in the
# "expanded" index space), so that we can properly pad those
# dimensions in this basic index's shape.
# Don't forget that a single advanced index can introduce an
# arbitrary number of dimensions to the expanded index space.
# If we're currently at a basic index that's past the first
# advanced index, then we're necessarily past the group of
# advanced indices.
n_preceding_dims = (
n_subspace_dims if d > first_adv_idx else 0
) + n_preceding_basics
expanded_idx = idx[(None,) * n_preceding_dims][
(Ellipsis,) + (None,) * (n_output_dims - n_preceding_dims - 1)
]
n_preceding_basics += 1
assert expanded_idx.ndim <= n_output_dims
full_indices[d] = expanded_idx
return tuple(np.broadcast_arrays(*full_indices))
def test_expand_indices():
A_parts = (
np.random.normal(size=(4, 3)),
np.random.normal(size=(4, 3)),
np.random.normal(size=(4, 3)),
)
A = np.stack(A_parts)
# A.shape
# (3, 4, 3)
indices = (np.array([[0, 1], [2, 2]]), slice(2, 3))
full_indices = expand_indices(indices, A.shape)
assert len(full_indices) == A.ndim
assert np.array_equal(A[full_indices], A[indices])
# Let's do it by hand:
# hand_full_indices = (
# np.expand_dims(np.array([[0, 1], [2, 2]]), (-1, -2)),
# np.expand_dims(np.arange(2, 3), (-1,)),
# np.arange(A.shape[2]),
# )
# bcast_full_indices = tuple(np.broadcast_arrays(*hand_full_indices))
# [idx.shape for idx in bcast_full_indices]
# [idx.shape for idx in full_indices]
# # This works!
# assert np.array_equal(A[bcast_full_indices], A[indices])
# This is another way to think about it:
# assert np.array_equal(A[indices[0]][:, :, 2:3, :], A[indices])
A_parts = (
np.random.normal(size=(4, 3)),
np.random.normal(size=(4, 3)),
np.random.normal(size=(4, 3)),
)
A = np.stack(A_parts)
indices = (slice(2, 3), np.array([[0, 1], [2, 2]]))
full_indices = expand_indices(indices, A.shape)
assert len(full_indices) == A.ndim
assert np.array_equal(A[full_indices], A[indices])
# Let's do it by hand:
# hand_full_indices = (
# np.expand_dims(np.arange(2, 3), (-1, -2)),
# np.expand_dims(np.array([[0, 1], [2, 2]]), (-1,)),
# np.expand_dims(np.arange(A.shape[2]), (0, 1, 2)),
# )
# bcast_full_indices = tuple(np.broadcast_arrays(*hand_full_indices))
# [idx.shape for idx in bcast_full_indices]
# [idx.shape for idx in full_indices]
# # This works!
# assert np.array_equal(A[bcast_full_indices], A[indices])
A_parts = (
np.random.normal(size=(5, 4, 3)),
np.random.normal(size=(5, 4, 3)),
np.random.normal(size=(5, 4, 3)),
)
A = np.stack(A_parts)
indices = (
np.array([[0], [2], [1]]),
slice(None),
np.array([2, 1]),
slice(2, 3),
)
full_indices = expand_indices(indices, A.shape)
assert len(full_indices) == A.ndim
assert np.array_equal(A[full_indices], A[indices])
# Let's do it by hand:
# hand_full_indices = (
# # 1. Add broadcastable dimensions that equal the number of non-advanced
# # indices to the end of each advanced index.
# np.expand_dims(np.array([[0], [2], [1]]), (-1, -2)),
# # While this is a slice for the second dimension, it is effectively "moved"
# # to the dimension *after* the broadcasted subspace created by the advanced
# # dimensions.
# # 2. Add broadcastable dimensions that equal the number of advanced indices
# # to the beginning of each basic index, and additional dimensions for each
# # basic index that follows.
# np.expand_dims(np.arange(A.shape[1]), (0, 1, -1)),
# # 1.
# np.expand_dims(np.array([2, 1]), (-1, -2)),
# # 2.
# np.expand_dims(np.arange(2, 3), (0, 1, 2)),
# )
# bcast_full_indices = tuple(np.broadcast_arrays(*hand_full_indices))
# [idx.shape for idx in bcast_full_indices]
# [idx.shape for idx in full_indices]
# A[indices].shape
# # This works!
# assert np.array_equiv(A[bcast_full_indices], A[indices])
# assert np.array_equal(A[bcast_full_indices], A[indices])
A_parts = (
np.random.normal(size=(4, 3)),
np.random.normal(size=(4, 3)),
np.random.normal(size=(4, 3)),
)
A = np.stack(A_parts)
indices = (slice(2, 3), np.array([0, 1, 2]))
full_indices = expand_indices(indices, A.shape)
assert len(full_indices) == A.ndim
assert np.array_equal(A[full_indices], A[indices])
A_parts = (
np.random.normal(size=(4, 3)),
np.random.normal(size=(4, 3)),
np.random.normal(size=(4, 3)),
)
A = np.stack(A_parts)
# A.shape
# (3, 4, 3)
indices = (np.array([[0, 1], [2, 2]]), np.array([[0, 1], [2, 2]]))
exp_res = A[indices]
# exp_res.shape
# (2, 2, 3)
full_indices = expand_indices(indices, A.shape)
assert len(full_indices) == A.ndim
res = A[full_indices]
assert np.array_equal(res, exp_res)
A_parts = (
np.random.normal(size=(4, 3)),
np.random.normal(size=(4, 3)),
np.random.normal(size=(4, 3)),
)
A = np.stack(A_parts)
# A.shape
# (3, 4, 3)
indices = (
np.array([[0, 1], [2, 2]]),
np.array([[0, 1], [2, 2]]),
np.array([[0, 1], [2, 2]]),
)
exp_res = A[indices]
# exp_res.shape
# (2, 2)
full_indices = expand_indices(indices, A.shape)
assert len(full_indices) == A.ndim
res = A[full_indices]
assert np.array_equal(res, exp_res)
A_parts = (
np.random.normal(size=(4, 3)),
np.random.normal(size=(4, 3)),
np.random.normal(size=(4, 3)),
)
A = np.stack(A_parts)
# A.shape
# (3, 4, 3)
indices = (np.array([[0, 1], [2, 2]]), np.array([[0, 1], [2, 2]]), 1)
exp_res = A[indices]
# exp_res.shape
# (2, 2)
full_indices = expand_indices(indices, A.shape)
assert len(full_indices) == A.ndim
res = A[full_indices]
assert np.array_equal(res, exp_res)
# No advanced indices
A_parts = (
np.random.normal(size=(5, 4, 3)),
np.random.normal(size=(5, 4, 3)),
)
A = np.stack(A_parts)
# A.shape
# (2, 5, 4, 3)
indices = (slice(0, 2),)
exp_res = A[indices]
# exp_res.shape
# (2, 5, 4, 3)
full_indices = expand_indices(indices, A.shape)
assert len(full_indices) == A.ndim
res = A[full_indices]
assert np.array_equal(res, exp_res)
A_parts = (
np.random.normal(size=(5, 4, 3)),
np.random.normal(size=(5, 4, 3)),
)
A = np.stack(A_parts)
# A.shape
# (2, 5, 4, 3)
indices = (slice(0, 2), np.random.randint(3, size=(2, 3)))
exp_res = A[indices]
# exp_res.shape
# (2, 2, 3, 4, 3)
full_indices = expand_indices(indices, A.shape)
assert len(full_indices) == A.ndim
res = A[full_indices]
assert np.array_equal(res, exp_res)
def test_expand_indices_moved_subspaces():
A_parts = (
np.random.normal(size=(6, 5, 4, 3)),
np.random.normal(size=(6, 5, 4, 3)),
np.random.normal(size=(6, 5, 4, 3)),
)
A = np.stack(A_parts)
indices = (
slice(None),
np.array([[0], [2], [1]]),
slice(None),
np.array([2, 1]),
slice(2, 3),
)
full_indices = expand_indices(indices, A.shape)
assert len(full_indices) == A.ndim
assert np.array_equal(A[full_indices], A[indices])
# Let's do it by hand:
# hand_full_indices = (
# np.expand_dims(np.arange(A.shape[0]), (0, 1, -1, -2)),
# np.expand_dims(np.array([[0], [2], [1]]), (-1, -2, -3)),
# np.expand_dims(np.arange(A.shape[2]), (0, 1, 2, -1)),
# np.expand_dims(np.array([2, 1]), (-1, -2, -3)),
# np.expand_dims(np.arange(2, 3), (0, 1, 2, 3)),
# )
# bcast_full_indices = tuple(np.broadcast_arrays(*hand_full_indices))
# [idx.shape for idx in hand_full_indices]
# [idx.shape for idx in bcast_full_indices]
# [idx.shape for idx in full_indices]
# A[indices].shape
# # This works!
# assert np.array_equiv(A[bcast_full_indices], A[indices])
# assert np.array_equal(A[bcast_full_indices], A[indices])
A_parts = (
np.random.normal(size=(4, 3)),
np.random.normal(size=(4, 3)),
np.random.normal(size=(4, 3)),
)
A = np.stack(A_parts)
indices = (np.array([[0, 1], [2, 2]]), slice(None), np.array([[0, 1], [2, 2]]))
full_indices = expand_indices(indices, A.shape)
assert len(full_indices) == A.ndim
assert np.array_equal(A[full_indices], A[indices])
# Let's do it by hand:
# hand_full_indices = (
# np.expand_dims(np.array([[0, 1], [2, 2]]), (-1,)),
# # While this is a slice for the second dimension, it is effectively "moved"
# # to the dimension *after* the broadcasted subspace created by the advanced
# # dimensions.
# np.expand_dims(np.arange(A.shape[1]), (0, 1)),
# np.expand_dims(np.array([[0, 1], [2, 2]]), (-1,)),
# )
# [s.shape for s in hand_full_indices]
# bcast_full_indices = tuple(np.broadcast_arrays(*hand_full_indices))
# [s.shape for s in bcast_full_indices]
# A[indices].shape
# # This works!
# assert np.array_equiv(A[bcast_full_indices], A[indices])
# assert np.array_equal(A[bcast_full_indices], A[indices])
def test_expand_indices_single_indices():
A_parts = (
np.random.normal(size=(4, 3)),
np.random.normal(size=(4, 3)),
np.random.normal(size=(4, 3)),
)
A = np.stack(A_parts)
indices = (slice(2, 3), np.array([0, 1, 2]), 1)
full_indices = expand_indices(indices, A.shape)
assert len(full_indices) == A.ndim
assert np.array_equal(A[full_indices], A[indices])
# Let's do it by hand:
# hand_full_indices = (
# np.expand_dims(np.arange(2, 3), (-1,)),
# np.expand_dims(np.array([0, 1, 2]), (0,)),
# np.expand_dims(np.array(1), (0,)),
# )
# bcast_full_indices = tuple(np.broadcast_arrays(*hand_full_indices))
# [idx.shape for idx in hand_full_indices]
# [idx.shape for idx in bcast_full_indices]
# [idx.shape for idx in full_indices]
# assert np.array_equiv(A[bcast_full_indices], A[indices])
# assert np.array_equal(A[bcast_full_indices], A[indices])
A_parts = (
np.random.normal(size=(4, 3)),
np.random.normal(size=(4, 3)),
np.random.normal(size=(4, 3)),
)
A = np.stack(A_parts)
indices = (slice(2, 3), 1, np.array([0, 1, 2]))
full_indices = expand_indices(indices, A.shape)
assert len(full_indices) == A.ndim
assert np.array_equal(A[full_indices], A[indices])
A_parts = (
np.random.normal(size=(4, 3)),
np.random.normal(size=(4, 3)),
np.random.normal(size=(4, 3)),
)
A = np.stack(A_parts)
# A.shape
# (3, 4, 3)
indices = (1, slice(2, 3), np.array([0, 1, 2]))
exp_res = A[indices]
# exp_res.shape
# (3, 1)
full_indices = expand_indices(indices, A.shape)
res = A[full_indices]
assert len(full_indices) == A.ndim
assert np.array_equal(res, exp_res)
# Let's do it by hand:
# hand_full_indices = (
# np.expand_dims(np.array(1), (-1,)),
# np.expand_dims(np.arange(2, 3), (0,)),
# np.expand_dims(np.array([0, 1, 2]), (-1,)),
# )
# bcast_full_indices = tuple(np.broadcast_arrays(*hand_full_indices))
# [idx.shape for idx in bcast_full_indices]
# [idx.shape for idx in full_indices]
# assert np.array_equiv(A[bcast_full_indices].flat, A[indices].flat)
# assert np.array_equal(A[bcast_full_indices], A[indices])
A_parts = (
np.random.normal(size=(4, 3)),
np.random.normal(size=(4, 3)),
np.random.normal(size=(4, 3)),
)
A = np.stack(A_parts)
# A.shape
# (3, 4, 3)
indices = (np.random.randint(2, size=(4, 3)), 1, 0)
exp_res = A[indices]
# exp_res.shape
# (4, 3)
full_indices = expand_indices(indices, A.shape)
res = A[full_indices]
assert len(full_indices) == A.ndim
assert np.array_equal(res, exp_res)
def reorder_index(A_parts, indices, join_index=0):
"""Compute `A[indices]` for `A = np.concatenate(A_parts)`."""
A_shape = list(A_parts[0].shape)
A_shape.insert(join_index, len(A_parts))
bcast_indices = expand_indices(indices, A_shape)
res = np.empty(bcast_indices[0].shape)
for m, A_part in enumerate(A_parts):
# Get the indices for group-`m` entries in the indices' first dimensions
# (i.e. the dimension on the indexed array's, `A`, join axis)
m_0 = np.nonzero(bcast_indices[join_index] == m)
# Get the corresponding group-`m` indices for all the other dimensions
m_idx = tuple(v[m_0] for i, v in enumerate(bcast_indices) if i != join_index)
# Apply the group-`m` indices to the group-`m` subspace in the indexed
# array (i.e. `A`).
res[m_0] = A_part[m_idx]
return res
def test_reorder_index():
A_parts = (
np.random.normal(size=(4, 3)),
np.random.normal(size=(4, 3)),
np.random.normal(size=(4, 3)),
)
A = np.stack(A_parts)
indices = (np.random.randint(2, size=(4, 3)), 1, 0)
res = reorder_index(A_parts, indices)
assert np.array_equal(res, A[indices])
A = np.stack(A_parts, axis=1)
res = reorder_index(A_parts, indices, join_index=1)
assert np.array_equal(res, A[indices])
indices = (np.random.randint(2, size=(4, 3)),)
A = np.stack(A_parts, axis=1)
res = reorder_index(A_parts, indices, join_index=1)
assert np.array_equal(res, A[indices])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment