Last active
April 11, 2018 18:19
-
-
Save Murgio/29c5c14c8ff113e7d70dd4b2370362ff to your computer and use it in GitHub Desktop.
Freeze, Pre-train and Finetune(FPT)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os.path | |
from keras.applications.inception_v3 import InceptionV3 | |
from keras.optimizers import SGD | |
from keras.preprocessing.image import ImageDataGenerator | |
from keras.models import Model | |
from keras.layers import Dense, GlobalAveragePooling2D | |
from keras.callbacks import ModelCheckpoint, TensorBoard, EarlyStopping | |
# Helper: Save the min val_loss model in each epoch. | |
checkpointer = ModelCheckpoint( | |
filepath='inceptionV3.{epoch:03d}-{val_loss:.2f}.hdf5', | |
verbose=1, | |
save_best_only=True) | |
# Helper: Stop when we stop learning. | |
# patience: number of epochs with no improvement after which training will be stopped. | |
early_stopper = EarlyStopping(patience=10) | |
# Helper: TensorBoard | |
tensorboard = TensorBoard(log_dir='logs') | |
def get_generators(): | |
train_datagen = ImageDataGenerator( | |
rescale=1./255, | |
shear_range=0.2, | |
horizontal_flip=True, | |
rotation_range=10., | |
width_shift_range=0.2, | |
height_shift_range=0.2) | |
test_datagen = ImageDataGenerator(rescale=1./255) | |
train_generator = train_datagen.flow_from_directory( | |
'train', | |
target_size=(299, 299), | |
batch_size=64, | |
classes=data.classes, | |
class_mode='categorical') | |
validation_generator = test_datagen.flow_from_directory( | |
'test', | |
target_size=(299, 299), | |
batch_size=64, | |
classes=data.classes, | |
class_mode='categorical') | |
return train_generator, validation_generator | |
def get_model(weights='imagenet'): | |
''' | |
Freeze and Pretrain: First replace the last layer with | |
a pooling layer, fc-layer and output layer. | |
Now, freeze all the pretrained layers and train the new network. | |
''' | |
# create the base pre-trained model | |
base_model = InceptionV3(weights=weights, include_top=False) | |
x = base_model.output | |
x = GlobalAveragePooling2D()(x) | |
x = Dense(1024, activation='relu')(x) | |
predictions = Dense(230, activation='softmax')(x) | |
# this is the model we will train | |
model = Model(inputs=base_model.input, outputs=predictions) | |
for layer in base_model.layers: | |
layer.trainable = False | |
# compile the model (should be done *after* setting layers to non-trainable) | |
model.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics=['accuracy', 'top_k_categorical_accuracy']) | |
return model | |
def fine_tune_inception_layer(model): | |
''' | |
Finetune: Load the pretrained weights and | |
train the complete network with a smaller learning rate. | |
''' | |
# we chose to train the top 2 inception blocks, i.e. we will freeze | |
# the first 172 layers and unfreeze the rest: | |
for layer in model.layers[:172]: | |
layer.trainable = False | |
for layer in model.layers[172:]: | |
layer.trainable = True | |
# we need to recompile the model for these modifications to take effect | |
# we use SGD with a low learning rate | |
model.compile( | |
optimizer=SGD(lr=0.0001, momentum=0.9), | |
loss='categorical_crossentropy', | |
metrics=['accuracy', 'top_k_categorical_accuracy']) | |
return model | |
def train_model(model, nb_epoch, generators, callbacks=[]): | |
train_generator, validation_generator = generators | |
model.fit_generator( | |
train_generator, | |
validation_data=validation_generator, | |
epochs=nb_epoch, | |
callbacks=callbacks) | |
return model | |
def main(weights_file): | |
model = get_model() | |
generators = get_generators() | |
if weights_file is None: | |
print("Training Top layers.") | |
model = train_model(model, 8, generators) | |
else: | |
print("Loading saved model: %s." % weights_file) | |
model.load_weights(weights_file) | |
# Get and train the mid layers. | |
model = fine_tune_inception_layer(model) | |
model = train_model(model, 64, generators,[checkpointer, early_stopper, tensorboard]) | |
if __name__ == '__main__': | |
weights_file = 'inceptionv3_fpf_1.h5' | |
if not os.path.isfile(weights_file): | |
weights_file = None | |
main(weights_file) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment