Skip to content

Instantly share code, notes, and snippets.

@killeent
Last active July 19, 2017 21:44
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save killeent/5e32a6fa20e2809ffeeb0d0a5d91167b to your computer and use it in GitHub Desktop.
Save killeent/5e32a6fa20e2809ffeeb0d0a5d91167b to your computer and use it in GitHub Desktop.

PyTorch now supports a subset of NumPy style advanced indexing. This allows users to select arbitrary indices at each dimension of the Tensor, including non-adjacent indices and duplicate indices, using the same []-style operation. This allows for a more flexible indexing strategy without needing calls to PyTorch's Index[Select, Add, ...] functions.

x = torch.Tensor(5, 5, 5)

# Pure Integer Array Indexing - specify arbitrary indices at each dim
x[[1, 2], [3, 2], [1, 0]] 
--> yields a 2-element Tensor (x[1][3][1], x[2][2][0])

# also supports broadcasting, duplicates
x[[2, 3, 2], [0], [1]]
--> yields a 3-element Tensor (x[2][0][1], x[3][0][1], x[2][0][1])

# arbitrary indexer shapes allowed
x[[[1, 0], [0, 1]], [0], [1]].shape 
--> yields a 2x2 Tensor [[x[1][0][1], x[0][0][1]],
                         [x[0][0][1], x[1][0][1]]]

# can use colon, ellipse
x[[0, 3], :, :]
x[[0, 3], ...]
--> both yield a 2x5x5 Tensor [x[0], x[3]]

# also use Tensors to index!
y = torch.LongTensor([0, 2, 4])
x[y, :, :]
--> yields a 3x5x5 Tensor [x[0], x[2], x[4]]

# selection with less than ndim, note use of comma to differentiate
x[[1, 3], ]
--> yields a 2x5x5 Tensor [x[1], x[3]]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment