Skip to content

Instantly share code, notes, and snippets.

@ChunML
Created April 28, 2019 14:17
Show Gist options
  • Save ChunML/166784ce5a8448e901de879a8cc70a3e to your computer and use it in GitHub Desktop.
Save ChunML/166784ce5a8448e901de879a8cc70a3e to your computer and use it in GitHub Desktop.
def predict(test_source_text=None):
# If test sentence is not provided
# randomly pick up one from the training data
if test_source_text is None:
test_source_text = raw_data_en[np.random.choice(len(raw_data_en))]
print(test_source_text)
# Tokenize the test sentence to obtain source sequence
test_source_seq = en_tokenizer.texts_to_sequences([test_source_text])
print(test_source_seq)
en_output = encoder(tf.constant(test_source_seq))
de_input = tf.constant([[fr_tokenizer.word_index['<start>']]], dtype=tf.int64)
out_words = []
while True:
de_output = decoder(de_input, en_output)
# Take the last token as the predicted token
new_word = tf.expand_dims(tf.argmax(de_output, -1)[:, -1], axis=1)
out_words.append(fr_tokenizer.index_word[new_word.numpy()[0][0]])
# The next input is a new sequence
# contains both the input sequence and the predicted token
de_input = tf.concat((de_input, new_word), axis=-1)
# End if hitting <end> or length exceeds 14
if out_words[-1] == '<end>' or len(out_words) >= 14:
break
print(' '.join(out_words))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment