Skip to content

Instantly share code, notes, and snippets.

@muety
Created July 27, 2019 15:51
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save muety/c3b9e9401f178807c91ad890a6c67e18 to your computer and use it in GitHub Desktop.
Save muety/c3b9e9401f178807c91ad890a6c67e18 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# https://blog.keras.io/building-powerful-image-classification-models-using-very-little-data.html"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import math\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"from keras.engine import Model\n",
"from keras.models import Sequential\n",
"from keras.layers import Dense, Dropout, Flatten\n",
"from keras.layers import Conv2D, MaxPooling2D\n",
"from keras.optimizers import Adam, SGD\n",
"from keras.callbacks import ModelCheckpoint, Callback\n",
"from keras.preprocessing.image import load_img, img_to_array, array_to_img, ImageDataGenerator\n",
"from keras_vggface.vggface import VGGFace"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"DATA_DIR = '../data/preprocessed/'\n",
"BATCH_SIZE = 32\n",
"GRAYSCALE = False\n",
"INPUT_DIM = (128, 128, 1 if GRAYSCALE else 3)\n",
"AUGMENTATION_FACTOR = 3\n",
"EPOCHS = 100\n",
"RANDOM_SEED = 123\n",
"LOAD_WEIGHTS = False"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"train_datagen = ImageDataGenerator(\n",
" rotation_range=10,\n",
" rescale=1./255,\n",
" shear_range=0.2,\n",
" zoom_range=0.2,\n",
" horizontal_flip=True)\n",
"\n",
"datagen = ImageDataGenerator(rescale=1./255)\n",
"\n",
"generator_base_params = {\n",
" 'target_size': INPUT_DIM[:2],\n",
" 'class_mode': 'categorical',\n",
" 'color_mode': 'grayscale' if GRAYSCALE else 'rgb',\n",
" 'batch_size': BATCH_SIZE,\n",
" 'seed': RANDOM_SEED\n",
"}\n",
"\n",
"train_generator = train_datagen.flow_from_directory(DATA_DIR + 'train', shuffle=True, **generator_base_params) \n",
"validation_generator = datagen.flow_from_directory(DATA_DIR + 'validation', shuffle=True, **generator_base_params)\n",
"test_generator = datagen.flow_from_directory(DATA_DIR + 'test', shuffle=True, **generator_base_params)\n",
"\n",
"n_train = train_generator.n\n",
"n_validation = validation_generator.n\n",
"n_test = test_generator.n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def get_model():\n",
" vgg_model = VGGFace(include_top=False, input_shape=INPUT_DIM, pooling='max')\n",
"\n",
" top_model = Sequential(name='top')\n",
" top_model.add(Dense(128, activation='relu', input_shape=vgg_model.output_shape[1:]))\n",
" top_model.add(Dropout(0.5))\n",
" top_model.add(Dense(4, activation='softmax'))\n",
" \n",
" if LOAD_WEIGHTS:\n",
" top_model.load_weights('top_model_weights.hdf5') # TODO retrain with two dense layers (256, 128) instead of one (64)\n",
" \n",
" for layer in vgg_model.layers[:-3]:\n",
" layer.trainable = False\n",
" \n",
" model = Sequential()\n",
" model.add(vgg_model)\n",
" model.add(top_model)\n",
" \n",
" opt = SGD(lr=1e-4, momentum=0.9)\n",
" model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['acc'])\n",
" \n",
" return model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model = get_model()\n",
"callbacks = [\n",
" ModelCheckpoint('final-{epoch:02d}.hdf5', monitor='val_acc', verbose=1, save_best_only=False, mode='max')\n",
"]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"history = model.fit_generator(\n",
" train_generator,\n",
" steps_per_epoch=n_train // BATCH_SIZE,\n",
" epochs=EPOCHS,\n",
" validation_data=validation_generator,\n",
" validation_steps=n_validation // BATCH_SIZE,\n",
" callbacks=callbacks)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"plt.plot(history.history['acc'])\n",
"plt.plot(history.history['val_acc'])\n",
"plt.title('Model accuracy')\n",
"plt.ylabel('Accuracy')\n",
"plt.xlabel('Epoch')\n",
"plt.legend(['Train', 'Test'], loc='upper left')\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"modelfiles = [f for f in os.listdir('.') if f.endswith('.hdf5') and f.startswith('final')]\n",
"for f in modelfiles:\n",
" model.load_weights(f)\n",
" result = model.evaluate_generator(\n",
" test_generator,\n",
" steps=n_test // BATCH_SIZE\n",
" )\n",
" print(f'{f}: {result[1]}')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model.load_weights('final-22-0.546.hdf5') # 0.546"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.7"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment