-
-
Save fmassa/f8158d1dfd25a8047c2c668a44ff57f4 to your computer and use it in GitHub Desktop.
First version of advanced indexing in PyTorch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from functools import reduce | |
from operator import mul | |
import numpy as np | |
def _linear_index(sizes, indices): | |
indices = [i.view(-1) for i in indices] | |
linear_idx = indices[0].new(indices[0].numel()).zero_() | |
stride = 1 | |
for i, idx in enumerate(indices[::-1], 1): | |
linear_idx += stride*idx | |
stride *= sizes[-i] | |
return linear_idx | |
def advanced_indexing(tensor, index): | |
if isinstance(index, tuple): | |
adv_loc = [] | |
for i, el in enumerate(index): | |
if isinstance(el, torch.LongTensor): | |
adv_loc.append((i, el)) | |
if len(adv_loc) < 2: | |
return tensor[index] | |
# check that number of elements in each indexing array is the same | |
len_array = [i.numel() for _, i in adv_loc] | |
#assert len_array.count(len_array[0]) == len(len_array) | |
idx = [i for i,_ in adv_loc] | |
sizes = [tensor.size(i) for i in idx] | |
new_size = [tensor.size(i) for i in range(tensor.dim()) if i not in idx] | |
new_size_final = [tensor.size(i) for i in range(tensor.dim()) if i not in idx] | |
start_idx = idx[0] | |
# if there is a space between the indexes | |
if idx[-1] - idx[0] + 1 != len(idx): | |
permute = idx + [i for i in range(tensor.dim()) if i not in idx] | |
tensor = tensor.permute(*permute).contiguous() | |
start_idx = 0 | |
lin_idx = _linear_index(sizes, [i for _, i in adv_loc]) | |
reduc_size = reduce(mul, sizes) | |
new_size.insert(start_idx, reduc_size) | |
new_size_final[start_idx:start_idx] = list(adv_loc[0][1].size()) | |
tensor = tensor.view(*new_size) | |
tensor = tensor.index_select(start_idx, lin_idx) | |
tensor = tensor.view(new_size_final) | |
return tensor | |
else: | |
return tensor[index] | |
def compare_numpy(t, idxs): | |
r = advanced_indexing(t, idxs).numpy() | |
np_idxs = tuple(i.tolist() if isinstance(i, torch.LongTensor) else i for i in idxs) | |
r2 = t.numpy()[np_idxs] | |
assert np.allclose(r, r2) | |
assert r.shape == r2.shape | |
t = torch.rand(3,3,3) | |
idx1 = torch.LongTensor([0,1]) | |
idx2 = torch.LongTensor([1,1]) | |
compare_numpy(t, (idx1, slice(0,3), idx2)) | |
compare_numpy(t, (slice(0,3), idx1, idx2)) | |
t = torch.rand(10,20,30,40,50) | |
idx_dim = (2,3,4) | |
idx1 = torch.LongTensor(*idx_dim).random_(0, 20-1) | |
idx2 = torch.LongTensor(*idx_dim).random_(0, 30-1) | |
compare_numpy(t, (slice(0, None), idx1, idx2)) | |
idx2 = torch.LongTensor(*idx_dim).random_(0, 40-1) | |
compare_numpy(t, (slice(0, None), idx1, slice(0, None), idx2)) | |
idx3 = torch.LongTensor(*idx_dim).random_(0, 50-1) | |
compare_numpy(t, (slice(0, None), idx1, slice(0, None), idx2, idx3)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
@fmassa what's the status of this feature in master?