Skip to content

Instantly share code, notes, and snippets.

@jcjohnson
Created February 14, 2017 02:51
Show Gist options
  • Save jcjohnson/cb3459c24ecd2b6573223e552903e68e to your computer and use it in GitHub Desktop.
Save jcjohnson/cb3459c24ecd2b6573223e552903e68e to your computer and use it in GitHub Desktop.
def index(x, axis, idxs):
"""
Inputs:
- x: torch.Tensor with x.dim() == N
- axis: Integer with 0 <= axis < N
- idxs: List of integers, with 0 <= idxs[i] < x.size(axis)
Returns:
y: torch.Tensor satisfying
y.select(axis, i) == x.select(axis, index[i])
"""
view_size = [1] * x.dim()
view_size[axis] = len(idxs)
view_size = torch.Size(view_size)
expand_size = list(x.size())
expand_size[axis] = len(idxs)
expand_size = torch.Size(expand_size)
idxs = torch.LongTensor(idxs).view(view_size).expand(expand_size)
return x.gather(axis, idxs)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment