Created
March 9, 2017 11:03
-
-
Save shahbazsyed/b72368a3853d1f6744a03a2d9239669b to your computer and use it in GitHub Desktop.
An overview of sequence2sequence model with attention
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
## Parameters | |
batch_sze = 4 | |
enc_layers = 4 | |
enc_timesteps = 120 | |
dec_timesteps = 30 | |
min_lr=0.01, # min learning rate. | |
lr=0.15, # learning rate | |
min_input_len = 2 | |
num_hidden = 256 | |
emb_dim = 128 | |
max_grad_norm = 2 | |
num_softmax_samples = 4096 | |
max_run_steps = 100000 | |
max_article_sentences = 2 | |
max_abstract_sentences = 100 | |
beam_size = 4 | |
num_gpus = 2 | |
# Seq2Seq_Model | |
HParams = (mode,min_lr,lr, batch_size, enc_layers, enc_timesteps, dec_timesteps, min_input_len, num_hidden, emb_dim, | |
max_grad_norm, num_softmax_samples) | |
def _extract_argmax_and_embed(embedding, output_projection=None,update_embedding=True): | |
"""Get a loop_function that extracts the previous symbol and embeds it. | |
Args: | |
embedding: embedding tensor for symbols. | |
output_projection: None or a pair (W, B). If provided, each fed previous | |
output will first be multiplied by W and added B. | |
update_embedding: Boolean; if False, the gradients will not propagate | |
through the embeddings. | |
Returns: | |
A loop function. | |
""" | |
return loop_function | |
class Seq2SeqAttentionModel(object): | |
def __init__(self, hps, vocab, num_gpus=2): | |
def run_train_step(self, sess, article_batch, abstract_batch, targets, | |
article_lens, abstract_lens, loss_weights): | |
to_return = [self._train_op, self._summaries, self._loss, self.global_step] | |
return sess.run(to_return, | |
feed_dict={self._articles: article_batch, | |
self._abstracts: abstract_batch, | |
self._targets: targets, | |
self._article_lens: article_lens, | |
self._abstract_lens: abstract_lens, | |
self._loss_weights: loss_weights}) | |
def run_decode_step(self, sess, article_batch, abstract_batch, targets, | |
article_lens, abstract_lens, loss_weights): | |
to_return = [self._outputs, self.global_step] | |
return sess.run(to_return, | |
feed_dict={self._articles: article_batch, | |
self._abstracts: abstract_batch, | |
self._targets: targets, | |
self._article_lens: article_lens, | |
self._abstract_lens: abstract_lens, | |
self._loss_weights: loss_weights}) | |
def _add_placeholders(self): | |
"""Inputs to be fed to the graph.""" | |
hps = self._hps | |
self._articles = tf.placeholder(tf.int32, | |
[hps.batch_size, hps.enc_timesteps], | |
name='articles') | |
self._abstracts = tf.placeholder(tf.int32, | |
[hps.batch_size, hps.dec_timesteps], | |
name='abstracts') | |
self._targets = tf.placeholder(tf.int32, | |
[hps.batch_size, hps.dec_timesteps], | |
name='targets') | |
self._article_lens = tf.placeholder(tf.int32, [hps.batch_size], | |
name='article_lens') | |
self._abstract_lens = tf.placeholder(tf.int32, [hps.batch_size], | |
name='abstract_lens') | |
self._loss_weights = tf.placeholder(tf.float32, | |
[hps.batch_size, hps.dec_timesteps], | |
name='loss_weights') | |
##################################################################################################################### | |
# Complete network | |
##################################################################################################################### | |
def __add_seq2seq(self): | |
hps = self._hps | |
vsize = self._vocab.NumIds() # Number of words in the vocabulary | |
with tf.variable_scope('seq2seq'): | |
encoder_inputs = tf.unpack(tf.transpose(self._articles)) | |
decoder_inputs = tf.unpack(tf.transpose(self._abstracts)) | |
targets = tf.unpack(tf.transpose(self._targets)) | |
loss_weights = tf.unpack(tf.transpose(self._loss_weights)) | |
article_lens = self._article_lens | |
with tf.variable_scope('embedding'), tf.device('/cpu:0'): | |
# Create a large matrix , fill it with random values | |
embedding = tf.get_variable( | |
'embedding', [vsize, hps.emb_dim], dtype=tf.float32, | |
initializer=tf.truncated_normal_initializer(stddev=1e-4)) | |
# Look up each id in the large matrix, for each symbol in the encoder_inputs and decoder_inputs | |
emb_encoder_inputs = [tf.nn.embedding_lookup(embedding, x) | |
for x in encoder_inputs] | |
emb_decoder_inputs = [tf.nn.embedding_lookup(embedding, x) | |
for x in decoder_inputs] | |
##################################################################################################################### | |
# Encoder network | |
##################################################################################################################### | |
# Create 4 bi directional layers for Encoder | |
for layer_i in xrange(hps.enc_layers): | |
with tf.variable_scope('encoder%d'%layer_i), tf.device( | |
self._next_device()): | |
cell_fw = tf.nn.rnn_cell.LSTMCell( | |
hps.num_hidden, | |
initializer=tf.random_uniform_initializer(-0.1, 0.1, seed=123), | |
state_is_tuple=False) | |
cell_bw = tf.nn.rnn_cell.LSTMCell( | |
hps.num_hidden, | |
initializer=tf.random_uniform_initializer(-0.1, 0.1, seed=113), | |
state_is_tuple=False) | |
(emb_encoder_inputs, fw_state, _) = tf.nn.bidirectional_rnn( | |
cell_fw, cell_bw, emb_encoder_inputs, dtype=tf.float32, | |
sequence_length=article_lens) | |
encoder_outputs = emb_encoder_inputs # Output of previous layer | |
##################################################################################################################### | |
# decoder network | |
##################################################################################################################### | |
with tf.variable_scope('decoder'), tf.device(self._next_device()): | |
# When decoding, use model output from the previous step | |
# for the next step. | |
loop_function = None | |
if hps.mode == 'decode': | |
loop_function = _extract_argmax_and_embed( | |
embedding, (w, v), update_embedding=False) | |
cell = tf.nn.rnn_cell.LSTMCell( | |
hps.num_hidden, | |
initializer=tf.random_uniform_initializer(-0.1, 0.1, seed=113), | |
state_is_tuple=False) | |
encoder_outputs = [tf.reshape(x, [hps.batch_size, 1, 2*hps.num_hidden]) | |
for x in encoder_outputs] | |
self._enc_top_states = tf.concat(1, encoder_outputs) | |
self._dec_in_state = fw_state | |
# During decoding, follow up _dec_in_state are fed from beam_search. | |
# dec_out_state are stored by beam_search for next step feeding. | |
initial_state_attention = (hps.mode == 'decode') | |
""" | |
tf.contrib.legacy_seq2seq.attention_decoder(decoder_inputs, initial_state, attention_states, cell, | |
output_size=None, num_heads=1, loop_function=None, dtype=None, scope=None, initial_state_attention=False) | |
Attention here means that during decoding, the RNN can look up for additional information from the attention_states | |
tensor. | |
""" | |
decoder_outputs, self._dec_out_state = tf.nn.seq2seq.attention_decoder( | |
emb_decoder_inputs, self._dec_in_state, self._enc_top_states, | |
cell, num_heads=1, loop_function=loop_function, | |
initial_state_attention=initial_state_attention) | |
# Multiply decoder_outputs with w , add v, and return (batch_size, outputs) | |
with tf.variable_scope('output'), tf.device(self._next_device()): | |
model_outputs = [] | |
for i in xrange(len(decoder_outputs)): | |
if i > 0: | |
tf.get_variable_scope().reuse_variables() | |
model_outputs.append(tf.nn.xw_plus_b(decoder_outputs[i], w, v)) | |
if hps.mode == 'decode': | |
with tf.variable_scope('decode_output'), tf.device('/cpu:0'): | |
best_outputs = [tf.argmax(x, 1) for x in model_outputs] | |
tf.logging.info('best_outputs%s', best_outputs[0].get_shape()) | |
self._outputs = tf.concat( | |
1, [tf.reshape(x, [hps.batch_size, 1]) for x in best_outputs]) | |
# Pick top 8 probablities | |
self._topk_log_probs, self._topk_ids = tf.nn.top_k( | |
tf.log(tf.nn.softmax(model_outputs[-1])), hps.batch_size*2) | |
def encode_top_state(self, sess, enc_inputs, enc_len): | |
"""Return the top states from encoder for decoder. | |
Args: | |
sess: tensorflow session. | |
enc_inputs: encoder inputs of shape [batch_size, enc_timesteps]. | |
enc_len: encoder input length of shape [batch_size] | |
Returns: | |
enc_top_states: The top level encoder states. | |
dec_in_state: The decoder layer initial state. | |
""" | |
results = sess.run([self._enc_top_states, self._dec_in_state], | |
feed_dict={self._articles: enc_inputs, | |
self._article_lens: enc_len}) | |
return results[0], results[1][0] | |
def decode_topk(self, sess, latest_tokens, enc_top_states, dec_init_states): | |
"""Return the topK results and new decoder states.""" | |
feed = { | |
self._enc_top_states: enc_top_states, | |
self._dec_in_state: | |
np.squeeze(np.array(dec_init_states)), | |
self._abstracts: | |
np.transpose(np.array([latest_tokens])), | |
self._abstract_lens: np.ones([len(dec_init_states)], np.int32)} | |
results = sess.run( | |
[self._topk_ids, self._topk_log_probs, self._dec_out_state], | |
feed_dict=feed) | |
ids, probs, states = results[0], results[1], results[2] | |
new_states = [s for s in states] | |
return ids, probs, new_states | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment