-
-
Save didacroyo/7e20d3b002c8597f2edafae28765be4f to your computer and use it in GitHub Desktop.
Fine Tuning with Keras from the pre-trained NASNet. Retraining the TOP N layers using my own dataset
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#from keras.applications.inception_v3 import InceptionV3 | |
from keras.applications.nasnet import NASNetMobile | |
from keras.preprocessing import image | |
from keras.models import Model | |
from keras.layers import Dense, GlobalAveragePooling2D | |
from keras.preprocessing.image import ImageDataGenerator | |
from keras import backend as K | |
from keras.callbacks import ModelCheckpoint | |
from keras.callbacks import EarlyStopping | |
from keras.callbacks import TensorBoard | |
import numpy as np | |
from sklearn.utils import class_weight | |
import os.path | |
import fnmatch | |
import itertools | |
import functools | |
import tensorflow as tf | |
# This custom callback allows me to see the 'training'(orange) and 'validation'(blue) lines plotted together | |
class TrainValTensorBoard(TensorBoard): | |
def __init__(self, log_dir='./logs', **kwargs): | |
# Make the original `TensorBoard` log to a subdirectory 'training' | |
training_log_dir = os.path.join(log_dir, 'training') | |
super(TrainValTensorBoard, self).__init__(training_log_dir, **kwargs) | |
# Log the validation metrics to a separate subdirectory | |
self.val_log_dir = os.path.join(log_dir, 'validation') | |
def set_model(self, model): | |
# Setup writer for validation metrics | |
self.val_writer = tf.summary.FileWriter(self.val_log_dir) | |
super(TrainValTensorBoard, self).set_model(model) | |
def on_epoch_end(self, epoch, logs=None): | |
# Pop the validation logs and handle them separately with | |
# `self.val_writer`. Also rename the keys so that they can | |
# be plotted on the same figure with the training metrics | |
logs = logs or {} | |
val_logs = {k.replace('val_', ''): v for k, v in logs.items() if k.startswith('val_')} | |
for name, value in val_logs.items(): | |
summary = tf.Summary() | |
summary_value = summary.value.add() | |
summary_value.simple_value = value.item() | |
summary_value.tag = name | |
self.val_writer.add_summary(summary, epoch) | |
self.val_writer.flush() | |
# Pass the remaining logs to `TensorBoard.on_epoch_end` | |
logs = {k: v for k, v in logs.items() if not k.startswith('val_')} | |
super(TrainValTensorBoard, self).on_epoch_end(epoch, logs) | |
def on_train_end(self, logs=None): | |
super(TrainValTensorBoard, self).on_train_end(logs) | |
self.val_writer.close() | |
# create the base pre-trained model -- | |
base_model = NASNetMobile(include_top=False, weights='imagenet', pooling='avg') | |
# dimensions of our images | |
img_width, img_height = 224, 224 | |
fine_tuned_checkpoint_path = 'cp.fine_tuned.best.hdf5' | |
new_extended_inception_weights = 'final_weights.hdf5' | |
train_data_dir = 'data/train' | |
validation_data_dir = 'data/validation' | |
# Dynamically get the count of samples in the training and validation directories | |
nb_train_samples = len(fnmatch.filter(os.listdir(train_data_dir + '/' + 'y'), '*')) + len(fnmatch.filter(os.listdir(train_data_dir + '/' + 'n'), '*')) | |
nb_validation_samples = len(fnmatch.filter(os.listdir(validation_data_dir + '/' + 'y'), '*')) + len(fnmatch.filter(os.listdir(validation_data_dir + '/' + 'n'), '*')) | |
top_epochs = 5 | |
fit_epochs = 50 | |
batch_size = 16 | |
# to compensate the imbalanced classes | |
class_weight = {0 : 1., 1: 2.} | |
# add a global spatial average pooling layer | |
x = base_model.output | |
# let's add a fully-connected layer | |
x = Dense(1024, activation='relu')(x) | |
# and a logistic layer -- let's say we have 2 classes | |
predictions = Dense(2, activation='softmax')(x) | |
# this is the model we will train | |
model = Model(inputs=base_model.input, outputs=predictions) | |
# first: train only the top layers (which were randomly initialized) | |
# i.e. freeze all convolutional NASNet layers | |
for layer in base_model.layers: | |
layer.trainable = False | |
# compile the model (should be done *after* setting layers to non-trainable) | |
model.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics=['accuracy']) | |
# prepare data augmentation configuration | |
train_datagen = ImageDataGenerator( | |
rescale=1./255, | |
vertical_flip=True, | |
fill_mode='nearest' | |
) | |
validation_datagen = ImageDataGenerator(rescale=1./255) | |
train_generator = train_datagen.flow_from_directory( | |
train_data_dir, | |
target_size=(img_height, img_width), | |
batch_size=batch_size, | |
class_mode='categorical') | |
validation_generator = validation_datagen.flow_from_directory( | |
validation_data_dir, | |
target_size=(img_height, img_width), | |
batch_size=batch_size, | |
class_mode='categorical') | |
#Save the model after every epoch. | |
mc_top = ModelCheckpoint(fine_tuned_checkpoint_path, monitor='val_acc', verbose=0, save_best_only=True, save_weights_only=False, mode='auto', period=1) | |
#Save the TensorBoard logs. | |
tb = TrainValTensorBoard() | |
early_stopping = EarlyStopping(monitor='val_loss', patience=15, verbose=1, mode='auto') | |
# train the model on the new data for a few epochs | |
hist_top = model.fit_generator( | |
train_generator, | |
steps_per_epoch=nb_train_samples // batch_size, | |
epochs=top_epochs, | |
validation_data=validation_generator, | |
validation_steps=nb_validation_samples // batch_size, | |
class_weight=class_weight, | |
callbacks=[mc_top]) | |
# at this point, the top layers are well trained and we can start fine-tuning | |
# convolutional layers from NasNet. We will freeze the bottom N layers | |
# and train the remaining top layers. | |
#let's visualize layer names and layer indices to see how many layers we should freeze for the 2nd phase of training with more trainable layers: | |
for i, layer in enumerate(base_model.layers): | |
print(i, layer.name) | |
#import sys | |
#sys.exit() | |
# Now we start the 2nd section: retraining the TOP N layers with our own dataset | |
#Save the model after every epoch. | |
mc_fit = ModelCheckpoint(fine_tuned_checkpoint_path, monitor='val_acc', verbose=0, save_best_only=True, save_weights_only=False, mode='auto', period=1) | |
if os.path.exists(fine_tuned_checkpoint_path): | |
model.load_weights(fine_tuned_checkpoint_path) | |
print ("Checkpoint '" + fine_tuned_checkpoint_path + "' loaded.") | |
# we chose to train the top N inception blocks, i.e. we will freeze | |
# the first 713 layers and unfreeze the rest: | |
for layer in model.layers[:713]: | |
layer.trainable = False | |
for layer in model.layers[713:]: | |
layer.trainable = True | |
# we need to recompile the model for these modifications to take effect | |
# we use SGD with a low learning rate | |
from keras.optimizers import SGD | |
model.compile(optimizer=SGD(lr=0.0001, momentum=0.9), loss='categorical_crossentropy', metrics=['accuracy']) | |
# we train our model again | |
hist_fit = model.fit_generator( | |
train_generator, | |
steps_per_epoch=nb_train_samples // batch_size, | |
epochs=fit_epochs, | |
validation_data=validation_generator, | |
validation_steps=nb_validation_samples // batch_size, | |
class_weight=class_weight, | |
callbacks=[mc_fit, tb, early_stopping]) | |
model.save_weights(new_extended_inception_weights) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hello,
I am also looking forward to fine-tuning NASnet Large on some Data,
Before trying it, I Just wanna confirm what accuracy/loss does this hyperparameter + Freeze layer config got for you.
Is it worth Trying?