Skip to content

Instantly share code, notes, and snippets.

@EricCousineau-TRI
Created December 3, 2021 18:14
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save EricCousineau-TRI/cc2dc27c7413ea8e5b4fd9675050b1c0 to your computer and use it in GitHub Desktop.
Save EricCousineau-TRI/cc2dc27c7413ea8e5b4fd9675050b1c0 to your computer and use it in GitHub Desktop.
import einops
import torch
def vector_gather(vectors, indices):
"""
Gathers (batched) vectors according to indices.
Arguments:
vectors: Tensor[N, L, D]
indices: Tensor[N, K] or Tensor[N]
Returns:
Tensor[N, K, D] or Tensor[N, D]
"""
N, L, D = vectors.shape
squeeze = False
if indices.ndim == 1:
squeeze = True
indices = indices.unsqueeze(-1)
N2, K = indices.shape
assert N == N2
indices = einops.repeat(indices, "N K -> N K D", D=D)
out = torch.gather(vectors, dim=1, index=indices)
if squeeze:
out = out.squeeze(1)
return out
import unittest
import einops
import torch
class Test(unittest.TestCase):
def test_vector_gather(self):
vector = torch.tensor([
[
[1, 2, 3],
[4, 5, 6],
[7, 8, 9],
],
[
[10, 11, 12],
[13, 14, 15],
[16, 17, 18],
],
])
# Case 1: Squeezed (for argmin).
indices = torch.tensor([1, 2])
torch.testing.assert_allclose(
mut.vector_gather(vector, indices),
torch.tensor([
[4, 5, 6],
[16, 17, 18],
]),
atol=0,
rtol=0,
)
# Case 2: Unsqueezed (for multinomial, etc).
indices = torch.tensor([[1, 0], [2, 1]])
torch.testing.assert_allclose(
mut.vector_gather(vector, indices),
torch.tensor([
[
[4, 5, 6],
[1, 2, 3],
],
[
[16, 17, 18],
[13, 14, 15],
],
]),
atol=0,
rtol=0,
)
if __name__ == "__main__":
unittest.main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment