Skip to content

Instantly share code, notes, and snippets.

@shgidi
Last active September 23, 2017 16:48
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save shgidi/0857388b9dea2608904195d71e360963 to your computer and use it in GitHub Desktop.
Save shgidi/0857388b9dea2608904195d71e360963 to your computer and use it in GitHub Desktop.
from keras.applications.vgg16 import VGG16
from keras.layers import Conv2D
from keras.models import Sequential
from keras.layers import BatchNormalization
from keras.optimizers import Adam
vgg=VGG16()
p=0.4 #dropout
label_count=17
def split_at(model, layer_type):
layers = model.layers
layer_idx = [index for index,layer in enumerate(layers)
if type(layer) is layer_type][-1]
return layers[:layer_idx+1], layers[layer_idx+1:]
conv_layers,fc_layers = split_at(vgg, Conv2D)
conv_model = Sequential(conv_layers)
def get_bn_layers(p):
return [
MaxPooling2D(input_shape=conv_layers[-1].output_shape[1:]),
BatchNormalization(axis=1),
Dropout(p/4),
Flatten(),
Dense(512, activation='relu'),
BatchNormalization(),
Dropout(p),
Dense(512, activation='relu'),
BatchNormalization(),
Dropout(p/2),
Dense(label_count, activation='softmax')
]
bn_model = Sequential(get_bn_layers(p))
bn_model.compile(Adam(lr=0.001), loss='categorical_crossentropy', metrics=['accuracy'])
bn_model.fit(trn, y, batch_size=64, nb_epoch=3, validation_data=(val, y_val))
bn_model.optimizer.lr = 1e-4
bn_model.fit(conv_feat, trn_labels, batch_size=batch_size, nb_epoch=7,
validation_data=(conv_val_feat, val_labels))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment