Skip to content

Instantly share code, notes, and snippets.

@SohanChy
Last active April 7, 2018 15:53
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save SohanChy/fb1a2bb485387addf8de694bc5f11c1b to your computer and use it in GitHub Desktop.
Save SohanChy/fb1a2bb485387addf8de694bc5f11c1b to your computer and use it in GitHub Desktop.
Transfer Learning
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sat Aug 26 20:40:38 2017
@author: dhaval
Updated to KERAS 2 API on Sat Apr 7 9:53:00 2018
@contrib: sohanchy
"""
import os
import sys
import glob
import argparse
import matplotlib
#matplotlib.use('agg')
import matplotlib.pyplot as plt
from keras import backend as K
from keras import __version__
from keras.applications.inception_v3 import InceptionV3, preprocess_input
from keras.models import Model
from keras.layers import Dense, AveragePooling2D, GlobalAveragePooling2D, Input, Flatten, Dropout
from keras.preprocessing.image import ImageDataGenerator
from keras.optimizers import SGD
IM_WIDTH, IM_HEIGHT = 299, 299 #fixed size for InceptionV3
NB_EPOCHS = 3
BAT_SIZE = 32
FC_SIZE = 1024
#NB_IV3_LAYERS_TO_FREEZE = 172
def get_nb_files(directory):
"""Get number of files by searching directory recursively"""
if not os.path.exists(directory):
return 0
cnt = 0
for r, dirs, files in os.walk(directory):
for dr in dirs:
cnt += len(glob.glob(os.path.join(r, dr + "/*")))
return cnt
def setup_to_transfer_learn(model, base_model):
"""Freeze all layers and compile the model"""
for layer in base_model.layers:
layer.trainable = False
model.compile(optimizer='rmsprop',
loss='categorical_crossentropy',
metrics=['accuracy'])
def add_new_last_layer(base_model, nb_classes):
"""Add last layer to the convnet
Args:
base_model: keras model excluding top
nb_classes: # of classes
Returns:
new keras model with last layer
"""
x = base_model.output
x = AveragePooling2D(pool_size=(8, 8),padding='valid',name='avg_pool')(x)
x = Dropout(0.4)(x)
x = Flatten()(x)
predictions = Dense(nb_classes, activation='softmax')(x)
model = Model(inputs=base_model.input, outputs=predictions)
return model
def plot_training(history):
acc = history.history['acc']
val_acc = history.history['val_acc']
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs = range(len(acc))
plt.plot(epochs, acc, 'r.')
plt.plot(epochs, val_acc, 'r')
plt.title('Training and validation accuracy')
plt.savefig('accuracy.png')
plt.figure()
plt.plot(epochs, loss, 'r.')
plt.plot(epochs, val_loss, 'r-')
plt.title('Training and validation loss')
plt.savefig('loss.png')
"""
def setup_to_finetune(model):
Freeze the bottom NB_IV3_LAYERS and retrain the remaining top layers.
note: NB_IV3_LAYERS corresponds to the top 2 inception blocks in the inceptionv3 arch
Args:
model: keras model
for layer in model.layers[:NB_IV3_LAYERS_TO_FREEZE]:
layer.trainable = False
for layer in model.layers[NB_IV3_LAYERS_TO_FREEZE:]:
layer.trainable = True
model.compile(optimizer=SGD(lr=0.0001, momentum=0.9), loss='categorical_crossentropy', metrics=['accuracy'])
"""
def train(output_model_file,nb_epoch=5,doPlot=False,batch_size=32,train_dir="train",test_dir="test"):
"""Use transfer learning and fine-tuning to train a network on a new dataset"""
train_img = train_dir
validation_img = test_dir
nb_train_samples = get_nb_files(train_img)
nb_classes = len(glob.glob(train_img + "/*"))
# data prep
train_datagen = ImageDataGenerator(
rotation_range=40,
width_shift_range=0.2,
height_shift_range=0.2,
rescale=1./255,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True,
fill_mode='nearest')
validation_datagen = ImageDataGenerator(
rotation_range=40,
width_shift_range=0.2,
height_shift_range=0.2,
rescale=1./255,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True,
fill_mode='nearest')
train_generator = train_datagen.flow_from_directory(
train_img,
batch_size=batch_size,
target_size=(299, 299),
class_mode='categorical'
)
validation_generator = validation_datagen.flow_from_directory(
validation_img,
batch_size=batch_size,
target_size=(299, 299),
class_mode='categorical'
)
if(K.image_dim_ordering() == 'th'):
input_tensor = Input(shape=(3, 299, 299))
else:
input_tensor = Input(shape=(299, 299, 3))
# setup model
base_model = InceptionV3(input_tensor = input_tensor,weights='imagenet', include_top=False) #include_top=False excludes final FC layer
model = add_new_last_layer(base_model, nb_classes)
# transfer learning
setup_to_transfer_learn(model, base_model)
history_tl = model.fit_generator(train_generator,
steps_per_epoch=8,
epochs=nb_epoch,
validation_data=validation_generator)
model.save(output_model_file)
if doPlot:
plot_training(history_tl)
#EXAMPLE USE
# nb_epoch = 5
# output_model_file = "dhakaia_pola.model"
# batch_size = 32
# doPlot = True
# train(args)
# train(output_model_file,nb_epoch,doPlot,batch_size)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment