Skip to content

Instantly share code, notes, and snippets.

@ThejanW
Last active April 18, 2018 10:17
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ThejanW/cab478ba848cbf6f7c2dfdf140a82530 to your computer and use it in GitHub Desktop.
Save ThejanW/cab478ba848cbf6f7c2dfdf140a82530 to your computer and use it in GitHub Desktop.
This script makes use of pre-trained Inception V3 network's weights to make your own image classifier
#!/usr/bin/env python
"""
This script makes use of pre-trained Inception V3 network's weights to make your own image classifier
Go through the script, make your own changes to the network as well as to the hyper-parameters
Dependencies: numpy
tensorflow
keras
Usage: python3 classifier.py #n_epochs
Replace #n_epochs with the number of epochs you need to train the network from the latest checkpoint
All checkpoints will be saved to checkpoints folder in the current directory
"""
import glob
import os
import sys
import numpy as np
from keras.applications.inception_v3 import InceptionV3
from keras.layers import Dense, Dropout, GlobalAveragePooling2D
from keras.models import Model, load_model
from keras.callbacks import ModelCheckpoint
def _create_network(train_inception_v3=False, n_classes=150):
x = InceptionV3(weights='imagenet', include_top=False)
y = x.layers[-1].output
y = GlobalAveragePooling2D(name='final_avg_pool_0')(y)
y = Dense(250, name='final_dense_0')(y)
y = Dropout(0.25, name='final_drop_0')(y)
y = Dense(n_classes, activation='softmax', name='final_dense_1')(y)
y = Model(inputs=x.input, outputs=y)
if not train_inception_v3:
for layer in y.layers:
layer.trainable = True if layer.name[0:6] == 'final_' else False
return y
def _get_fine_tune_network(weight_file_path):
y = load_model(weight_file_path)
for layer in y.layers:
layer.trainable = True
return y
n_classes = 150
model = None
initial_epoch = 0
ckpts = glob.glob("checkpoints/*.hdf5")
if len(ckpts) != 0:
# there are ckpts
latest_ckpt = max(ckpts, key=os.path.getctime)
print("loading checkpoint:", latest_ckpt)
initial_epoch = int(latest_ckpt[latest_ckpt.find("-epoch-") + len("-epoch-"):latest_ckpt.rfind("-train_acc-")])
model = load_model(latest_ckpt)
else:
# no ckpts
model = _create_network(train_inception_v3=False, n_classes=n_classes)
model.compile(optimizer="adam", loss='categorical_crossentropy', metrics=['accuracy'])
os.makedirs("checkpoints", exist_ok=True)
# uncomment to see the model architecture
# model.summary(line_length=100)
file_path = "checkpoints/ckpt-epoch-{epoch:05d}-train_acc-{acc:.4f}.hdf5"
checkpoint = ModelCheckpoint(file_path,
monitor=['acc'],
verbose=1,
save_best_only=False,
save_weights_only=False,
mode='auto',
period=2)
tensorboard = TensorBoard(log_dir="logs/{}".format(time()), histogram_freq=0)
# just for reference, form your data to this format
train_size, h, w = 100, 299, 299
X = np.random.random((train_size, h, w, 3))
Y = np.random.random((train_size, n_classes))
n_epochs_to_train = int(sys.argv[1])
if n_epochs_to_train <= initial_epoch:
n_epochs_to_train += initial_epoch
model.fit(X, Y, epochs=n_epochs_to_train, initial_epoch=initial_epoch, batch_size=10, callbacks=[checkpoint, tensorboard])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment