Skip to content

Instantly share code, notes, and snippets.

@0x00b1
Last active January 31, 2024 19:34
Show Gist options
  • Save 0x00b1/f5b44e576652337f5d8a5f61f04d87e1 to your computer and use it in GitHub Desktop.
Save 0x00b1/f5b44e576652337f5d8a5f61f04d87e1 to your computer and use it in GitHub Desktop.
PyTorch `segment_sum` operator
from typing import Optional
import math
from torch import Tensor
import torch
def segment_sum(input: Tensor, indexes: Tensor, n: Optional[int] = None, **kwargs) -> Tensor:
if indexes.ndim == 1:
indexes = torch.repeat_interleave(indexes, math.prod([*input.shape[1:]])).view(*[indexes.shape[0], *input.shape[1:]])
if n is None:
n = max([*indexes]) + 1
return torch.zeros(n, *input.shape[1:]).scatter_add(0, indexes, input.to(torch.float32)).to(**kwargs)
@0x00b1
Copy link
Author

0x00b1 commented Jan 31, 2024

input = torch.tensor([[1, 2, 3, 4], [4, 3, 2, 1], [5, 6, 7, 8]])

indexes = torch.tensor([0, 0, 1])

segment_sum(input, indexes, 2, dtype=torch.int32)

tensor([[5, 5, 5, 5],
        [5, 6, 7, 8]], dtype=torch.int32)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment