PyTorch now supports a subset of NumPy style advanced indexing. This allows users to select arbitrary indices at each dimension of the Tensor, including non-adjacent indices and duplicate indices, using the same []
-style operation. This allows for a more flexible indexing strategy without needing calls to PyTorch's Index[Select, Add, ...]
functions.
x = torch.Tensor(5, 5, 5)
# Pure Integer Array Indexing - specify arbitrary indices at each dim
x[[1, 2], [3, 2], [1, 0]]
--> yields a 2-element Tensor (x[1][3][1], x[2][2][0])
# also supports broadcasting, duplicates
x[[2, 3, 2], [0], [1]]
--> yields a 3-element Tensor (x[2][0][1], x[3][0][1], x[2][0][1])
# arbitrary indexer shapes allowed
x[[[1, 0], [0, 1]], [0], [1]].shape
--> yields a 2x2 Tensor [[x[1][0][1], x[0][0][1]],
[x[0][0][1], x[1][0][1]]]
# can use colon, ellipse
x[[0, 3], :, :]
x[[0, 3], ...]
--> both yield a 2x5x5 Tensor [x[0], x[3]]
# also use Tensors to index!
y = torch.LongTensor([0, 2, 4])
x[y, :, :]
--> yields a 3x5x5 Tensor [x[0], x[2], x[4]]
# selection with less than ndim, note use of comma to differentiate
x[[1, 3], ]
--> yields a 2x5x5 Tensor [x[1], x[3]]