Skip to content

Instantly share code, notes, and snippets.

@cjlovering
Created December 8, 2020 17:40
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save cjlovering/6d25425d92ec2d1109d80566be4e46b8 to your computer and use it in GitHub Desktop.
Save cjlovering/6d25425d92ec2d1109d80566be4e46b8 to your computer and use it in GitHub Desktop.
Index into a 3D tensor (pytorch).

This lets you index into a 3D tensor and select a subset of the vectors. For example, say you wanted to select the final output from sequential data of different lengths that was packed together.

Set the indices to be the lengths of each sequence in the batch.

(Note: Normally you can use pack/unpack in pytorch, but this requires you to use their implementations of RNNs and this does not yet work with transformers.)

batch_size, seq_len, embed_dim = output.size()
selected = output[
	torch.arange(batch_size),
	indices,
	...,
]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment