Skip to content

Instantly share code, notes, and snippets.

@rockt
Created August 25, 2016 19:45
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rockt/f4f9df5674f3da6a32786bcf9fbb6a88 to your computer and use it in GitHub Desktop.
Save rockt/f4f9df5674f3da6a32786bcf9fbb6a88 to your computer and use it in GitHub Desktop.
TensorFlow utility method to obtain the correct last outputs from a tf.nn.dynamic_rnn as determined by sequence length
def gather_by_lengths(outputs, seq_lengths):
"""
:param outputs: [batch_size x max_seq_length x hidden_size] tensor of dynamic_rnn outputs
:param seq_lengths: [batch_size] tensor of sequence lengths
:return: [batch_size x hidden_size] tensor of last outputs
"""
batch_size, max_seq_length, hidden_size = tf.unpack(tf.shape(outputs))
index = tf.range(0, batch_size) * max_seq_length + (seq_lengths - 1)
return tf.gather(tf.reshape(outputs, [-1, hidden_size]), index)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment