Skip to content

Instantly share code, notes, and snippets.

@thomasahle
Created June 7, 2022 22:41
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save thomasahle/72a19301b91d85876990d8c326751305 to your computer and use it in GitHub Desktop.
Save thomasahle/72a19301b91d85876990d8c326751305 to your computer and use it in GitHub Desktop.
How to use pytorch scatter_add
[nav] In [48]: table = torch.arange(12, dtype=torch.float32).reshape(4,3)
[ins] In [49]: new_table = torch.zeros(4, 3)
[ins] In [50]: index = torch.tensor([1,1,0,3])
[ins] In [51]: index2 = index.unsqueeze(1).expand(4,3)
[ins] In [52]: table
Out[52]:
tensor([[ 0., 1., 2.],
[ 3., 4., 5.],
[ 6., 7., 8.],
[ 9., 10., 11.]])
[ins] In [53]: new_table
Out[53]:
tensor([[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]])
[ins] In [54]: index2
Out[54]:
tensor([[1, 1, 1],
[1, 1, 1],
[0, 0, 0],
[3, 3, 3]])
[ins] In [55]: new_table.scatter_add_(0, index2, table)
Out[55]:
tensor([[ 6., 7., 8.],
[ 3., 5., 7.],
[ 0., 0., 0.],
[ 9., 10., 11.]])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment