Skip to content

Instantly share code, notes, and snippets.

@aravindpai
Created May 27, 2019 06:50
Show Gist options
  • Save aravindpai/a9485ae25f1686f5a531abfc7b669766 to your computer and use it in GitHub Desktop.
Save aravindpai/a9485ae25f1686f5a531abfc7b669766 to your computer and use it in GitHub Desktop.
from keras import backend as K
K.clear_session()
latent_dim = 500
# Encoder
encoder_inputs = Input(shape=(max_len_text,))
enc_emb = Embedding(x_voc_size, latent_dim,trainable=True)(encoder_inputs)
#LSTM 1
encoder_lstm1 = LSTM(latent_dim,return_sequences=True,return_state=True)
encoder_output1, state_h1, state_c1 = encoder_lstm1(enc_emb)
#LSTM 2
encoder_lstm2 = LSTM(latent_dim,return_sequences=True,return_state=True)
encoder_output2, state_h2, state_c2 = encoder_lstm2(encoder_output1)
#LSTM 3
encoder_lstm3=LSTM(latent_dim, return_state=True, return_sequences=True)
encoder_outputs, state_h, state_c= encoder_lstm3(encoder_output2)
# Set up the decoder.
decoder_inputs = Input(shape=(None,))
dec_emb_layer = Embedding(y_voc_size, latent_dim,trainable=True)
dec_emb = dec_emb_layer(decoder_inputs)
#LSTM using encoder_states as initial state
decoder_lstm = LSTM(latent_dim, return_sequences=True, return_state=True)
decoder_outputs,decoder_fwd_state, decoder_back_state = decoder_lstm(dec_emb,initial_state=[state_h, state_c])
#Attention Layer
Attention layer attn_layer = AttentionLayer(name='attention_layer')
attn_out, attn_states = attn_layer([encoder_outputs, decoder_outputs])
# Concat attention output and decoder LSTM output
decoder_concat_input = Concatenate(axis=-1, name='concat_layer')([decoder_outputs, attn_out])
#Dense layer
decoder_dense = TimeDistributed(Dense(y_voc_size, activation='softmax'))
decoder_outputs = decoder_dense(decoder_concat_input)
# Define the model
model = Model([encoder_inputs, decoder_inputs], decoder_outputs)
model.summary()
@PahadiASD
Copy link

#Attention Layer Attention layer attn_layer = AttentionLayer(name='attention_layer') attn_out, attn_states = attn_layer([encoder_outputs, decoder_outputs])

is not working getting invalid syntax error

have u got any answer please help

@AbdikadarM
Copy link

how can I change the above code in GRU

@AbdikadarM
Copy link

I need this can any one help me this

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment