Created
April 28, 2019 14:17
-
-
Save ChunML/166784ce5a8448e901de879a8cc70a3e to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def predict(test_source_text=None): | |
# If test sentence is not provided | |
# randomly pick up one from the training data | |
if test_source_text is None: | |
test_source_text = raw_data_en[np.random.choice(len(raw_data_en))] | |
print(test_source_text) | |
# Tokenize the test sentence to obtain source sequence | |
test_source_seq = en_tokenizer.texts_to_sequences([test_source_text]) | |
print(test_source_seq) | |
en_output = encoder(tf.constant(test_source_seq)) | |
de_input = tf.constant([[fr_tokenizer.word_index['<start>']]], dtype=tf.int64) | |
out_words = [] | |
while True: | |
de_output = decoder(de_input, en_output) | |
# Take the last token as the predicted token | |
new_word = tf.expand_dims(tf.argmax(de_output, -1)[:, -1], axis=1) | |
out_words.append(fr_tokenizer.index_word[new_word.numpy()[0][0]]) | |
# The next input is a new sequence | |
# contains both the input sequence and the predicted token | |
de_input = tf.concat((de_input, new_word), axis=-1) | |
# End if hitting <end> or length exceeds 14 | |
if out_words[-1] == '<end>' or len(out_words) >= 14: | |
break | |
print(' '.join(out_words)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment