Skip to content

Instantly share code, notes, and snippets.

@shahbazsyed
Created March 9, 2017 11:03
Show Gist options
  • Save shahbazsyed/b72368a3853d1f6744a03a2d9239669b to your computer and use it in GitHub Desktop.
Save shahbazsyed/b72368a3853d1f6744a03a2d9239669b to your computer and use it in GitHub Desktop.
An overview of sequence2sequence model with attention
## Parameters
batch_sze = 4
enc_layers = 4
enc_timesteps = 120
dec_timesteps = 30
min_lr=0.01, # min learning rate.
lr=0.15, # learning rate
min_input_len = 2
num_hidden = 256
emb_dim = 128
max_grad_norm = 2
num_softmax_samples = 4096
max_run_steps = 100000
max_article_sentences = 2
max_abstract_sentences = 100
beam_size = 4
num_gpus = 2
# Seq2Seq_Model
HParams = (mode,min_lr,lr, batch_size, enc_layers, enc_timesteps, dec_timesteps, min_input_len, num_hidden, emb_dim,
max_grad_norm, num_softmax_samples)
def _extract_argmax_and_embed(embedding, output_projection=None,update_embedding=True):
"""Get a loop_function that extracts the previous symbol and embeds it.
Args:
embedding: embedding tensor for symbols.
output_projection: None or a pair (W, B). If provided, each fed previous
output will first be multiplied by W and added B.
update_embedding: Boolean; if False, the gradients will not propagate
through the embeddings.
Returns:
A loop function.
"""
return loop_function
class Seq2SeqAttentionModel(object):
def __init__(self, hps, vocab, num_gpus=2):
def run_train_step(self, sess, article_batch, abstract_batch, targets,
article_lens, abstract_lens, loss_weights):
to_return = [self._train_op, self._summaries, self._loss, self.global_step]
return sess.run(to_return,
feed_dict={self._articles: article_batch,
self._abstracts: abstract_batch,
self._targets: targets,
self._article_lens: article_lens,
self._abstract_lens: abstract_lens,
self._loss_weights: loss_weights})
def run_decode_step(self, sess, article_batch, abstract_batch, targets,
article_lens, abstract_lens, loss_weights):
to_return = [self._outputs, self.global_step]
return sess.run(to_return,
feed_dict={self._articles: article_batch,
self._abstracts: abstract_batch,
self._targets: targets,
self._article_lens: article_lens,
self._abstract_lens: abstract_lens,
self._loss_weights: loss_weights})
def _add_placeholders(self):
"""Inputs to be fed to the graph."""
hps = self._hps
self._articles = tf.placeholder(tf.int32,
[hps.batch_size, hps.enc_timesteps],
name='articles')
self._abstracts = tf.placeholder(tf.int32,
[hps.batch_size, hps.dec_timesteps],
name='abstracts')
self._targets = tf.placeholder(tf.int32,
[hps.batch_size, hps.dec_timesteps],
name='targets')
self._article_lens = tf.placeholder(tf.int32, [hps.batch_size],
name='article_lens')
self._abstract_lens = tf.placeholder(tf.int32, [hps.batch_size],
name='abstract_lens')
self._loss_weights = tf.placeholder(tf.float32,
[hps.batch_size, hps.dec_timesteps],
name='loss_weights')
#####################################################################################################################
# Complete network
#####################################################################################################################
def __add_seq2seq(self):
hps = self._hps
vsize = self._vocab.NumIds() # Number of words in the vocabulary
with tf.variable_scope('seq2seq'):
encoder_inputs = tf.unpack(tf.transpose(self._articles))
decoder_inputs = tf.unpack(tf.transpose(self._abstracts))
targets = tf.unpack(tf.transpose(self._targets))
loss_weights = tf.unpack(tf.transpose(self._loss_weights))
article_lens = self._article_lens
with tf.variable_scope('embedding'), tf.device('/cpu:0'):
# Create a large matrix , fill it with random values
embedding = tf.get_variable(
'embedding', [vsize, hps.emb_dim], dtype=tf.float32,
initializer=tf.truncated_normal_initializer(stddev=1e-4))
# Look up each id in the large matrix, for each symbol in the encoder_inputs and decoder_inputs
emb_encoder_inputs = [tf.nn.embedding_lookup(embedding, x)
for x in encoder_inputs]
emb_decoder_inputs = [tf.nn.embedding_lookup(embedding, x)
for x in decoder_inputs]
#####################################################################################################################
# Encoder network
#####################################################################################################################
# Create 4 bi directional layers for Encoder
for layer_i in xrange(hps.enc_layers):
with tf.variable_scope('encoder%d'%layer_i), tf.device(
self._next_device()):
cell_fw = tf.nn.rnn_cell.LSTMCell(
hps.num_hidden,
initializer=tf.random_uniform_initializer(-0.1, 0.1, seed=123),
state_is_tuple=False)
cell_bw = tf.nn.rnn_cell.LSTMCell(
hps.num_hidden,
initializer=tf.random_uniform_initializer(-0.1, 0.1, seed=113),
state_is_tuple=False)
(emb_encoder_inputs, fw_state, _) = tf.nn.bidirectional_rnn(
cell_fw, cell_bw, emb_encoder_inputs, dtype=tf.float32,
sequence_length=article_lens)
encoder_outputs = emb_encoder_inputs # Output of previous layer
#####################################################################################################################
# decoder network
#####################################################################################################################
with tf.variable_scope('decoder'), tf.device(self._next_device()):
# When decoding, use model output from the previous step
# for the next step.
loop_function = None
if hps.mode == 'decode':
loop_function = _extract_argmax_and_embed(
embedding, (w, v), update_embedding=False)
cell = tf.nn.rnn_cell.LSTMCell(
hps.num_hidden,
initializer=tf.random_uniform_initializer(-0.1, 0.1, seed=113),
state_is_tuple=False)
encoder_outputs = [tf.reshape(x, [hps.batch_size, 1, 2*hps.num_hidden])
for x in encoder_outputs]
self._enc_top_states = tf.concat(1, encoder_outputs)
self._dec_in_state = fw_state
# During decoding, follow up _dec_in_state are fed from beam_search.
# dec_out_state are stored by beam_search for next step feeding.
initial_state_attention = (hps.mode == 'decode')
"""
tf.contrib.legacy_seq2seq.attention_decoder(decoder_inputs, initial_state, attention_states, cell,
output_size=None, num_heads=1, loop_function=None, dtype=None, scope=None, initial_state_attention=False)
Attention here means that during decoding, the RNN can look up for additional information from the attention_states
tensor.
"""
decoder_outputs, self._dec_out_state = tf.nn.seq2seq.attention_decoder(
emb_decoder_inputs, self._dec_in_state, self._enc_top_states,
cell, num_heads=1, loop_function=loop_function,
initial_state_attention=initial_state_attention)
# Multiply decoder_outputs with w , add v, and return (batch_size, outputs)
with tf.variable_scope('output'), tf.device(self._next_device()):
model_outputs = []
for i in xrange(len(decoder_outputs)):
if i > 0:
tf.get_variable_scope().reuse_variables()
model_outputs.append(tf.nn.xw_plus_b(decoder_outputs[i], w, v))
if hps.mode == 'decode':
with tf.variable_scope('decode_output'), tf.device('/cpu:0'):
best_outputs = [tf.argmax(x, 1) for x in model_outputs]
tf.logging.info('best_outputs%s', best_outputs[0].get_shape())
self._outputs = tf.concat(
1, [tf.reshape(x, [hps.batch_size, 1]) for x in best_outputs])
# Pick top 8 probablities
self._topk_log_probs, self._topk_ids = tf.nn.top_k(
tf.log(tf.nn.softmax(model_outputs[-1])), hps.batch_size*2)
def encode_top_state(self, sess, enc_inputs, enc_len):
"""Return the top states from encoder for decoder.
Args:
sess: tensorflow session.
enc_inputs: encoder inputs of shape [batch_size, enc_timesteps].
enc_len: encoder input length of shape [batch_size]
Returns:
enc_top_states: The top level encoder states.
dec_in_state: The decoder layer initial state.
"""
results = sess.run([self._enc_top_states, self._dec_in_state],
feed_dict={self._articles: enc_inputs,
self._article_lens: enc_len})
return results[0], results[1][0]
def decode_topk(self, sess, latest_tokens, enc_top_states, dec_init_states):
"""Return the topK results and new decoder states."""
feed = {
self._enc_top_states: enc_top_states,
self._dec_in_state:
np.squeeze(np.array(dec_init_states)),
self._abstracts:
np.transpose(np.array([latest_tokens])),
self._abstract_lens: np.ones([len(dec_init_states)], np.int32)}
results = sess.run(
[self._topk_ids, self._topk_log_probs, self._dec_out_state],
feed_dict=feed)
ids, probs, states = results[0], results[1], results[2]
new_states = [s for s in states]
return ids, probs, new_states
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment