Skip to content

Instantly share code, notes, and snippets.

@RF5
Created January 15, 2018 12:38
Show Gist options
  • Save RF5/bb48b4f8433ca67d9a144b8e31de3542 to your computer and use it in GitHub Desktop.
Save RF5/bb48b4f8433ca67d9a144b8e31de3542 to your computer and use it in GitHub Desktop.
Tensorflow script to get Luong-style attention context vector from an rnn output sequence and a previous hidden state
"""
The TL;DR:
What is this? A function to get the global Luong-style context vector for an attention mechanism in tensorflow.
Use case? tensorflow's attention api's are quite restrictive to seq2seq models and aren't very flexible or
require training helpers and sequence feeders that must be used to use their attention api properly.
So, this just gives you the context vector given an rnn output sequence and a previous rnn state (traditionally the decoder's last state)
Note: Scoring is done between the previous hidden state and the encoder rnn's outputs and not the encoder's hidden states,
although for many RNNCells the output is the hidden state
Author: Matthew Baas
Last updated: 2018-01-14
"""
import tensorflow as tf
def attention_context_vector(encoder_outputs, previous_state, time_major=False, weight_initializer=tf.contrib.layers.xavier_initializer(), suprise_bias=False):
# Gets the context vector from the output of a dynamic rnn sequence (also returns probabilities for each rnn output - the alphas in the paper)
# based on Luong-style (multiplicative) attention scoring (https://arxiv.org/abs/1508.04025).
# suprise_bias - a little personal flavor. If suprise_bias is true, then include a bias term in the calculation of scores. i.e score(h', h) = h'Wh + b
# Motivation: if u want some intrinsic attention to be paid to a particular place in the sequence.
# e.g the application is such that the first elements in the encoder_output should be paid more attention to intrinsically.
# Note: indexing assumes multiple rnn cells are stacked on top of one another (i.e somewhere you are using MultiRNNCell or similar)
if type(encoder_outputs) == tuple:
# In case of Bi-RNN, concatenate the forward and the backward RNN outputs.
encoder_outputs = tf.concat(encoder_outputs, 2)
if time_major:
encoder_outputs = tf.array_ops.transpose(encoder_outputs, [1, 0, 2])
# Note: assumes
hidden_size = previous_state[-1][1].shape[1].value # as shape[0] is batch size
output_size = encoder_outputs.shape[2].value
num_timesteps = encoder_outputs.shape[1].value
W_alpha = tf.get_variable('alpha', shape=[hidden_size, output_size], initializer=weight_initializer)
# rnn outputs should be shape (batch size, num_timesteps, output_size), now multiplied by W_alpha * previous_state[-1][1] which is (hidden_size, 1)
# to give the scores in shape (batch_size, num_timesteps)
# Note: h, the hidden state of the top layer of multi-rnn cells is previous_state[-1][1]
activ = tf.matmul(previous_state[-1][1], W_alpha)
activ = tf.expand_dims(activ, axis=-1)
scores = tf.matmul(encoder_outputs, activ)
scores = tf.reshape(scores, [-1, num_timesteps])
if suprise_bias:
b_alpha = tf.get_variable("b_alpha", shape=[num_timesteps], initializer=weight_initializer)
scores = scores + b_alpha
alphas = tf.nn.softmax(scores)
alphas = tf.expand_dims(alphas, axis=-1) # ensure broadcasting works correctly
# alphas is (batch_size, num_timesteps, 1), encoder_outputs is (batch size, num_timesteps, output_size)
context_vec = tf.multiply(alphas, encoder_outputs)
context_vec = tf.reduce_sum(context_vec, axis=1) # axis 1 is the num_timesteps axis
return context_vec, alphas
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment