Skip to content

Instantly share code, notes, and snippets.

@ayushidalmia
Created December 15, 2017 14:37
Show Gist options
  • Save ayushidalmia/bd3e039ed836f1eae50db18a07b752a0 to your computer and use it in GitHub Desktop.
Save ayushidalmia/bd3e039ed836f1eae50db18a07b752a0 to your computer and use it in GitHub Desktop.
def decoder(self,encoder_outputs,dropout,mode):
cell, decoder_initial_state = self.decoder_cell(encoder_outputs,dropout,mode)
if mode != "INFER":
tgt_text = self.tgt_text
if self.time_major:
tgt_text = tf.transpose(self.tgt_text)
inputs = tf.nn.embedding_lookup(self.tgt_embedding,tgt_text)
helper = tf.contrib.seq2seq.TrainingHelper(inputs, self.tgt_sequence_length, time_major=True)
basic_decoder = tf.contrib.seq2seq.BasicDecoder(cell, helper, decoder_initial_state)
outputs, final_context_state, _ = tf.contrib.seq2seq.dynamic_decode(basic_decoder, output_time_major=self.time_major)
sample_id = outputs.sample_id
logits = self.output_layer(outputs.rnn_output)
## To add attention
else:
start_tokens = tf.fill([self.batch_size], self.start_token)
end_token = self.end_token
maximum_iterations = tf.round(tf.reduce_max(self.src_sequence_length) * 2)
## Currently only beam search decoder
beam_decoder = tf.contrib.seq2seq.BeamSearchDecoder(cell=cell, embedding=self.tgt_embedding, start_tokens=start_tokens, end_token=end_token,initial_state=decoder_initial_state, beam_width=self.beam_width,output_layer=self.output_layer)
outputs, final_context_state, _ = tf.contrib.seq2seq.dynamic_decode( beam_decoder, maximum_iterations=maximum_iterations,output_time_major=self.time_major)
logits = tf.no_op()
sample_id = outputs.predicted_ids
return logits, sample_id, final_context_state
def decoder_cell(self,encoder_state,dropout,mode):
cell = self.create_rnn_cell(dropout)
if mode == "INFER" and self.beam_width!=0:
decoder_initial_state = tf.contrib.seq2seq.tile_batch(encoder_state, multiplier= self.beam_width)
else:
decoder_initial_state = encoder_state
return cell, decoder_initial_state
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment