Skip to content

Instantly share code, notes, and snippets.

@maxrohleder
Created June 21, 2021 09:43
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save maxrohleder/f59611dee5947b68be40e39ca2c6f256 to your computer and use it in GitHub Desktop.
Save maxrohleder/f59611dee5947b68be40e39ca2c6f256 to your computer and use it in GitHub Desktop.
tensorflow slice assignment
# we have 1 batch of a stack of size 2 of images shape (3, 4)
test = tf.constant(np.arange(24).reshape(1, 2, 3, 4), dtype=tf.int64)
# <tf.Tensor: shape=(1, 2, 3, 4), dtype=int64, numpy=
# array([[[[ 0, 1, 2, 3],
# [ 4, 5, 6, 7],
# [ 8, 9, 10, 11]],
# [[12, 13, 14, 15],
# [16, 17, 18, 19],
# [20, 21, 22, 23]]]])>
# we have an update of shape (1, 3, 4) and we want to replace the ith position in the stack with it
update = tf.zeros((1, 3, 4), dtype=tf.int64)
# nulling the
tf.tensor_scatter_nd_update(test, [[0, 1]], update)
# <tf.Tensor: shape=(1, 2, 3, 4), dtype=int64, numpy=
# array([[[[ 0, 1, 2, 3],
# [ 4, 5, 6, 7],
# [ 8, 9, 10, 11]],
# [[ 0, 0, 0, 0],
# [ 0, 0, 0, 0],
# [ 0, 0, 0, 0]]]])>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment