Skip to content

Instantly share code, notes, and snippets.

@strubell
Created January 15, 2018 17:33
Show Gist options
  • Save strubell/d07f780449a6f4d2d735cb84514da024 to your computer and use it in GitHub Desktop.
Save strubell/d07f780449a6f4d2d735cb84514da024 to your computer and use it in GitHub Desktop.
Example using tf.gather_nd / tf.scatter_nd
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
def gather_scatter_example():
batch_size = 2
seq_len = 5
num_classes = 3
tf.set_random_seed(23)
logits = tf.random_normal([batch_size, seq_len, seq_len, num_classes])
gather_indices = tf.constant([[0, 1, 1],
[0, 1, 2],
[0, 2, 3],
[1, 0, 1]])
with tf.Session() as sess:
g = tf.gather_nd(logits, gather_indices)
results = sess.run([logits, tf.shape(logits), gather_indices, g])
for r in results:
print(r)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment