Skip to content

Instantly share code, notes, and snippets.

@HarshSingh16
Created February 6, 2019 00:15
Show Gist options
  • Save HarshSingh16/504bffca38ba936e3e1f75cef9c65102 to your computer and use it in GitHub Desktop.
Save HarshSingh16/504bffca38ba936e3e1f75cef9c65102 to your computer and use it in GitHub Desktop.
# Decoding the test/validation set
def decode_test_set(encoder_state, decoder_cell, decoder_embeddings_matrix, sos_id, eos_id, maximum_length, num_words, decoding_scope, output_function, keep_prob, batch_size):
attention_states = tf.zeros([batch_size, 1, decoder_cell.output_size])
attention_keys, attention_values, attention_score_function, attention_construct_function = tf.contrib.seq2seq.prepare_attention(attention_states, attention_option = "bahdanau", num_units = decoder_cell.output_size)
test_decoder_function = tf.contrib.seq2seq.attention_decoder_fn_inference(output_function,
encoder_state[0],
attention_keys,
attention_values,
attention_score_function,
attention_construct_function,
decoder_embeddings_matrix,
sos_id,
eos_id,
maximum_length,
num_words,
name = "attn_dec_inf")
test_predictions, decoder_final_state, decoder_final_context_state = tf.contrib.seq2seq.dynamic_rnn_decoder(decoder_cell,
test_decoder_function,
scope = decoding_scope)
return test_predictions
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment