Last active
May 27, 2019 07:03
-
-
Save aravindpai/5bab81aa7d617d1be8d194566cb6b868 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# encoder inference | |
encoder_model = Model(inputs=encoder_inputs,outputs=[encoder_outputs, state_h, state_c]) | |
# decoder inference | |
# Below tensors will hold the states of the previous time step | |
decoder_state_input_h = Input(shape=(latent_dim,)) | |
decoder_state_input_c = Input(shape=(latent_dim,)) | |
decoder_hidden_state_input = Input(shape=(max_len_text,latent_dim)) | |
# Get the embeddings of the decoder sequence | |
dec_emb2= dec_emb_layer(decoder_inputs) | |
# To predict the next word in the sequence, set the initial states to the states from the previous time step | |
decoder_outputs2, state_h2, state_c2 = decoder_lstm(dec_emb2, initial_state=[decoder_state_input_h, decoder_state_input_c]) | |
#attention inference | |
attn_out_inf, attn_states_inf = attn_layer([decoder_hidden_state_input, decoder_outputs2]) | |
decoder_inf_concat = Concatenate(axis=-1, name='concat')([decoder_outputs2, attn_out_inf]) | |
# A dense softmax layer to generate prob dist. over the target vocabulary | |
decoder_outputs2 = decoder_dense(decoder_inf_concat) | |
# Final decoder model | |
decoder_model = Model( | |
[decoder_inputs] + [decoder_hidden_state_input,decoder_state_input_h, decoder_state_input_c], | |
[decoder_outputs2] + [state_h2, state_c2]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment