Skip to content

Instantly share code, notes, and snippets.

@stnk20 stnk20/tight_MNIST.ipynb
Last active Feb 22, 2018

Embed
What would you like to do?
An experiment of reusing convolutional layer.
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/lib/python3.5/dist-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n",
" from ._conv import register_converters as _register_converters\n",
"Using TensorFlow backend.\n"
]
}
],
"source": [
"import keras\n",
"import os\n",
"\n",
"batch_size = 128\n",
"epochs = 30\n",
"lr = 0.002\n",
"\n",
"num_classes = 10\n",
"img_rows, img_cols = 28, 28\n",
"\n",
"save_dir = os.path.join(os.getcwd(), 'saved_models')\n",
"model_name = 'keras_mnist_trained_model.h5'\n",
"model_path = os.path.join(save_dir, model_name)\n",
"\n",
"if not os.path.isdir(save_dir):\n",
" os.makedirs(save_dir)\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"x_train shape: (60000, 28, 28, 1)\n",
"60000 train samples\n",
"10000 test samples\n"
]
}
],
"source": [
"# the data, shuffled and split between train and test sets\n",
"(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()\n",
"\n",
"x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)\n",
"x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)\n",
"input_shape = (img_rows, img_cols, 1)\n",
"\n",
"x_train = x_train.astype('float32')/255\n",
"x_test = x_test.astype('float32')/255\n",
"print('x_train shape:', x_train.shape)\n",
"print(x_train.shape[0], 'train samples')\n",
"print(x_test.shape[0], 'test samples')\n",
"\n",
"# convert class vectors to binary class matrices\n",
"y_train = keras.utils.to_categorical(y_train, num_classes)\n",
"y_test = keras.utils.to_categorical(y_test, num_classes)\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"__________________________________________________________________________________________________\n",
"Layer (type) Output Shape Param # Connected to \n",
"==================================================================================================\n",
"input_1 (InputLayer) (None, 28, 28, 1) 0 \n",
"__________________________________________________________________________________________________\n",
"conv2d_1 (Conv2D) (None, 12, 12, 20) 520 input_1[0][0] \n",
"__________________________________________________________________________________________________\n",
"activation_1 (Activation) (None, 12, 12, 20) 0 conv2d_1[0][0] \n",
"__________________________________________________________________________________________________\n",
"separable_conv2d_1 (SeparableCo (None, 12, 12, 20) 600 activation_1[0][0] \n",
" activation_2[0][0] \n",
" activation_3[0][0] \n",
" activation_4[0][0] \n",
" activation_5[0][0] \n",
" activation_6[0][0] \n",
" activation_7[0][0] \n",
" activation_8[0][0] \n",
"__________________________________________________________________________________________________\n",
"add_1 (Add) (None, 12, 12, 20) 0 conv2d_1[0][0] \n",
" separable_conv2d_1[0][0] \n",
"__________________________________________________________________________________________________\n",
"activation_2 (Activation) (None, 12, 12, 20) 0 add_1[0][0] \n",
"__________________________________________________________________________________________________\n",
"add_2 (Add) (None, 12, 12, 20) 0 add_1[0][0] \n",
" separable_conv2d_1[1][0] \n",
"__________________________________________________________________________________________________\n",
"activation_3 (Activation) (None, 12, 12, 20) 0 add_2[0][0] \n",
"__________________________________________________________________________________________________\n",
"add_3 (Add) (None, 12, 12, 20) 0 add_2[0][0] \n",
" separable_conv2d_1[2][0] \n",
"__________________________________________________________________________________________________\n",
"activation_4 (Activation) (None, 12, 12, 20) 0 add_3[0][0] \n",
"__________________________________________________________________________________________________\n",
"add_4 (Add) (None, 12, 12, 20) 0 add_3[0][0] \n",
" separable_conv2d_1[3][0] \n",
"__________________________________________________________________________________________________\n",
"activation_5 (Activation) (None, 12, 12, 20) 0 add_4[0][0] \n",
"__________________________________________________________________________________________________\n",
"add_5 (Add) (None, 12, 12, 20) 0 add_4[0][0] \n",
" separable_conv2d_1[4][0] \n",
"__________________________________________________________________________________________________\n",
"activation_6 (Activation) (None, 12, 12, 20) 0 add_5[0][0] \n",
"__________________________________________________________________________________________________\n",
"add_6 (Add) (None, 12, 12, 20) 0 add_5[0][0] \n",
" separable_conv2d_1[5][0] \n",
"__________________________________________________________________________________________________\n",
"activation_7 (Activation) (None, 12, 12, 20) 0 add_6[0][0] \n",
"__________________________________________________________________________________________________\n",
"add_7 (Add) (None, 12, 12, 20) 0 add_6[0][0] \n",
" separable_conv2d_1[6][0] \n",
"__________________________________________________________________________________________________\n",
"activation_8 (Activation) (None, 12, 12, 20) 0 add_7[0][0] \n",
"__________________________________________________________________________________________________\n",
"add_8 (Add) (None, 12, 12, 20) 0 add_7[0][0] \n",
" separable_conv2d_1[7][0] \n",
"__________________________________________________________________________________________________\n",
"average_pooling2d_1 (AveragePoo (None, 4, 4, 20) 0 add_8[0][0] \n",
"__________________________________________________________________________________________________\n",
"cropping2d_1 (Cropping2D) (None, 2, 2, 20) 0 average_pooling2d_1[0][0] \n",
"__________________________________________________________________________________________________\n",
"flatten_1 (Flatten) (None, 80) 0 cropping2d_1[0][0] \n",
"__________________________________________________________________________________________________\n",
"dropout_1 (Dropout) (None, 80) 0 flatten_1[0][0] \n",
"__________________________________________________________________________________________________\n",
"dense_1 (Dense) (None, 10) 810 dropout_1[0][0] \n",
"==================================================================================================\n",
"Total params: 1,930\n",
"Trainable params: 1,930\n",
"Non-trainable params: 0\n",
"__________________________________________________________________________________________________\n"
]
}
],
"source": [
"from keras.models import Model, Input\n",
"from keras.layers import Dense, Dropout, Flatten, Activation, Lambda, LeakyReLU\n",
"from keras.layers import Conv2D, SeparableConv2D, MaxPooling2D, AveragePooling2D, Cropping2D\n",
"from keras.layers.merge import Add\n",
"from keras import backend as K\n",
"\n",
"def TightResNet(dim,loop,num_classes,input_shape,dropout=0.1):\n",
" x = Input(shape=input_shape)\n",
" h = Conv2D(dim,(5,5),strides=(2,2),padding=\"valid\")(x)\n",
"\n",
" common_conv = SeparableConv2D(dim,(3,3),padding=\"same\")\n",
" for i in range(loop):\n",
" b = h\n",
" b = Activation(\"relu\")(b)\n",
" b = common_conv(b)\n",
" h = Add()([h,b])\n",
"\n",
" h = AveragePooling2D((3,3))(h)\n",
" h = Cropping2D(1)(h)\n",
" h = Flatten()(h)\n",
" h = Dropout(dropout)(h)\n",
" y = Dense(num_classes, activation='softmax')(h)\n",
"\n",
" return Model(inputs=x,outputs=y)\n",
"\n",
"\n",
"# construct model\n",
"model = TightResNet(20,8,10,input_shape,dropout=0.1) \n",
"\n",
"model.compile(loss=keras.losses.categorical_crossentropy,\n",
" optimizer=keras.optimizers.Nadam(lr),\n",
" metrics=['accuracy'])\n",
"\n",
"model.summary()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train on 60000 samples, validate on 10000 samples\n",
"Epoch 1/30\n",
"60000/60000 [==============================] - 7s 125us/step - loss: 0.4598 - acc: 0.8519 - val_loss: 0.1477 - val_acc: 0.9532\n",
"Epoch 2/30\n",
"60000/60000 [==============================] - 6s 108us/step - loss: 0.1474 - acc: 0.9538 - val_loss: 0.0856 - val_acc: 0.9735\n",
"Epoch 3/30\n",
"60000/60000 [==============================] - 6s 107us/step - loss: 0.1094 - acc: 0.9663 - val_loss: 0.0784 - val_acc: 0.9750\n",
"Epoch 4/30\n",
"60000/60000 [==============================] - 6s 106us/step - loss: 0.0915 - acc: 0.9715 - val_loss: 0.0654 - val_acc: 0.9798\n",
"Epoch 5/30\n",
"60000/60000 [==============================] - 6s 107us/step - loss: 0.0819 - acc: 0.9745 - val_loss: 0.0534 - val_acc: 0.9826\n",
"Epoch 6/30\n",
"60000/60000 [==============================] - 6s 108us/step - loss: 0.0755 - acc: 0.9762 - val_loss: 0.0511 - val_acc: 0.9833\n",
"Epoch 7/30\n",
"60000/60000 [==============================] - 6s 106us/step - loss: 0.0703 - acc: 0.9777 - val_loss: 0.0566 - val_acc: 0.9821\n",
"Epoch 8/30\n",
"60000/60000 [==============================] - 6s 108us/step - loss: 0.0667 - acc: 0.9795 - val_loss: 0.0424 - val_acc: 0.9860\n",
"Epoch 9/30\n",
"60000/60000 [==============================] - 7s 110us/step - loss: 0.0634 - acc: 0.9793 - val_loss: 0.0484 - val_acc: 0.9850\n",
"Epoch 10/30\n",
"60000/60000 [==============================] - 6s 108us/step - loss: 0.0608 - acc: 0.9806 - val_loss: 0.0429 - val_acc: 0.9856\n",
"Epoch 11/30\n",
"60000/60000 [==============================] - 6s 106us/step - loss: 0.0576 - acc: 0.9817 - val_loss: 0.0425 - val_acc: 0.9858\n",
"Epoch 12/30\n",
"60000/60000 [==============================] - 6s 108us/step - loss: 0.0554 - acc: 0.9824 - val_loss: 0.0379 - val_acc: 0.9870\n",
"Epoch 13/30\n",
"60000/60000 [==============================] - 7s 108us/step - loss: 0.0541 - acc: 0.9831 - val_loss: 0.0349 - val_acc: 0.9888\n",
"Epoch 14/30\n",
"60000/60000 [==============================] - 6s 106us/step - loss: 0.0542 - acc: 0.9830 - val_loss: 0.0397 - val_acc: 0.9867\n",
"Epoch 15/30\n",
"60000/60000 [==============================] - 6s 105us/step - loss: 0.0513 - acc: 0.9842 - val_loss: 0.0362 - val_acc: 0.9883\n",
"Epoch 16/30\n",
"60000/60000 [==============================] - 6s 108us/step - loss: 0.0513 - acc: 0.9837 - val_loss: 0.0432 - val_acc: 0.9857\n",
"Epoch 17/30\n",
"59520/60000 [============================>.] - ETA: 0s - loss: 0.0486 - acc: 0.9841\n",
"Epoch 00017: reducing learning rate to 0.0004000000189989805.\n",
"60000/60000 [==============================] - 6s 107us/step - loss: 0.0486 - acc: 0.9841 - val_loss: 0.0365 - val_acc: 0.9875\n",
"Epoch 18/30\n",
"60000/60000 [==============================] - 6s 105us/step - loss: 0.0394 - acc: 0.9870 - val_loss: 0.0316 - val_acc: 0.9907\n",
"Epoch 19/30\n",
"60000/60000 [==============================] - 6s 106us/step - loss: 0.0381 - acc: 0.9881 - val_loss: 0.0359 - val_acc: 0.9874\n",
"Epoch 20/30\n",
"60000/60000 [==============================] - 6s 106us/step - loss: 0.0377 - acc: 0.9880 - val_loss: 0.0328 - val_acc: 0.9893\n",
"Epoch 21/30\n",
"60000/60000 [==============================] - 6s 105us/step - loss: 0.0374 - acc: 0.9882 - val_loss: 0.0329 - val_acc: 0.9895\n",
"Epoch 22/30\n",
"60000/60000 [==============================] - 6s 105us/step - loss: 0.0374 - acc: 0.9883 - val_loss: 0.0309 - val_acc: 0.9905\n",
"Epoch 23/30\n",
"60000/60000 [==============================] - 6s 106us/step - loss: 0.0372 - acc: 0.9887 - val_loss: 0.0306 - val_acc: 0.9907\n",
"Epoch 24/30\n",
"60000/60000 [==============================] - 6s 103us/step - loss: 0.0370 - acc: 0.9880 - val_loss: 0.0310 - val_acc: 0.9902\n",
"Epoch 25/30\n",
"60000/60000 [==============================] - 6s 105us/step - loss: 0.0367 - acc: 0.9882 - val_loss: 0.0327 - val_acc: 0.9894\n",
"Epoch 26/30\n",
"60000/60000 [==============================] - 6s 106us/step - loss: 0.0355 - acc: 0.9885 - val_loss: 0.0326 - val_acc: 0.9899\n",
"Epoch 27/30\n",
"59392/60000 [============================>.] - ETA: 0s - loss: 0.0357 - acc: 0.9883\n",
"Epoch 00027: reducing learning rate to 8.000000379979611e-05.\n",
"60000/60000 [==============================] - 6s 105us/step - loss: 0.0357 - acc: 0.9883 - val_loss: 0.0307 - val_acc: 0.9907\n",
"Epoch 28/30\n",
"60000/60000 [==============================] - 6s 105us/step - loss: 0.0330 - acc: 0.9897 - val_loss: 0.0298 - val_acc: 0.9909\n",
"Epoch 29/30\n",
"60000/60000 [==============================] - 6s 106us/step - loss: 0.0324 - acc: 0.9896 - val_loss: 0.0295 - val_acc: 0.9908\n",
"Epoch 30/30\n",
"60000/60000 [==============================] - 6s 105us/step - loss: 0.0323 - acc: 0.9894 - val_loss: 0.0299 - val_acc: 0.9906\n"
]
},
{
"data": {
"text/plain": [
"<keras.callbacks.History at 0x7f3e661b3d68>"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# train\n",
"checkpoint = keras.callbacks.ModelCheckpoint(model_path,save_best_only=True)\n",
"reducelr = keras.callbacks.ReduceLROnPlateau(factor=0.2,patience=3,cooldown=2,verbose=1)\n",
"\n",
"model.fit(x_train, y_train,\n",
" batch_size=batch_size,\n",
" epochs=epochs,\n",
" verbose=1,\n",
" validation_data=(x_test, y_test),\n",
" callbacks=[checkpoint,reducelr])\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Test loss: 0.029511989441869082\n",
"Test accuracy: 0.9908\n"
]
}
],
"source": [
"# evaluate\n",
"model.load_weights(model_path)\n",
"\n",
"score = model.evaluate(x_test, y_test, verbose=0)\n",
"print('Test loss:', score[0])\n",
"print('Test accuracy:', score[1])"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.2+"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.