This lets you index into a 3D tensor and select a subset of the vectors. For example, say you wanted to select the final output from sequential data of different lengths that was packed together.
Set the indices to be the lengths of each sequence in the batch.
(Note: Normally you can use pack/unpack in pytorch, but this requires you to use their implementations of RNNs and this does not yet work with transformers.)
batch_size, seq_len, embed_dim = output.size()
selected = output[
torch.arange(batch_size),
indices,
...,
]