Skip to content

Instantly share code, notes, and snippets.

@d4l3k
Created February 17, 2023 18:03
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save d4l3k/44548e97ee11153c4be026f62c62d38e to your computer and use it in GitHub Desktop.
Save d4l3k/44548e97ee11153c4be026f62c62d38e to your computer and use it in GitHub Desktop.
A pytorch implementation of torch_gather_nd with multiple batch dim and multiple channel dim support.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and limitations under the License.
import unittest
class TestTorchGatherND(unittest.TestCase):
def test_torch_gather_nd(self) -> None:
# gather 1 dim
params = torch.arange(10)
indices = torch.tensor([[0], [2], [3]])
out = torch_gather_nd(params, indices, batch_dim=0)
self.assertEqual(out.tolist(), [0, 2, 3])
# gather 2 dims
params = torch.tensor([[1,2], [3, 4]])
indices = torch.tensor([[1, 1], [0, 0], [0, 1]])
out = torch_gather_nd(params, indices, batch_dim=0)
self.assertEqual(out.tolist(), [4, 1, 2])
# gather 1 dim, 1 ch dim
params = torch.tensor([[1,2], [2,3], [3, 4]])
indices = torch.tensor([[2], [0]])
out = torch_gather_nd(params, indices, batch_dim=0)
self.assertEqual(out.tolist(), [[3, 4], [1, 2]])
# gather 1 batch dim, 2 dims, 1 ch dim
params = torch.tensor([
[[1,2], [2,3], [3, 4]],
[[5,6], [7,8], [9, 10]],
])
indices = torch.tensor([[[2], [0]]]*2)
out = torch_gather_nd(params, indices, batch_dim=1)
self.assertEqual(out.tolist(), [
[[3, 4], [1,2]],
[[9, 10], [5, 6]],
])
# gather 2 batch dims, 2 dims, 1 ch dims
params = torch.rand(1, 2, 10, 6, 3)
indices = torch.zeros(1, 2, 5, 2, dtype=torch.long)
out = torch_gather_nd(params, indices, batch_dim=2)
self.assertEqual(out.shape, (1, 2, 5, 3))
# gather 2 dims, 2 ch dims
params = torch.rand(10, 6, 3, 4)
indices = torch.zeros(5, 2, dtype=torch.long)
out = torch_gather_nd(params, indices, batch_dim=0)
self.assertEqual(out.shape, (5, 3, 4))
def torch_gather_nd(params: torch.Tensor, indices: torch.Tensor, batch_dim: int = 0) -> torch.Tensor:
"""
torch_gather_nd implements tf.gather_nd in PyTorch.
This supports multiple batch dimensions as well as multiple channel dimensions.
"""
index_shape = indices.shape[:-1]
num_dim = indices.size(-1)
tail_sizes = params.shape[batch_dim+num_dim:]
# flatten extra dimensions
for s in tail_sizes:
row_indices = torch.arange(s, device=params.device)
indices = indices.unsqueeze(-2)
indices = indices.repeat(*[1 for _ in range(indices.dim()-2)], s, 1)
row_indices = row_indices.expand(*indices.shape[:-2], -1).unsqueeze(-1)
indices = torch.cat((indices, row_indices), dim=-1)
num_dim += 1
# flatten indices and params to batch specific ones instead of channel specific
for i in range(num_dim):
size = prod(params.shape[batch_dim+i+1:batch_dim+num_dim])
indices[..., i] *= size
indices = indices.sum(dim=-1)
params = params.flatten(batch_dim, -1)
indices = indices.flatten(batch_dim, -1)
out = torch.gather(params, dim=batch_dim, index=indices)
return out.reshape(*index_shape,*tail_sizes)
@WASSER2545
Copy link

What is prod() in 'size = prod(params.shape[batch_dim+i+1:batch_dim+num_dim])' ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment