Skip to content

Instantly share code, notes, and snippets.

@HarshSingh16
Created February 5, 2019 07:16
Show Gist options
  • Save HarshSingh16/58ac98de9e433da66ea8e9aea16e78cf to your computer and use it in GitHub Desktop.
Save HarshSingh16/58ac98de9e433da66ea8e9aea16e78cf to your computer and use it in GitHub Desktop.
###Decoding Training Set
def decode_training_set(encoder_state,decoder_cell,decoder_embedded_input,sequence_length,decoding_scope,output_function,keep_prob,batch_size):
attention_states=tf_zeros([batch_size,1,decoder_cell.output_size])
attention_keys,attention_values,attention_score_function,attention_construct_function=tf.contrib.seq2seq.prepare_attention(attention_states,
attention_option="bahdanau",
num_units=decoder_cell.output_size)
training_decoder_function=tf.contrib.seq2seq.attention_decoder_fn_train(encoder_state[0],
attention_keys,
attention_values,
attention_score_function,
attention_construct_function,
name="attn_dec_train")
decoder_ouput,_,_=tf.contrib.seq2seq.dynamic_rnn_decoder(decoder_cell,
training_decoder_function,
decoder_embedded_input,
sequence_length,
scope=decoding_scope)
decoder_output_dropout=tf.nn.dropout(decoder_output,keep_prob)
return output_function(decoder_output_dropout)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment