Skip to content

Instantly share code, notes, and snippets.

@fchollet
Last active October 24, 2019 14:19
Show Gist options
  • Star 18 You must be signed in to star a gist
  • Fork 8 You must be signed in to fork a gist
  • Save fchollet/0ecc151189b997fd4400bc2fecf2489f to your computer and use it in GitHub Desktop.
Save fchollet/0ecc151189b997fd4400bc2fecf2489f to your computer and use it in GitHub Desktop.
from keras.layers import Activation, Input, Dense, Embedding, merge, LSTM, Lambda
from keras.models import Model
from keras import backend as K
from deep_learning_models import VGG16
EMBEDDING_DIM = 3000
MAX_SEQUENCE_LENGTH = 100
GLOVE_MATRIX = ...
word_index = ...
VOC_SIZE = len(word_index) + 1
# assuming dim_ordering=tf
imagenet_model = VGG16(weights='imagenet')
image_input = imagenet_model.input # input for image associated with captions
imagenet_input = Input(shape=(224, 224, 3)) # input for ImageNet images in classification task
imagenet_preds = VGG16(imagenet_input) # 1000-way predictions for ImageNet images
visual_features = imagenet_model.get_layer('fc2').output # bottleneck features of VGG16
visual_word_predictions = Dense(VOC_SIZE, activation='softmax')(visual_features)
vision_model = Model(image_input, visual_word_predictions)
text_input = Input(shape=(MAX_SEQUENCE_LENGTH,), dtype='int32') # input for captions
embed_layer = Embedding(VOC_SIZE,
EMBEDDING_DIM,
weights=[GLOVE_MATRIX],
input_length=MAX_SEQUENCE_LENGTH,
trainable=False)
x = embed_layer(text_input)
inverse_embed = lambda x: K.dot(x, K.transpose(embed_layer.W))
x = LSTM(2048)(x)
x = Dense(EMBEDDING_DIM, activation='relu')(x)
x = Lambda(inverse_embed)(x)
lm_word_predictions = Activation('softmax')(x)
combined_predictions = merge([visual_word_predictions, lm_word_predictions], mode='concat')
final_predictions = Dense(voc_size, activation='softmax')(combined_predictions)
lm_model = Model(text_input, lm_word_predictions)
caption_model = Model([image_input, text_input], final_predictions)
imagenet_model = Model(imagenet_input, imagenet_preds)
multi_task_model = Model([imagenet_input, image_input, text_input],
[imagenet_preds, lm_word_predictions, final_predictions])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment