Last active
June 7, 2017 20:55
-
-
Save pbloem/9531ec25ae6df6e2d6407a854a2f5538 to your computer and use it in GitHub Desktop.
Simple classifier
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 32, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"from keras.layers import Activation, Input, Dense, Conv2D, Conv2DTranspose, MaxPooling2D\n", | |
"from keras.layers import GlobalMaxPooling2D, Flatten, Reshape, BatchNormalization, Dropout\n", | |
"from keras.models import Model, Sequential\n", | |
"from keras import optimizers, metrics\n", | |
"from keras.utils import np_utils\n", | |
"\n", | |
"\n", | |
"from keras.datasets import mnist, cifar100, cifar10\n", | |
"import numpy as np\n", | |
"from keras.callbacks import TensorBoard\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 37, | |
"metadata": { | |
"collapsed": false, | |
"scrolled": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"_________________________________________________________________\n", | |
"Layer (type) Output Shape Param # \n", | |
"=================================================================\n", | |
"input_15 (InputLayer) (None, 32, 32, 3) 0 \n", | |
"_________________________________________________________________\n", | |
"conv2d_118 (Conv2D) (None, 16, 16, 8) 224 \n", | |
"_________________________________________________________________\n", | |
"conv2d_119 (Conv2D) (None, 8, 8, 16) 1168 \n", | |
"_________________________________________________________________\n", | |
"conv2d_120 (Conv2D) (None, 4, 4, 32) 4640 \n", | |
"_________________________________________________________________\n", | |
"conv2d_121 (Conv2D) (None, 2, 2, 64) 18496 \n", | |
"_________________________________________________________________\n", | |
"flatten_4 (Flatten) (None, 256) 0 \n", | |
"_________________________________________________________________\n", | |
"dense_6 (Dense) (None, 128) 32896 \n", | |
"_________________________________________________________________\n", | |
"dense_7 (Dense) (None, 10) 1290 \n", | |
"=================================================================\n", | |
"Total params: 58,714\n", | |
"Trainable params: 58,714\n", | |
"Non-trainable params: 0\n", | |
"_________________________________________________________________\n" | |
] | |
} | |
], | |
"source": [ | |
"# We start with color images of 3232 resolution\n", | |
"input = Input(shape=(32,32,3))\n", | |
"\n", | |
"# Four convolution layers. We don't use maxpooling, just 3x3 convolutions with stride 2, \n", | |
"# which halves the resolution each layer\n", | |
"#\n", | |
"x = Conv2D(8, (3, 3), strides=(2,2), activation='relu', padding='same')(input)\n", | |
"x = Conv2D(16, (3, 3), strides=(2,2), activation='relu', padding='same')(x)\n", | |
"x = Conv2D(32, (3, 3), strides=(2,2), activation='relu', padding='same')(x)\n", | |
"x = Conv2D(64, (3, 3), strides=(2,2), activation='relu', padding='same')(x)\n", | |
"\n", | |
"x = Flatten()(x)\n", | |
"x = Dense(128, activation='sigmoid')(x)\n", | |
"\n", | |
"# This is the classification layer. It produces a 10D probability vector.\n", | |
"# Each element in the vector represents a class and its value is the \n", | |
"# probability that that class belongs to the input image\n", | |
"x = Dense(10, activation='softmax')(x) # (note the use of softmax)\n", | |
"\n", | |
"model = Model(input, x)\n", | |
"\n", | |
"# Compile the model. Adam is a good general-purpose optimizer. We don't optimize for \n", | |
"# accuracy (crossentropy works better), but we are ultimately interested in accuracy,\n", | |
"# so we add it as a metric, so we can monitor the accuracy.\n", | |
"model.compile(optimizer='adam', loss='binary_crossentropy', metrics=[metrics.categorical_accuracy])\n", | |
"\n", | |
"model.summary()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 38, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"(50000, 32, 32, 3)\n", | |
"(10000, 32, 32, 3)\n", | |
"[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 1.] (10,)\n" | |
] | |
} | |
], | |
"source": [ | |
"# Load the dataset\n", | |
"(x_train, y_train), (x_test, y_test) = cifar10.load_data()\n", | |
"\n", | |
"# (try mnist.load_data() to get images with shape (28, 28, 1))\n", | |
"\n", | |
"# Rescale to [0.0, 1.0]\n", | |
"x_train = x_train.astype('float32') / 255.\n", | |
"x_test = x_test.astype('float32') / 255.\n", | |
"\n", | |
"x_train = x_train.reshape((len(x_train), 32, 32, 3))\n", | |
"x_test = x_test.reshape((len(x_test), 32, 32, 3))\n", | |
"print(x_train.shape)\n", | |
"print(x_test.shape)\n", | |
"\n", | |
"# For the labels (classes), the given data is a list of integers.\n", | |
"# We need them as 10D one-hot vectors, since that's what the network \n", | |
"# will output\n", | |
"y_test = np_utils.to_categorical(y_test)\n", | |
"y_train = np_utils.to_categorical(y_train)\n", | |
"\n", | |
"# Print a random label\n", | |
"print(y_test[14], y_test[1].shape)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 39, | |
"metadata": { | |
"collapsed": false, | |
"scrolled": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Train on 50000 samples, validate on 10000 samples\n", | |
"Epoch 1/10\n", | |
"50000/50000 [==============================] - 14s - loss: 0.2838 - categorical_accuracy: 0.2868 - val_loss: 0.2489 - val_categorical_accuracy: 0.4004\n", | |
"Epoch 2/10\n", | |
"50000/50000 [==============================] - 14s - loss: 0.2404 - categorical_accuracy: 0.4230 - val_loss: 0.2305 - val_categorical_accuracy: 0.4574\n", | |
"Epoch 3/10\n", | |
"50000/50000 [==============================] - 14s - loss: 0.2268 - categorical_accuracy: 0.4656 - val_loss: 0.2221 - val_categorical_accuracy: 0.4806\n", | |
"Epoch 4/10\n", | |
"50000/50000 [==============================] - 14s - loss: 0.2172 - categorical_accuracy: 0.4945 - val_loss: 0.2122 - val_categorical_accuracy: 0.5115\n", | |
"Epoch 5/10\n", | |
"50000/50000 [==============================] - 16s - loss: 0.2100 - categorical_accuracy: 0.5156 - val_loss: 0.2087 - val_categorical_accuracy: 0.5205\n", | |
"Epoch 6/10\n", | |
"50000/50000 [==============================] - 19s - loss: 0.2045 - categorical_accuracy: 0.5301 - val_loss: 0.2030 - val_categorical_accuracy: 0.5386\n", | |
"Epoch 7/10\n", | |
"50000/50000 [==============================] - 16s - loss: 0.1986 - categorical_accuracy: 0.5475 - val_loss: 0.2003 - val_categorical_accuracy: 0.5449\n", | |
"Epoch 8/10\n", | |
"50000/50000 [==============================] - 15s - loss: 0.1938 - categorical_accuracy: 0.5609 - val_loss: 0.1973 - val_categorical_accuracy: 0.5523\n", | |
"Epoch 9/10\n", | |
"50000/50000 [==============================] - 15s - loss: 0.1886 - categorical_accuracy: 0.5727 - val_loss: 0.1924 - val_categorical_accuracy: 0.5631\n", | |
"Epoch 10/10\n", | |
"50000/50000 [==============================] - 15s - loss: 0.1841 - categorical_accuracy: 0.5867 - val_loss: 0.1946 - val_categorical_accuracy: 0.5550\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"<keras.callbacks.History at 0x129e2e240>" | |
] | |
}, | |
"execution_count": 39, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# Train the model\n", | |
"\n", | |
"model.fit(x_train, y_train,\n", | |
" epochs=10, \n", | |
" batch_size=256,\n", | |
" validation_data=(x_test, y_test),\n", | |
" shuffle=True)\n", | |
"\n", | |
"# This takes about 10-20s per epoch on my laptop. The accuracy hits around 60% after 10 epochs." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.5.1" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment