Skip to content

Instantly share code, notes, and snippets.

@pbloem
Last active June 7, 2017 20:55
Show Gist options
  • Save pbloem/9531ec25ae6df6e2d6407a854a2f5538 to your computer and use it in GitHub Desktop.
Save pbloem/9531ec25ae6df6e2d6407a854a2f5538 to your computer and use it in GitHub Desktop.
Simple classifier
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 32,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"from keras.layers import Activation, Input, Dense, Conv2D, Conv2DTranspose, MaxPooling2D\n",
"from keras.layers import GlobalMaxPooling2D, Flatten, Reshape, BatchNormalization, Dropout\n",
"from keras.models import Model, Sequential\n",
"from keras import optimizers, metrics\n",
"from keras.utils import np_utils\n",
"\n",
"\n",
"from keras.datasets import mnist, cifar100, cifar10\n",
"import numpy as np\n",
"from keras.callbacks import TensorBoard\n"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {
"collapsed": false,
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
"input_15 (InputLayer) (None, 32, 32, 3) 0 \n",
"_________________________________________________________________\n",
"conv2d_118 (Conv2D) (None, 16, 16, 8) 224 \n",
"_________________________________________________________________\n",
"conv2d_119 (Conv2D) (None, 8, 8, 16) 1168 \n",
"_________________________________________________________________\n",
"conv2d_120 (Conv2D) (None, 4, 4, 32) 4640 \n",
"_________________________________________________________________\n",
"conv2d_121 (Conv2D) (None, 2, 2, 64) 18496 \n",
"_________________________________________________________________\n",
"flatten_4 (Flatten) (None, 256) 0 \n",
"_________________________________________________________________\n",
"dense_6 (Dense) (None, 128) 32896 \n",
"_________________________________________________________________\n",
"dense_7 (Dense) (None, 10) 1290 \n",
"=================================================================\n",
"Total params: 58,714\n",
"Trainable params: 58,714\n",
"Non-trainable params: 0\n",
"_________________________________________________________________\n"
]
}
],
"source": [
"# We start with color images of 3232 resolution\n",
"input = Input(shape=(32,32,3))\n",
"\n",
"# Four convolution layers. We don't use maxpooling, just 3x3 convolutions with stride 2, \n",
"# which halves the resolution each layer\n",
"#\n",
"x = Conv2D(8, (3, 3), strides=(2,2), activation='relu', padding='same')(input)\n",
"x = Conv2D(16, (3, 3), strides=(2,2), activation='relu', padding='same')(x)\n",
"x = Conv2D(32, (3, 3), strides=(2,2), activation='relu', padding='same')(x)\n",
"x = Conv2D(64, (3, 3), strides=(2,2), activation='relu', padding='same')(x)\n",
"\n",
"x = Flatten()(x)\n",
"x = Dense(128, activation='sigmoid')(x)\n",
"\n",
"# This is the classification layer. It produces a 10D probability vector.\n",
"# Each element in the vector represents a class and its value is the \n",
"# probability that that class belongs to the input image\n",
"x = Dense(10, activation='softmax')(x) # (note the use of softmax)\n",
"\n",
"model = Model(input, x)\n",
"\n",
"# Compile the model. Adam is a good general-purpose optimizer. We don't optimize for \n",
"# accuracy (crossentropy works better), but we are ultimately interested in accuracy,\n",
"# so we add it as a metric, so we can monitor the accuracy.\n",
"model.compile(optimizer='adam', loss='binary_crossentropy', metrics=[metrics.categorical_accuracy])\n",
"\n",
"model.summary()"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(50000, 32, 32, 3)\n",
"(10000, 32, 32, 3)\n",
"[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 1.] (10,)\n"
]
}
],
"source": [
"# Load the dataset\n",
"(x_train, y_train), (x_test, y_test) = cifar10.load_data()\n",
"\n",
"# (try mnist.load_data() to get images with shape (28, 28, 1))\n",
"\n",
"# Rescale to [0.0, 1.0]\n",
"x_train = x_train.astype('float32') / 255.\n",
"x_test = x_test.astype('float32') / 255.\n",
"\n",
"x_train = x_train.reshape((len(x_train), 32, 32, 3))\n",
"x_test = x_test.reshape((len(x_test), 32, 32, 3))\n",
"print(x_train.shape)\n",
"print(x_test.shape)\n",
"\n",
"# For the labels (classes), the given data is a list of integers.\n",
"# We need them as 10D one-hot vectors, since that's what the network \n",
"# will output\n",
"y_test = np_utils.to_categorical(y_test)\n",
"y_train = np_utils.to_categorical(y_train)\n",
"\n",
"# Print a random label\n",
"print(y_test[14], y_test[1].shape)"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {
"collapsed": false,
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train on 50000 samples, validate on 10000 samples\n",
"Epoch 1/10\n",
"50000/50000 [==============================] - 14s - loss: 0.2838 - categorical_accuracy: 0.2868 - val_loss: 0.2489 - val_categorical_accuracy: 0.4004\n",
"Epoch 2/10\n",
"50000/50000 [==============================] - 14s - loss: 0.2404 - categorical_accuracy: 0.4230 - val_loss: 0.2305 - val_categorical_accuracy: 0.4574\n",
"Epoch 3/10\n",
"50000/50000 [==============================] - 14s - loss: 0.2268 - categorical_accuracy: 0.4656 - val_loss: 0.2221 - val_categorical_accuracy: 0.4806\n",
"Epoch 4/10\n",
"50000/50000 [==============================] - 14s - loss: 0.2172 - categorical_accuracy: 0.4945 - val_loss: 0.2122 - val_categorical_accuracy: 0.5115\n",
"Epoch 5/10\n",
"50000/50000 [==============================] - 16s - loss: 0.2100 - categorical_accuracy: 0.5156 - val_loss: 0.2087 - val_categorical_accuracy: 0.5205\n",
"Epoch 6/10\n",
"50000/50000 [==============================] - 19s - loss: 0.2045 - categorical_accuracy: 0.5301 - val_loss: 0.2030 - val_categorical_accuracy: 0.5386\n",
"Epoch 7/10\n",
"50000/50000 [==============================] - 16s - loss: 0.1986 - categorical_accuracy: 0.5475 - val_loss: 0.2003 - val_categorical_accuracy: 0.5449\n",
"Epoch 8/10\n",
"50000/50000 [==============================] - 15s - loss: 0.1938 - categorical_accuracy: 0.5609 - val_loss: 0.1973 - val_categorical_accuracy: 0.5523\n",
"Epoch 9/10\n",
"50000/50000 [==============================] - 15s - loss: 0.1886 - categorical_accuracy: 0.5727 - val_loss: 0.1924 - val_categorical_accuracy: 0.5631\n",
"Epoch 10/10\n",
"50000/50000 [==============================] - 15s - loss: 0.1841 - categorical_accuracy: 0.5867 - val_loss: 0.1946 - val_categorical_accuracy: 0.5550\n"
]
},
{
"data": {
"text/plain": [
"<keras.callbacks.History at 0x129e2e240>"
]
},
"execution_count": 39,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Train the model\n",
"\n",
"model.fit(x_train, y_train,\n",
" epochs=10, \n",
" batch_size=256,\n",
" validation_data=(x_test, y_test),\n",
" shuffle=True)\n",
"\n",
"# This takes about 10-20s per epoch on my laptop. The accuracy hits around 60% after 10 epochs."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.1"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment