Created
January 15, 2018 12:38
-
-
Save RF5/bb48b4f8433ca67d9a144b8e31de3542 to your computer and use it in GitHub Desktop.
Tensorflow script to get Luong-style attention context vector from an rnn output sequence and a previous hidden state
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
The TL;DR: | |
What is this? A function to get the global Luong-style context vector for an attention mechanism in tensorflow. | |
Use case? tensorflow's attention api's are quite restrictive to seq2seq models and aren't very flexible or | |
require training helpers and sequence feeders that must be used to use their attention api properly. | |
So, this just gives you the context vector given an rnn output sequence and a previous rnn state (traditionally the decoder's last state) | |
Note: Scoring is done between the previous hidden state and the encoder rnn's outputs and not the encoder's hidden states, | |
although for many RNNCells the output is the hidden state | |
Author: Matthew Baas | |
Last updated: 2018-01-14 | |
""" | |
import tensorflow as tf | |
def attention_context_vector(encoder_outputs, previous_state, time_major=False, weight_initializer=tf.contrib.layers.xavier_initializer(), suprise_bias=False): | |
# Gets the context vector from the output of a dynamic rnn sequence (also returns probabilities for each rnn output - the alphas in the paper) | |
# based on Luong-style (multiplicative) attention scoring (https://arxiv.org/abs/1508.04025). | |
# suprise_bias - a little personal flavor. If suprise_bias is true, then include a bias term in the calculation of scores. i.e score(h', h) = h'Wh + b | |
# Motivation: if u want some intrinsic attention to be paid to a particular place in the sequence. | |
# e.g the application is such that the first elements in the encoder_output should be paid more attention to intrinsically. | |
# Note: indexing assumes multiple rnn cells are stacked on top of one another (i.e somewhere you are using MultiRNNCell or similar) | |
if type(encoder_outputs) == tuple: | |
# In case of Bi-RNN, concatenate the forward and the backward RNN outputs. | |
encoder_outputs = tf.concat(encoder_outputs, 2) | |
if time_major: | |
encoder_outputs = tf.array_ops.transpose(encoder_outputs, [1, 0, 2]) | |
# Note: assumes | |
hidden_size = previous_state[-1][1].shape[1].value # as shape[0] is batch size | |
output_size = encoder_outputs.shape[2].value | |
num_timesteps = encoder_outputs.shape[1].value | |
W_alpha = tf.get_variable('alpha', shape=[hidden_size, output_size], initializer=weight_initializer) | |
# rnn outputs should be shape (batch size, num_timesteps, output_size), now multiplied by W_alpha * previous_state[-1][1] which is (hidden_size, 1) | |
# to give the scores in shape (batch_size, num_timesteps) | |
# Note: h, the hidden state of the top layer of multi-rnn cells is previous_state[-1][1] | |
activ = tf.matmul(previous_state[-1][1], W_alpha) | |
activ = tf.expand_dims(activ, axis=-1) | |
scores = tf.matmul(encoder_outputs, activ) | |
scores = tf.reshape(scores, [-1, num_timesteps]) | |
if suprise_bias: | |
b_alpha = tf.get_variable("b_alpha", shape=[num_timesteps], initializer=weight_initializer) | |
scores = scores + b_alpha | |
alphas = tf.nn.softmax(scores) | |
alphas = tf.expand_dims(alphas, axis=-1) # ensure broadcasting works correctly | |
# alphas is (batch_size, num_timesteps, 1), encoder_outputs is (batch size, num_timesteps, output_size) | |
context_vec = tf.multiply(alphas, encoder_outputs) | |
context_vec = tf.reduce_sum(context_vec, axis=1) # axis 1 is the num_timesteps axis | |
return context_vec, alphas |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment