Skip to content

Instantly share code, notes, and snippets.

@bryanlimy
Created May 20, 2019 15:12
Show Gist options
  • Save bryanlimy/106a54bb573cbd2819aa1eb651d4c473 to your computer and use it in GitHub Desktop.
Save bryanlimy/106a54bb573cbd2819aa1eb651d4c473 to your computer and use it in GitHub Desktop.
def evaluate(sentence):
sentence = preprocess_sentence(sentence)
sentence = tf.expand_dims(
START_TOKEN + tokenizer.encode(sentence) + END_TOKEN, axis=0)
output = tf.expand_dims(START_TOKEN, 0)
for i in range(MAX_LENGTH):
predictions = model(inputs=[sentence, output], training=False)
# select the last word from the seq_len dimension
predictions = predictions[:, -1:, :]
predicted_id = tf.cast(tf.argmax(predictions, axis=-1), tf.int32)
# return the result if the predicted_id is equal to the end token
if tf.equal(predicted_id, END_TOKEN[0]):
break
# concatenated the predicted_id to the output which is given to the decoder as its input.
output = tf.concat([output, predicted_id], axis=-1)
return tf.squeeze(output, axis=0)
def predict(sentence):
prediction = evaluate(sentence)
predicted_sentence = tokenizer.decode([i for i in prediction if i < tokenizer.vocab_size])
return predicted_sentence
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment