Skip to content

Instantly share code, notes, and snippets.

@queirozfcom
Created November 21, 2019 01:22
Show Gist options
  • Save queirozfcom/20d76e3113c649660df8dc1e59455680 to your computer and use it in GitHub Desktop.
Save queirozfcom/20d76e3113c649660df8dc1e59455680 to your computer and use it in GitHub Desktop.
encoder-decoder-layers
###################################
############ MODEL ################
###################################
with tf.device("/device:GPU:0"):
# this has been suggested for tensorflow 2.0
tf.keras.backend.clear_session()
# Define an input sequence and process it.
encoder_inputs = tf.keras.layers.Input(shape=(MAX_SEQ_LEN_KW,),name="encoder_input")
encoder_embedding_layer = tf.keras.layers.Embedding(
VOCABULARY_SIZE_KW,
EMBEDDING_DIMS,
mask_zero=True,
name="encoder_embedding"
)
encoder_embedding = encoder_embedding_layer(encoder_inputs)
_, state_h, state_c = tf.keras.layers.LSTM(
EMBEDDING_DIMS,
return_state=True,
name="encoder_lstm")(encoder_embedding)
encoder_states = [state_h, state_c]
decoder_inputs = tf.keras.layers.Input(shape=(MAX_SEQ_LEN_TITLE,),name="decoder_input")
decoder_embedding_layer = tf.keras.layers.Embedding(
VOCABULARY_SIZE_TITLE,
EMBEDDING_DIMS,
mask_zero=True,
name="decoder_embedding")
decoder_embedding = decoder_embedding_layer(decoder_inputs)
decoder_lstm = tf.keras.layers.LSTM(
EMBEDDING_DIMS,
return_sequences=True,
return_state=True,
name="decoder_lstm")
decoder_outputs, _, _ = decoder_lstm(decoder_embedding, initial_state=encoder_states)
decoder_dense = tf.keras.layers.Dense(VOCABULARY_SIZE_TITLE, activation='softmax',name="output_layer")
output = decoder_dense(decoder_outputs)
model = tf.keras.models.Model([encoder_inputs, decoder_inputs], output)
model.compile(optimizer='rmsprop', loss='sparse_categorical_crossentropy')
model.summary()
model.fit([keyword_sequences, title_sequences], decoder_target_data,
batch_size=BATCH_SIZE,
epochs=NUM_EPOCHS,
validation_split=0.0,
verbose=2)
#######################################
######## INFERENCE MODELS #############
#######################################
encoder_model = tf.keras.models.Model(encoder_inputs, encoder_states)
decoder_state_input_h = tf.keras.layers.Input(shape=(EMBEDDING_DIMS ,),name="encoder_input_hidden")
decoder_state_input_c = tf.keras.layers.Input(shape=(EMBEDDING_DIMS ,),name="encoder_input_cell")
decoder_states_inputs = [decoder_state_input_h, decoder_state_input_c]
decoder_outputs, state_h, state_c = decoder_lstm(
decoder_embedding , initial_state=decoder_states_inputs)
decoder_states = [state_h, state_c]
decoder_outputs = decoder_dense(decoder_outputs)
decoder_model = tf.keras.models.Model(
[decoder_inputs] + decoder_states_inputs,
[decoder_outputs] + decoder_states)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment