Skip to content

Instantly share code, notes, and snippets.

@drsxr
Created October 30, 2017 17:15
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save drsxr/79a9b5192abaf49321e7bce9e359f68b to your computer and use it in GitHub Desktop.
Save drsxr/79a9b5192abaf49321e7bce9e359f68b to your computer and use it in GitHub Desktop.
CIFAR 10 Small VGG for testing purposes
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"'''\n",
"Cifar-10 classification - VGG-net adaptation\n",
"Original dataset and info: https://www.cs.toronto.edu/~kriz/cifar.html for more information\n",
"Code base attributed to Giuseppe Bonaccorso and modified\n",
"See: https://www.bonaccorso.eu/2016/08/06/cifar-10-image-classification-with-keras-convnet/ for further information\n",
"'''\n",
" \n",
"from __future__ import print_function\n",
"\n",
"import numpy as np\n",
"import timeit \n",
"\n",
"from keras.callbacks import EarlyStopping\n",
"from keras.datasets import cifar10\n",
"from keras.models import Sequential\n",
"from keras.layers.core import Dense, Dropout, Flatten\n",
"from keras.layers.convolutional import Conv2D\n",
"from keras.optimizers import Adam\n",
"from keras.layers.pooling import MaxPooling2D\n",
"from keras.utils import to_categorical"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# For reproducibility\n",
"np.random.seed(1000)\n",
"\n",
"if __name__ == '__main__':\n",
" # Load the dataset\n",
" (X_train, Y_train), (X_test, Y_test) = cifar10.load_data()\n",
"\n",
" # Create the model\n",
" model = Sequential()\n",
"\n",
" model.add(Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=(32, 32, 3)))\n",
" model.add(Conv2D(64, kernel_size=(3, 3), activation='relu'))\n",
" model.add(MaxPooling2D(pool_size=(2, 2)))\n",
" model.add(Dropout(0.5)) #we will adjust this dropout\n",
"\n",
" model.add(Conv2D(128, kernel_size=(3, 3), activation='relu'))\n",
" model.add(MaxPooling2D(pool_size=(2, 2)))\n",
" model.add(Conv2D(128, kernel_size=(3, 3), activation='relu'))\n",
" model.add(MaxPooling2D(pool_size=(2, 2)))\n",
" model.add(Dropout(0.5)) #we will adjust this dropout\n",
"\n",
" model.add(Flatten())\n",
" model.add(Dense(1024, activation='relu'))\n",
" model.add(Dropout(0.5)) #this dropout will remain constant\n",
" model.add(Dense(10, activation='softmax'))\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
" # Compile the model\n",
"model.compile(loss='categorical_crossentropy',\n",
" optimizer=Adam(lr=0.0001, decay=1e-6),\n",
" metrics=['accuracy'])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# Train the model\n",
"tic=timeit.default_timer()\n",
"model.fit(X_train / 255.0, to_categorical(Y_train),\n",
" batch_size=256, # we will adjust this batch size\n",
" shuffle=True,\n",
" epochs=250,\n",
" validation_data=(X_test / 255.0, to_categorical(Y_test)),\n",
" callbacks=[EarlyStopping(min_delta=0.001, patience=3)])\n",
"toc=timeit.default_timer()\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# Evaluate the model\n",
"scores = model.evaluate(X_test / 255.0, to_categorical(Y_test))\n",
"secs = (toc-tic)\n",
"print('Loss: %.3f' % scores[0])\n",
"print('Accuracy: %.3f' % scores[1])\n",
"print('Time: %.1f' % secs)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"celltoolbar": "Edit Metadata",
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment