This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
== Some context == | |
Notes: | |
- numpy.take is just a slice operation. np.take(arr, indices, axis=3) -> arr[:, :, :, indices, ...] | |
- numpy.take can be implemented with numpy.take_along_axis via some reshape & broadcasting | |
- torch.take -> limited numpy.take where a is always flattened; | |
- torch.index_select -> limited numpy.take where axis is always given and indices is a 1-d array; | |
- torch.take_along_dim -> numpy.take_along_axis; | |
- torch.gather -> expanded numpy.take_along_axis where there's no requirement on length of indices to match input; |