Skip to content

Instantly share code, notes, and snippets.

@MSWon
Last active June 15, 2019 14:12
Show Gist options
  • Save MSWon/ad28d40d5a057713c6763472bb505fc8 to your computer and use it in GitHub Desktop.
Save MSWon/ad28d40d5a057713c6763472bb505fc8 to your computer and use it in GitHub Desktop.
Slicing input tensor by its positions
import tensorflow as tf
import numpy as np
input_data = np.random.rand(10, 10, 64)
positions = [0,9,8,2,3,5,1,2,4,6]
input_tensor = tf.placeholder(shape = (None,10,64) , dtype = tf.float32)
positions_tensor = tf.placeholder(shape = (None,) , dtype = tf.int32)
def gather_indexes(input_tensor, positions):
"""Gathers the vectors at the specific positions over a minibatch."""
batch_size = tf.shape(input_tensor)[0]
seq_length = tf.shape(input_tensor)[1]
dim = tf.shape(input_tensor)[2]
flat_offsets = tf.reshape(tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1])
flat_positions = tf.reshape(tf.expand_dims(positions, axis=-1) + flat_offsets, [-1])
flat_sequence_tensor = tf.reshape(input_tensor,
[batch_size * seq_length, dim])
output_tensor = tf.gather(flat_sequence_tensor, flat_positions) ## slices tensor by positions
return output_tensor
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
result = sess.run(gather_indexes(input_tensor, positions_tensor),
feed_dict = {input_tensor: input_data,
positions_tensor: positions})
## check result
print(result[0])
print(input_data[0][positions[0]])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment