Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Replicating TensorFlow's scatter_nd function in PyTorch.
# See: https://www.tensorflow.org/api_docs/python/tf/scatter_nd.
# TensorFlow
import tensorflow as tf
sess = tf.InteractiveSession()
indices = tf.constant([[0, 1], [2, 3]])
updates = tf.constant([[5, 5, 5, 5],
[6, 6, 6, 6]])
shape = tf.constant([4, 4, 4])
scatter = tf.scatter_nd(indices, updates, shape)
print("TensorFlow")
print(sess.run(scatter), end = "\n\n")
# PyTorch
import torch
indices = torch.tensor([[0, 1], [2, 3]])
updates = torch.tensor([[5, 5, 5, 5],
[6, 6, 6, 6]])
result = torch.zeros((4, 4, 4), dtype = torch.int64)
result[indices[:, 0], indices[:, 1]] = updates
print("PyTorch")
print(result.numpy())
@siddharthanpr
Copy link

siddharthanpr commented Nov 16, 2018

Thas has a mistake. Test with indices = ([[0, 1], [0, 1]])

@lkhphuc
Copy link

lkhphuc commented Dec 3, 2018

Yes, the behavior of tf.scatter_nd when there is duplication in index is accumulating, while the fancy indexing in pytorch will just replace the value.
Does anyone know how to achieve the same as tf.scatter_nd using torch.Tensor.scatter_add_ ?
I already ask in this thread: https://discuss.pytorch.org/t/how-to-implement-tf-scatter-nd-function-of-tensorflow-in-pytorch/18358/5

@pg2455
Copy link

pg2455 commented Feb 6, 2020

You can use index_add_

@Mengxue12
Copy link

Mengxue12 commented May 25, 2022

this also works:

result[indices.t().numpy()]=updates

do not need torch.Tensor.scatter_()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment