Skip to content

Instantly share code, notes, and snippets.

@shshemi
Last active August 14, 2019 11:32
Show Gist options
  • Save shshemi/ccfa8bd65852b08928bbea5a77145ef4 to your computer and use it in GitHub Desktop.
Save shshemi/ccfa8bd65852b08928bbea5a77145ef4 to your computer and use it in GitHub Desktop.
def get_model(image_shape, sentence_len, dict_len):
# the encoder part
input_img = Input(image_shape)
input_sen = Input((sentence_len,))
embed_sen = Embedding(dict_len, 100)(input_sen)
embed_sen = Flatten()(embed_sen)
embed_sen = Reshape((image_shape[0], image_shape[1], 1))(embed_sen)
convd_img = Conv2D(20, 1, activation="relu")(input_img)
cat_tenrs = Concatenate(axis=-1)([embed_sen, convd_img])
out_img = Conv2D(3, 1, activation='relu', name='image_reconstruction')(cat_tenrs)
# the decoder part
decoder_model = Sequential(name="sentence_reconstruction")
decoder_model.add(Conv2D(1, 1, input_shape=(100, 100, 3)))
decoder_model.add(Reshape((sentence_len, 100)))
decoder_model.add(TimeDistributed(Dense(dict_len, activation="softmax")))
out_sen = decoder_model(out_img)
# creating models
model = Model(inputs=[input_img, input_sen], outputs=[out_img, out_sen])
model.compile('adam', loss=[mean_absolute_error, categorical_crossentropy], metrics={'sentence_reconstruction': categorical_accuracy})
encoder_model = Model(inputs=[input_img, input_sen], outputs=[out_img])
return model, encoder_model, decoder_model
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment