Skip to content

Instantly share code, notes, and snippets.

@edumunozsala
Created October 11, 2020 17:12
Show Gist options
  • Save edumunozsala/60c379e9f8093d349a4405a7e596d876 to your computer and use it in GitHub Desktop.
Save edumunozsala/60c379e9f8093d349a4405a7e596d876 to your computer and use it in GitHub Desktop.
Make predictions for out seq2seq with attention model
def predict_seq2seq_att(input_text, input_max_len, tokenizer_inputs, word2idx_outputs, idx2word_outputs):
if input_text is None:
input_text = input_data[np.random.choice(len(input_data))]
print(input_text)
# Tokenize the input text
input_seq = tokenizer_inputs.texts_to_sequences([input_text])
# Pad the sentence
input_seq = pad_sequences(input_seq, maxlen=input_max_len, padding='post')
# Get the encoder initial states
en_initial_states = encoder.init_states(1)
# Get the encoder outputs or hidden states
en_outputs = encoder(tf.constant(input_seq), en_initial_states)
# Set the decoder input to the sos token
de_input = tf.constant([[word2idx_outputs['<sos>']]])
# Set the initial hidden states of the decoder to the hidden states of the encoder
de_state_h, de_state_c = en_outputs[1:]
out_words = []
alignments = []
while True:
# Get the decoder with attention output
de_output, de_state_h, de_state_c, alignment = decoder(
de_input, (de_state_h, de_state_c), en_outputs[0])
de_input = tf.expand_dims(tf.argmax(de_output, -1), 0)
# Detokenize the output
out_words.append(idx2word_outputs[de_input.numpy()[0][0]])
# Save the aligment matrix
alignments.append(alignment.numpy())
if out_words[-1] == '<eos>' or len(out_words) >= 20:
break
# Join the output words
print(' '.join(out_words))
return np.array(alignments), input_text.split(' '), out_words
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment