Skip to content

Instantly share code, notes, and snippets.

@NickShargan
Last active October 20, 2017 22:04
Show Gist options
  • Save NickShargan/e9d334319d794f74fd09892aade30a7b to your computer and use it in GitHub Desktop.
Save NickShargan/e9d334319d794f74fd09892aade30a7b to your computer and use it in GitHub Desktop.
from keras.applications.inception_v3 import InceptionV3
# from keras.applications.resnet50 import ResNet50
from keras.applications.mobilenet import MobileNet
from keras.preprocessing import image
from keras.models import Model
from keras.layers import Dense, GlobalAveragePooling2D
from keras.preprocessing.image import ImageDataGenerator
from keras.callbacks import ModelCheckpoint
from keras import backend as K
import matplotlib.pyplot as plt
img_size = 224
# create the base pre-trained model
# base_model = ResNet50(weights='imagenet', include_top=False)
base_model = MobileNet(alpha=0.5, include_top=False, input_shape=(img_size, img_size, 3))
# add a global spatial average pooling layer
x = base_model.output
x = GlobalAveragePooling2D()(x)
# let's add a fully-connected layer
x = Dense(1024, activation='relu')(x)
# and a logistic layer -- let's say we have 200 classes
predictions = Dense(2, activation='softmax')(x)
# this is the model we will train
model = Model(inputs=base_model.input, outputs=predictions)
# first: train only the top layers (which were randomly initialized)
# i.e. freeze all convolutional InceptionV3 layers
for layer in base_model.layers:
layer.trainable = False
# compile the model (should be done *after* setting layers to non-trainable)
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
# read data
batch_size = 64
train_datagen = ImageDataGenerator(
rescale=1./255,
# shear_range=0.2,
# zoom_range=0.2,
horizontal_flip=True
)
train_generator = train_datagen.flow_from_directory(
'./data/train',
target_size=(img_size, img_size),
batch_size=batch_size,
class_mode='categorical')
test_datagen = ImageDataGenerator(rescale=1. / 255)
validation_generator = test_datagen.flow_from_directory(
'./data/val',
target_size=(img_size, img_size),
batch_size=batch_size,
class_mode='categorical'
)
filepath="./weights-improvement-{epoch:02d}-{val_acc:.2f}.hdf5"
checkpoint = ModelCheckpoint(filepath, monitor='val_acc', verbose=1, save_best_only=True, mode='max')
callbacks_list = [checkpoint]
# train the model on the new data for a few epochs
history = model.fit_generator(
train_generator,
steps_per_epoch=3200 // batch_size,
epochs=10,
callbacks=callbacks_list,
validation_data=validation_generator,
nb_val_samples=800 // batch_size
)
print("history")
# list all data in history
print(history.history.keys())
# summarize history for accuracy
plt.plot(history.history['acc'])
plt.plot(history.history['val_acc'])
plt.title('model accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.savefig("fig_acc.png")
plt.show()
# summarize history for loss
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.savefig("fig_loss.png")
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment