Skip to content

Instantly share code, notes, and snippets.

@gngdb
Last active November 20, 2018 17:21
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save gngdb/3d4f5aa27ee5199b0d4b997ffe21a6b4 to your computer and use it in GitHub Desktop.
Save gngdb/3d4f5aa27ee5199b0d4b997ffe21a6b4 to your computer and use it in GitHub Desktop.
index_select and then reshaping is faster than just indexing?
import torch
if __name__ == '__main__':
X = torch.randn(100)
out_shape = (100,100)
idxs = torch.randint(high=100, size=out_shape).long()
assert torch.abs(X[idxs] - X.index_select(0, idxs.view(-1)).view(*out_shape)).max() < 1e-3
from timeit import timeit
setup = 'import torch; X = torch.randn(100); out_shape=(100,100); idxs = torch.randint(high=100, size=out_shape).long()'
print("X[idxs]: ", timeit("_ = X[idxs]", setup=setup, number=100))
print("X.index_select and reshape: ", timeit("_ = X.index_select(0, idxs.view(-1)).view(*out_shape)", setup=setup, number=100))
@gngdb
Copy link
Author

gngdb commented Nov 20, 2018

Results:

X[idxs]:  0.021933638956397772
X.index_select and reshape:  0.00275177089497447

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment