Last active
December 10, 2017 04:44
-
-
Save myurasov/fbc8f9f42fe04c2ac37cb40e363decd1 to your computer and use it in GitHub Desktop.
ResNet vs VanillaCNN on CIFAR10
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "import keras\nfrom keras.layers import *\nfrom keras.models import *\nfrom keras.optimizers import *\nfrom keras.callbacks import *\nfrom keras.preprocessing.image import *\n\nimport numpy as np\nimport matplotlib.pyplot as plt\n%matplotlib inline", | |
"execution_count": 163, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "# display model as svg\ndef model_as_svg(m):\n from IPython.display import SVG\n from keras.utils.vis_utils import model_to_dot\n return SVG(model_to_dot(m, show_shapes=True)\\\n .create(prog='dot', format='svg'))", | |
"execution_count": 165, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "N_CLASSES = 10", | |
"execution_count": 166, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true, | |
"scrolled": true | |
}, | |
"cell_type": "code", | |
"source": "# load data\n(X_train, y_train), (X_test, y_test) = keras.datasets.cifar10.load_data()\n\nX_train = X_train.astype(np.float32)\nX_test = X_test.astype(np.float32)\n\n# convert ys to one-hots\ny_train = keras.utils.to_categorical(y_train, num_classes=N_CLASSES)\ny_test = keras.utils.to_categorical(y_test, num_classes=N_CLASSES)\n\n# normalize data\n\nmean_X = np.mean(X_train, axis=0)\nX_train -= mean_X\nX_test -= mean_X\n\nstd_X = np.std(X_test)\nX_train /= std_X\nX_test /= std_X", | |
"execution_count": 167, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "def build_plain_cnn_1(input_shape=(32, 32, 3), n_clases=10):\n\n i = Input(shape=input_shape, name='input')\n\n x = BatchNormalization()(i)\n\n x = Conv2D(64, (3, 3), padding='same')(x)\n x = BatchNormalization()(x)\n x = Activation('relu')(x)\n x = Conv2D(64, (3, 3), padding='same')(x)\n x = BatchNormalization()(x)\n x = Activation('relu')(x)\n x = MaxPooling2D((2, 2))(x)\n\n x = Conv2D(128, (3, 3), padding='same')(x)\n x = BatchNormalization()(x)\n x = Activation('relu')(x)\n x = Conv2D(128, (3, 3), padding='same')(x)\n x = BatchNormalization()(x)\n x = Activation('relu')(x)\n x = MaxPooling2D((2, 2))(x)\n\n x = Conv2D(256, (3, 3), padding='same')(x)\n x = BatchNormalization()(x)\n x = Activation('relu')(x)\n x = Conv2D(256, (3, 3), padding='same')(x)\n x = BatchNormalization()(x)\n x = Activation('relu')(x)\n x = MaxPooling2D((2, 2))(x)\n\n x = Conv2D(512, (3, 3), padding='same')(x)\n x = BatchNormalization()(x)\n x = Activation('relu')(x)\n x = Conv2D(512, (3, 3), padding='same')(x)\n x = BatchNormalization()(x)\n x = Activation('relu')(x)\n x = MaxPooling2D((2, 2))(x)\n\n x = Flatten()(x)\n x = Dense(1024, activation='relu')(x)\n x = Dense(1024, activation='relu')(x)\n x = Dense(n_clases, activation='softmax')(x)\n\n return Model(inputs=[i], outputs=[x])", | |
"execution_count": 172, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "def build_resnet18(input_size=(32, 32, 3), n_classes=10):\n\n i = Input(shape=input_size, name='input')\n\n x = Conv2D(64, (7, 7), strides=(2, 2), padding='same', name='conv_0')(i)\n x = BatchNormalization()(x)\n x = Activation('relu')(x)\n\n y = MaxPooling2D((2, 2))(x)\n\n x = Conv2D(64, (3, 3), padding='same', name='conv_64_1a')(y)\n x = BatchNormalization()(x)\n x = Activation('relu')(x)\n x = Conv2D(64, (3, 3), padding='same', name='conv_64_1b')(x)\n x = BatchNormalization()(x)\n x = Activation('relu')(x)\n\n y = Add()([x, y])\n y = Activation('relu')(y)\n\n x = Conv2D(64, (3, 3), padding='same', name='conv_64_2a')(y)\n x = BatchNormalization()(x)\n x = Activation('relu')(x)\n x = Conv2D(64, (3, 3), padding='same', name='conv_64_2b')(x)\n x = BatchNormalization()(x)\n x = Activation('relu')(x)\n\n y = Add()([x, y])\n y = Activation('relu')(y)\n\n x = Conv2D(\n 128, (3, 3), strides=(2, 2), padding='same', name='conv_128_1a')(y)\n x = BatchNormalization()(x)\n x = Activation('relu')(x)\n x = Conv2D(128, (3, 3), padding='same', name='conv_128_1b')(x)\n x = BatchNormalization()(x)\n x = Activation('relu')(x)\n\n y = Conv2D(128, (3, 3), strides=(2, 2), padding='same')(y)\n y = BatchNormalization()(y)\n y = Add()([x, y])\n y = Activation('relu')(y)\n\n x = Conv2D(128, (3, 3), padding='same', name='conv_128_2a')(y)\n x = BatchNormalization()(x)\n x = Activation('relu')(x)\n x = Conv2D(128, (3, 3), padding='same', name='conv_128_2b')(x)\n x = BatchNormalization()(x)\n x = Activation('relu')(x)\n\n y = Add()([x, y])\n y = Activation('relu')(y)\n\n x = Conv2D(\n 256, (3, 3), strides=(2, 2), padding='same', name='conv_256_1a')(y)\n x = BatchNormalization()(x)\n x = Activation('relu')(x)\n x = Conv2D(256, (3, 3), padding='same', name='conv_256_1b')(x)\n x = BatchNormalization()(x)\n x = Activation('relu')(x)\n\n y = Conv2D(256, (3, 3), strides=(2, 2), padding='same')(y)\n y = BatchNormalization()(y)\n y = Add()([x, y])\n y = Activation('relu')(y)\n\n x = Conv2D(256, (3, 3), padding='same', name='conv_256_2a')(y)\n x = BatchNormalization()(x)\n x = Activation('relu')(x)\n x = Conv2D(256, (3, 3), padding='same', name='conv_256_2b')(x)\n x = BatchNormalization()(x)\n x = Activation('relu')(x)\n\n y = Add()([x, y])\n y = Activation('relu')(y)\n\n x = Conv2D(\n 512, (3, 3), strides=(2, 2), padding='same', name='conv_512_1a')(y)\n x = BatchNormalization()(x)\n x = Activation('relu')(x)\n x = Conv2D(512, (3, 3), padding='same', name='conv_512_1b')(x)\n x = BatchNormalization()(x)\n x = Activation('relu')(x)\n\n y = Conv2D(512, (3, 3), strides=(2, 2), padding='same')(y)\n y = BatchNormalization()(y)\n y = Add()([x, y])\n y = Activation('relu')(y)\n\n x = Conv2D(512, (3, 3), padding='same', name='conv_512_2a')(y)\n x = BatchNormalization()(x)\n x = Activation('relu')(x)\n x = Conv2D(512, (3, 3), padding='same', name='conv_512_2b')(x)\n x = BatchNormalization()(x)\n x = Activation('relu')(x)\n\n y = Add()([x, y])\n y = Activation('relu')(y)\n\n # add average poling as needed to bring filter size to 1x1\n # (on ImageNet 224x224 input filter size herer is 7x7, on 32x32 for CIFAR it's 1x1 already)\n y_filter_size = y.shape.as_list()[1:3]\n if max(y_filter_size) > 1:\n x = AveragePooling2D(tuple(y_filter_size))(y)\n\n x = Flatten()(y)\n x = Dense(n_classes, activation='softmax')(x)\n\n return Model(inputs=[i], outputs=[x])", | |
"execution_count": 173, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "model = build_resnet18()\n# model = build_plain_cnn_1()\n\nmodel.compile(\n loss='categorical_crossentropy',\n optimizer=RMSprop(lr=1e-3),\n metrics=['accuracy'])\n\nAUGMENTATION = True\nBATCH_SIZE = 256\nEPOCHS = 100\nSAMPLES_PER_EPOCH = len(X_train) // BATCH_SIZE * BATCH_SIZE\nprint('SAMPLES_PER_EPOCH:', SAMPLES_PER_EPOCH, 'of', len(X_train))\n\ncallbacks = [\n ReduceLROnPlateau(\n factor=0.333,\n cooldown=0,\n patience=2,\n min_lr=1e-9,\n verbose=True,\n monitor='val_acc'),\n EarlyStopping(\n min_delta=0.001, patience=10, monitor='val_acc', verbose=True)\n]\n\nif AUGMENTATION:\n\n g = ImageDataGenerator(\n featurewise_center=False,\n samplewise_center=False,\n featurewise_std_normalization=False,\n samplewise_std_normalization=False,\n zca_whitening=False,\n rotation_range=0.,\n width_shift_range=0.25,\n height_shift_range=0.25,\n horizontal_flip=True,\n vertical_flip=False)\n\n g.fit(X_train)\n\n model.fit_generator(\n g.flow(X_train, y_train, batch_size=BATCH_SIZE, shuffle=True),\n steps_per_epoch=SAMPLES_PER_EPOCH // BATCH_SIZE,\n validation_data=(X_test, y_test),\n epochs=EPOCHS,\n verbose=1,\n max_queue_size=100,\n callbacks=callbacks)\n\nelse:\n\n model.fit(\n X_train,\n y_train,\n batch_size=BATCH_SIZE,\n epochs=EPOCHS,\n validation_data=(X_test, y_test),\n shuffle=True,\n callbacks=callbacks)", | |
"execution_count": 194, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": "SAMPLES_PER_EPOCH: 49920 of 50000\nEpoch 1/100\n195/195 [==============================] - 13s - loss: 1.9017 - acc: 0.3465 - val_loss: 2.6356 - val_acc: 0.2595\nEpoch 2/100\n195/195 [==============================] - 10s - loss: 1.4769 - acc: 0.4786 - val_loss: 1.7127 - val_acc: 0.4994\nEpoch 3/100\n195/195 [==============================] - 10s - loss: 1.2804 - acc: 0.5462 - val_loss: 1.3760 - val_acc: 0.5210\nEpoch 4/100\n195/195 [==============================] - 10s - loss: 1.1823 - acc: 0.5852 - val_loss: 2.3478 - val_acc: 0.4452\nEpoch 5/100\n195/195 [==============================] - 10s - loss: 1.0856 - acc: 0.6204 - val_loss: 1.2669 - val_acc: 0.5871\nEpoch 6/100\n195/195 [==============================] - 10s - loss: 0.9964 - acc: 0.6505 - val_loss: 1.1076 - val_acc: 0.6321\nEpoch 7/100\n195/195 [==============================] - 10s - loss: 0.9235 - acc: 0.6753 - val_loss: 1.4291 - val_acc: 0.5585\nEpoch 8/100\n195/195 [==============================] - 10s - loss: 0.9127 - acc: 0.6860 - val_loss: 1.1149 - val_acc: 0.6114\nEpoch 9/100\n195/195 [==============================] - 10s - loss: 0.8369 - acc: 0.7057 - val_loss: 0.9764 - val_acc: 0.6821\nEpoch 10/100\n195/195 [==============================] - 10s - loss: 0.7910 - acc: 0.7234 - val_loss: 0.8064 - val_acc: 0.7233\nEpoch 11/100\n195/195 [==============================] - 10s - loss: 0.7447 - acc: 0.7383 - val_loss: 0.8988 - val_acc: 0.6991\nEpoch 12/100\n195/195 [==============================] - 10s - loss: 0.7097 - acc: 0.7501 - val_loss: 0.9078 - val_acc: 0.6902\nEpoch 13/100\n195/195 [==============================] - 10s - loss: 0.6814 - acc: 0.7606 - val_loss: 0.7662 - val_acc: 0.7334\nEpoch 14/100\n195/195 [==============================] - 10s - loss: 0.6645 - acc: 0.7666 - val_loss: 0.7401 - val_acc: 0.7481\nEpoch 15/100\n195/195 [==============================] - 10s - loss: 0.6416 - acc: 0.7745 - val_loss: 0.8701 - val_acc: 0.7083\nEpoch 16/100\n195/195 [==============================] - 10s - loss: 0.6122 - acc: 0.7861 - val_loss: 0.8371 - val_acc: 0.7322\nEpoch 17/100\n195/195 [==============================] - 10s - loss: 0.6017 - acc: 0.7892 - val_loss: 0.6550 - val_acc: 0.7817\nEpoch 18/100\n195/195 [==============================] - 10s - loss: 0.5798 - acc: 0.7954 - val_loss: 0.7455 - val_acc: 0.7517\nEpoch 19/100\n195/195 [==============================] - 10s - loss: 0.5672 - acc: 0.8020 - val_loss: 0.7141 - val_acc: 0.7617\nEpoch 20/100\n194/195 [============================>.] - ETA: 0s - loss: 0.5487 - acc: 0.8073\nEpoch 00019: reducing learning rate to 0.0003330000158166513.\n195/195 [==============================] - 13s - loss: 0.5489 - acc: 0.8073 - val_loss: 0.6659 - val_acc: 0.7687\nEpoch 21/100\n195/195 [==============================] - 10s - loss: 0.4617 - acc: 0.8380 - val_loss: 0.5565 - val_acc: 0.8157\nEpoch 22/100\n195/195 [==============================] - 10s - loss: 0.4401 - acc: 0.8460 - val_loss: 0.5346 - val_acc: 0.8210\nEpoch 23/100\n195/195 [==============================] - 10s - loss: 0.4284 - acc: 0.8492 - val_loss: 0.5175 - val_acc: 0.8232\nEpoch 24/100\n195/195 [==============================] - 10s - loss: 0.4196 - acc: 0.8533 - val_loss: 0.5118 - val_acc: 0.8288\nEpoch 25/100\n195/195 [==============================] - 10s - loss: 0.4144 - acc: 0.8535 - val_loss: 0.5120 - val_acc: 0.8273\nEpoch 26/100\n195/195 [==============================] - 10s - loss: 0.4006 - acc: 0.8586 - val_loss: 0.5601 - val_acc: 0.8213\nEpoch 27/100\n194/195 [============================>.] - ETA: 0s - loss: 0.3995 - acc: 0.8595\nEpoch 00026: reducing learning rate to 0.00011088900119648315.\n195/195 [==============================] - 10s - loss: 0.3994 - acc: 0.8594 - val_loss: 0.5371 - val_acc: 0.8207\nEpoch 28/100\n195/195 [==============================] - 10s - loss: 0.3717 - acc: 0.8690 - val_loss: 0.4783 - val_acc: 0.8407\nEpoch 29/100\n195/195 [==============================] - 10s - loss: 0.3634 - acc: 0.8730 - val_loss: 0.4830 - val_acc: 0.8402\nEpoch 30/100\n195/195 [==============================] - 10s - loss: 0.3576 - acc: 0.8744 - val_loss: 0.4864 - val_acc: 0.8382\nEpoch 31/100\n195/195 [==============================] - 10s - loss: 0.3527 - acc: 0.8751 - val_loss: 0.4828 - val_acc: 0.8430\nEpoch 32/100\n195/195 [==============================] - 10s - loss: 0.3462 - acc: 0.8762 - val_loss: 0.4816 - val_acc: 0.8436\nEpoch 33/100\n195/195 [==============================] - 10s - loss: 0.3495 - acc: 0.8770 - val_loss: 0.4905 - val_acc: 0.8405\nEpoch 34/100\n195/195 [==============================] - 10s - loss: 0.3410 - acc: 0.8783 - val_loss: 0.4828 - val_acc: 0.8440\nEpoch 35/100\n195/195 [==============================] - 10s - loss: 0.3416 - acc: 0.8793 - val_loss: 0.4821 - val_acc: 0.8460\nEpoch 36/100\n195/195 [==============================] - 10s - loss: 0.3361 - acc: 0.8809 - val_loss: 0.4853 - val_acc: 0.8464\nEpoch 37/100\n195/195 [==============================] - 10s - loss: 0.3270 - acc: 0.8840 - val_loss: 0.4817 - val_acc: 0.8470\nEpoch 38/100\n195/195 [==============================] - 10s - loss: 0.3292 - acc: 0.8808 - val_loss: 0.4879 - val_acc: 0.8443\nEpoch 39/100\n195/195 [==============================] - 10s - loss: 0.3263 - acc: 0.8843 - val_loss: 0.4947 - val_acc: 0.8460\nEpoch 40/100\n194/195 [============================>.] - ETA: 0s - loss: 0.3240 - acc: 0.8852\nEpoch 00039: reducing learning rate to 3.692603672971018e-05.\n195/195 [==============================] - 10s - loss: 0.3238 - acc: 0.8853 - val_loss: 0.4858 - val_acc: 0.8439\nEpoch 41/100\n195/195 [==============================] - 10s - loss: 0.3151 - acc: 0.8886 - val_loss: 0.4793 - val_acc: 0.8480\nEpoch 42/100\n195/195 [==============================] - 10s - loss: 0.3135 - acc: 0.8894 - val_loss: 0.4799 - val_acc: 0.8479\nEpoch 43/100\n195/195 [==============================] - 10s - loss: 0.3164 - acc: 0.8887 - val_loss: 0.4753 - val_acc: 0.8476\nEpoch 44/100\n195/195 [==============================] - 10s - loss: 0.3103 - acc: 0.8893 - val_loss: 0.4752 - val_acc: 0.8494\nEpoch 45/100\n195/195 [==============================] - 10s - loss: 0.3075 - acc: 0.8911 - val_loss: 0.4767 - val_acc: 0.8488\nEpoch 46/100\n195/195 [==============================] - 10s - loss: 0.3075 - acc: 0.8905 - val_loss: 0.4770 - val_acc: 0.8490\nEpoch 47/100\n194/195 [============================>.] - ETA: 0s - loss: 0.3062 - acc: 0.8915\nEpoch 00046: reducing learning rate to 1.2296370608964936e-05.\n195/195 [==============================] - 10s - loss: 0.3059 - acc: 0.8915 - val_loss: 0.4735 - val_acc: 0.8478\nEpoch 48/100\n195/195 [==============================] - 10s - loss: 0.3034 - acc: 0.8924 - val_loss: 0.4731 - val_acc: 0.8490\nEpoch 49/100\n194/195 [============================>.] - ETA: 0s - loss: 0.3032 - acc: 0.8929\nEpoch 00048: reducing learning rate to 4.094691272257478e-06.\n195/195 [==============================] - 10s - loss: 0.3032 - acc: 0.8930 - val_loss: 0.4739 - val_acc: 0.8492\nEpoch 50/100\n195/195 [==============================] - 10s - loss: 0.3021 - acc: 0.8937 - val_loss: 0.4723 - val_acc: 0.8511\nEpoch 51/100\n195/195 [==============================] - 10s - loss: 0.3055 - acc: 0.8918 - val_loss: 0.4718 - val_acc: 0.8504\nEpoch 52/100\n195/195 [==============================] - 10s - loss: 0.3020 - acc: 0.8920 - val_loss: 0.4715 - val_acc: 0.8507\nEpoch 53/100\n194/195 [============================>.] - ETA: 0s - loss: 0.3053 - acc: 0.8924\nEpoch 00052: reducing learning rate to 1.3635321433866921e-06.\n195/195 [==============================] - 10s - loss: 0.3052 - acc: 0.8924 - val_loss: 0.4716 - val_acc: 0.8507\nEpoch 54/100\n195/195 [==============================] - 10s - loss: 0.2992 - acc: 0.8936 - val_loss: 0.4714 - val_acc: 0.8515\nEpoch 55/100\n195/195 [==============================] - 10s - loss: 0.2984 - acc: 0.8924 - val_loss: 0.4711 - val_acc: 0.8506\nEpoch 56/100\n195/195 [==============================] - 10s - loss: 0.3030 - acc: 0.8921 - val_loss: 0.4710 - val_acc: 0.8508\nEpoch 57/100\n194/195 [============================>.] - ETA: 0s - loss: 0.3008 - acc: 0.8938\nEpoch 00056: reducing learning rate to 4.540562199508713e-07.\n", | |
"name": "stdout" | |
}, | |
{ | |
"output_type": "stream", | |
"text": "195/195 [==============================] - 10s - loss: 0.3006 - acc: 0.8939 - val_loss: 0.4707 - val_acc: 0.8506\nEpoch 58/100\n195/195 [==============================] - 10s - loss: 0.2978 - acc: 0.8947 - val_loss: 0.4710 - val_acc: 0.8507\nEpoch 59/100\n194/195 [============================>.] - ETA: 0s - loss: 0.3039 - acc: 0.8919\nEpoch 00058: reducing learning rate to 1.512007213193556e-07.\n195/195 [==============================] - 10s - loss: 0.3036 - acc: 0.8919 - val_loss: 0.4706 - val_acc: 0.8507\nEpoch 60/100\n195/195 [==============================] - 10s - loss: 0.3015 - acc: 0.8933 - val_loss: 0.4710 - val_acc: 0.8512\nEpoch 61/100\n194/195 [============================>.] - ETA: 0s - loss: 0.3033 - acc: 0.8933\nEpoch 00060: reducing learning rate to 5.03498407766756e-08.\n195/195 [==============================] - 10s - loss: 0.3033 - acc: 0.8934 - val_loss: 0.4708 - val_acc: 0.8505\nEpoch 00060: early stopping\n", | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "", | |
"execution_count": null, | |
"outputs": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3", | |
"language": "python" | |
}, | |
"language_info": { | |
"file_extension": ".py", | |
"pygments_lexer": "ipython3", | |
"name": "python", | |
"version": "3.5.2", | |
"nbconvert_exporter": "python", | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"mimetype": "text/x-python" | |
}, | |
"gist": { | |
"id": "fbc8f9f42fe04c2ac37cb40e363decd1", | |
"data": { | |
"description": "ResNet vs VanillaCNN on CIFAR10", | |
"public": true | |
} | |
}, | |
"_draft": { | |
"nbviewer_url": "https://gist.github.com/fbc8f9f42fe04c2ac37cb40e363decd1" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment