Created
October 11, 2020 17:12
-
-
Save edumunozsala/60c379e9f8093d349a4405a7e596d876 to your computer and use it in GitHub Desktop.
Make predictions for out seq2seq with attention model
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def predict_seq2seq_att(input_text, input_max_len, tokenizer_inputs, word2idx_outputs, idx2word_outputs): | |
if input_text is None: | |
input_text = input_data[np.random.choice(len(input_data))] | |
print(input_text) | |
# Tokenize the input text | |
input_seq = tokenizer_inputs.texts_to_sequences([input_text]) | |
# Pad the sentence | |
input_seq = pad_sequences(input_seq, maxlen=input_max_len, padding='post') | |
# Get the encoder initial states | |
en_initial_states = encoder.init_states(1) | |
# Get the encoder outputs or hidden states | |
en_outputs = encoder(tf.constant(input_seq), en_initial_states) | |
# Set the decoder input to the sos token | |
de_input = tf.constant([[word2idx_outputs['<sos>']]]) | |
# Set the initial hidden states of the decoder to the hidden states of the encoder | |
de_state_h, de_state_c = en_outputs[1:] | |
out_words = [] | |
alignments = [] | |
while True: | |
# Get the decoder with attention output | |
de_output, de_state_h, de_state_c, alignment = decoder( | |
de_input, (de_state_h, de_state_c), en_outputs[0]) | |
de_input = tf.expand_dims(tf.argmax(de_output, -1), 0) | |
# Detokenize the output | |
out_words.append(idx2word_outputs[de_input.numpy()[0][0]]) | |
# Save the aligment matrix | |
alignments.append(alignment.numpy()) | |
if out_words[-1] == '<eos>' or len(out_words) >= 20: | |
break | |
# Join the output words | |
print(' '.join(out_words)) | |
return np.array(alignments), input_text.split(' '), out_words |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment