Skip to content

Instantly share code, notes, and snippets.

== Some context ==
Notes:
- numpy.take is just a slice operation. np.take(arr, indices, axis=3) -> arr[:, :, :, indices, ...]
- numpy.take can be implemented with numpy.take_along_axis via some reshape & broadcasting
- torch.take -> limited numpy.take where a is always flattened;
- torch.index_select -> limited numpy.take where axis is always given and indices is a 1-d array;
- torch.take_along_dim -> numpy.take_along_axis;
- torch.gather -> expanded numpy.take_along_axis where there's no requirement on length of indices to match input;