Skip to content

Instantly share code, notes, and snippets.

@Murgio
Last active April 11, 2018 18:19
Show Gist options
  • Save Murgio/29c5c14c8ff113e7d70dd4b2370362ff to your computer and use it in GitHub Desktop.
Save Murgio/29c5c14c8ff113e7d70dd4b2370362ff to your computer and use it in GitHub Desktop.
Freeze, Pre-train and Finetune(FPT)
import os.path
from keras.applications.inception_v3 import InceptionV3
from keras.optimizers import SGD
from keras.preprocessing.image import ImageDataGenerator
from keras.models import Model
from keras.layers import Dense, GlobalAveragePooling2D
from keras.callbacks import ModelCheckpoint, TensorBoard, EarlyStopping
# Helper: Save the min val_loss model in each epoch.
checkpointer = ModelCheckpoint(
filepath='inceptionV3.{epoch:03d}-{val_loss:.2f}.hdf5',
verbose=1,
save_best_only=True)
# Helper: Stop when we stop learning.
# patience: number of epochs with no improvement after which training will be stopped.
early_stopper = EarlyStopping(patience=10)
# Helper: TensorBoard
tensorboard = TensorBoard(log_dir='logs')
def get_generators():
train_datagen = ImageDataGenerator(
rescale=1./255,
shear_range=0.2,
horizontal_flip=True,
rotation_range=10.,
width_shift_range=0.2,
height_shift_range=0.2)
test_datagen = ImageDataGenerator(rescale=1./255)
train_generator = train_datagen.flow_from_directory(
'train',
target_size=(299, 299),
batch_size=64,
classes=data.classes,
class_mode='categorical')
validation_generator = test_datagen.flow_from_directory(
'test',
target_size=(299, 299),
batch_size=64,
classes=data.classes,
class_mode='categorical')
return train_generator, validation_generator
def get_model(weights='imagenet'):
'''
Freeze and Pretrain: First replace the last layer with
a pooling layer, fc-layer and output layer.
Now, freeze all the pretrained layers and train the new network.
'''
# create the base pre-trained model
base_model = InceptionV3(weights=weights, include_top=False)
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(1024, activation='relu')(x)
predictions = Dense(230, activation='softmax')(x)
# this is the model we will train
model = Model(inputs=base_model.input, outputs=predictions)
for layer in base_model.layers:
layer.trainable = False
# compile the model (should be done *after* setting layers to non-trainable)
model.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics=['accuracy', 'top_k_categorical_accuracy'])
return model
def fine_tune_inception_layer(model):
'''
Finetune: Load the pretrained weights and
train the complete network with a smaller learning rate.
'''
# we chose to train the top 2 inception blocks, i.e. we will freeze
# the first 172 layers and unfreeze the rest:
for layer in model.layers[:172]:
layer.trainable = False
for layer in model.layers[172:]:
layer.trainable = True
# we need to recompile the model for these modifications to take effect
# we use SGD with a low learning rate
model.compile(
optimizer=SGD(lr=0.0001, momentum=0.9),
loss='categorical_crossentropy',
metrics=['accuracy', 'top_k_categorical_accuracy'])
return model
def train_model(model, nb_epoch, generators, callbacks=[]):
train_generator, validation_generator = generators
model.fit_generator(
train_generator,
validation_data=validation_generator,
epochs=nb_epoch,
callbacks=callbacks)
return model
def main(weights_file):
model = get_model()
generators = get_generators()
if weights_file is None:
print("Training Top layers.")
model = train_model(model, 8, generators)
else:
print("Loading saved model: %s." % weights_file)
model.load_weights(weights_file)
# Get and train the mid layers.
model = fine_tune_inception_layer(model)
model = train_model(model, 64, generators,[checkpointer, early_stopper, tensorboard])
if __name__ == '__main__':
weights_file = 'inceptionv3_fpf_1.h5'
if not os.path.isfile(weights_file):
weights_file = None
main(weights_file)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment