Skip to content

Instantly share code, notes, and snippets.

@alexminnaar
Created March 27, 2017 00:56
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 6 You must be signed in to fork a gist
  • Save alexminnaar/a286912efa417dfd0f25bd33992c3d6b to your computer and use it in GitHub Desktop.
Save alexminnaar/a286912efa417dfd0f25bd33992c3d6b to your computer and use it in GitHub Desktop.
Joint image/text classifier in Keras
import numpy as np
from keras.layers import Dropout
from keras import applications
from keras.layers import Dense, GlobalAveragePooling2D, merge, Input
from keras.models import Model
max_words = 10000
epochs = 50
batch_size = 32
X_train_image = ... #images training input
X_train_text = ... #text training input
y_train = ... #training output
num_classes = np.max(y_train) + 1
# Text input branch - just a simple MLP
text_inputs = Input(shape=(max_words,))
branch_1 = Dense(512, activation='relu')(text_inputs)
# Image input branch - a pre-trained Inception module followed by an added fully connected layer
base_model = applications.InceptionV3(weights='imagenet', include_top=False)
# Freeze Inception's weights - we don't want to train these
for layer in base_model.layers:
layer.trainable = False
# add a fully connected layer after Inception - we do want to train these
branch_2 = base_model.output
branch_2 = GlobalAveragePooling2D()(branch_2)
branch_2 = Dense(1024, activation='relu')(branch_2)
# merge the text input branch and the image input branch and add another fully connected layer
joint = merge([branch_1, branch_2], mode='concat')
joint = Dense(512, activation='relu')(joint)
joint = Dropout(0.5)(joint)
predictions = Dense(num_classes, activation='sigmoid')(joint)
full_model = Model(inputs=[base_model.input, text_inputs], outputs=[predictions])
full_model.compile(loss='categorical_crossentropy',
optimizer='rmsprop',
metrics=['accuracy'])
history = full_model.fit([X_train_image, X_train_text], y_train,
epochs=epochs, batch_size=batch_size,
verbose=1, validation_split=0.2, shuffle=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment