Skip to content

Instantly share code, notes, and snippets.

@yaroslavvb
Created January 24, 2020 00:35
Show Gist options
  • Save yaroslavvb/3335705e63121f3f6e892a4a21bf8a6b to your computer and use it in GitHub Desktop.
Save yaroslavvb/3335705e63121f3f6e892a4a21bf8a6b to your computer and use it in GitHub Desktop.
index_reduce
def index_reduce(values, indices, dim):
"""Reduce values by selecting a single element in dimension dim:
Example below produces rank-2 tensor out of rank-3 values tensor by indexing as follows
dim=0: values[index[i,j],i,j]
dim=1: values[i,index[i,j],j]
dim=2: values[i,j,index[i,j]]
When all entries of "indices" are equal to p, the result is equivalent to slicing along that dimension.
dim=0: values[p,:,:]
dim=1: values[:,p,:]
"""
assert len(indices.shape) == len(values.shape) - 1
shape = list(values.shape)
del shape[dim]
assert np.prod(shape) == np.prod(indices.shape), f"not enough indices to reduce"
indices = indices.unsqueeze(dim)
vals = torch.gather(values, dim, indices)
return vals.squeeze(dim)
def test_index_reduce():
values = torch.arange(0, 8).reshape(2, 2, 2)
pos = 0
indices = pos*torch.ones(2, 2).long()
assert torch.allclose(index_reduce(values, indices, 0), values[pos, :, :])
assert torch.allclose(index_reduce(values, indices, 1), values[:, pos, :])
assert torch.allclose(index_reduce(values, indices, 2), values[:, :, pos])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment