Created
May 8, 2017 04:47
-
-
Save 0xpantera/03b0c76885c4f9f4cf84e77470349808 to your computer and use it in GitHub Desktop.
First dlnd language translation submission
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def decoding_layer(dec_embed_input, dec_embeddings, encoder_state, vocab_size, sequence_length, rnn_size, | |
num_layers, target_vocab_to_int, keep_prob): | |
""" | |
Create decoding layer | |
:param dec_embed_input: Decoder embedded input | |
:param dec_embeddings: Decoder embeddings | |
:param encoder_state: The encoded state | |
:param vocab_size: Size of vocabulary | |
:param sequence_length: Sequence Length | |
:param rnn_size: RNN Size | |
:param num_layers: Number of layers | |
:param target_vocab_to_int: Dictionary to go from the target words to an id | |
:param keep_prob: Dropout keep probability | |
:return: Tuple of (Training Logits, Inference Logits) | |
""" | |
lstm = tf.contrib.rnn.BasicLSTMCell(rnn_size) | |
dropout = tf.contrib.rnn.DropoutWrapper(lstm, keep_prob) | |
cell = tf.contrib.rnn.MultiRNNCell([dropout] * num_layers) | |
with tf.variable_scope("decoding") as decoding_scope: | |
#Output layer | |
output = lambda x: tf.contrib.layers.fully_connected(x, vocab_size, None, scope=decoding_scope) | |
with tf.variable_scope("decoding") as decoding_scope: | |
# Training decoder | |
training_logits = decoding_layer_train(encoder_state, cell, dec_embed_input, sequence_length, decoding_scope, | |
output, keep_prob) | |
with tf.variable_scope("decoding", reuse=True) as decoding_scope: | |
# Inference decoder | |
inference_logits = decoding_layer_infer(encoder_state, cell, dec_embeddings, target_vocab_to_int['<GO>'], | |
target_vocab_to_int["<EOS>"], sequence_length, vocab_size, | |
decoding_scope, output, keep_prob) | |
return (training_logits, inference_logits) | |
""" | |
DON'T MODIFY ANYTHING IN THIS CELL THAT IS BELOW THIS LINE | |
""" | |
tests.test_decoding_layer(decoding_layer) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment