Created
May 8, 2017 04:48
-
-
Save 0xpantera/867886274cee90fe61b23aab2bb3354b to your computer and use it in GitHub Desktop.
Second dlnd project submission: after aplying changes suggested
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def decoding_layer(dec_embed_input, dec_embeddings, encoder_state, vocab_size, sequence_length, rnn_size, | |
num_layers, target_vocab_to_int, keep_prob): | |
""" | |
Create decoding layer | |
:param dec_embed_input: Decoder embedded input | |
:param dec_embeddings: Decoder embeddings | |
:param encoder_state: The encoded state | |
:param vocab_size: Size of vocabulary | |
:param sequence_length: Sequence Length | |
:param rnn_size: RNN Size | |
:param num_layers: Number of layers | |
:param target_vocab_to_int: Dictionary to go from the target words to an id | |
:param keep_prob: Dropout keep probability | |
:return: Tuple of (Training Logits, Inference Logits) | |
""" | |
lstm = tf.contrib.rnn.BasicLSTMCell(rnn_size) | |
dropout = tf.contrib.rnn.DropoutWrapper(lstm, keep_prob) | |
cell = tf.contrib.rnn.MultiRNNCell([dropout] * num_layers) | |
with tf.variable_scope("decoding") as decoding_scope: | |
#Output layer | |
output = lambda x: tf.contrib.layers.fully_connected(x, vocab_size, None, scope=decoding_scope) | |
# Training decoder | |
training_logits = decoding_layer_train(encoder_state, cell, dec_embed_input, sequence_length, decoding_scope, | |
output, keep_prob) | |
decoding_scope.reuse_variables() | |
# Inference decoder | |
inference_logits = decoding_layer_infer(encoder_state, cell, dec_embeddings, target_vocab_to_int['<GO>'], | |
target_vocab_to_int["<EOS>"], sequence_length, vocab_size, | |
decoding_scope, output, keep_prob) | |
return (training_logits, inference_logits) | |
""" | |
DON'T MODIFY ANYTHING IN THIS CELL THAT IS BELOW THIS LINE | |
""" | |
tests.test_decoding_layer(decoding_layer) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment