Skip to content

Instantly share code, notes, and snippets.

@hadifar
Created December 6, 2018 22:57
Show Gist options
  • Save hadifar/8c2cfc7eed355dde33eecf894fe99635 to your computer and use it in GitHub Desktop.
Save hadifar/8c2cfc7eed355dde33eecf894fe99635 to your computer and use it in GitHub Desktop.
hidden_states = tf.scan(fn=rnn_step,
elems=tf.transpose(embed, perm=[1, 0, 2]), # change batch_size*seq_len*dim --> seq_len*batch_size*dim
initializer=tf.zeros([batch_size, rnn_size]))
outputs = tf.transpose(hidden_states, perm=[1, 0, 2]) # convert to original shape --> batch_size*seq_len*dim
last_rnn_output = outputs[:, -1, :] # extract last output
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment