Skip to content

Instantly share code, notes, and snippets.

@myurasov
Last active December 10, 2017 04:44
Show Gist options
  • Save myurasov/fbc8f9f42fe04c2ac37cb40e363decd1 to your computer and use it in GitHub Desktop.
Save myurasov/fbc8f9f42fe04c2ac37cb40e363decd1 to your computer and use it in GitHub Desktop.
ResNet vs VanillaCNN on CIFAR10
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "import keras\nfrom keras.layers import *\nfrom keras.models import *\nfrom keras.optimizers import *\nfrom keras.callbacks import *\nfrom keras.preprocessing.image import *\n\nimport numpy as np\nimport matplotlib.pyplot as plt\n%matplotlib inline",
"execution_count": 163,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "# display model as svg\ndef model_as_svg(m):\n from IPython.display import SVG\n from keras.utils.vis_utils import model_to_dot\n return SVG(model_to_dot(m, show_shapes=True)\\\n .create(prog='dot', format='svg'))",
"execution_count": 165,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "N_CLASSES = 10",
"execution_count": 166,
"outputs": []
},
{
"metadata": {
"trusted": true,
"scrolled": true
},
"cell_type": "code",
"source": "# load data\n(X_train, y_train), (X_test, y_test) = keras.datasets.cifar10.load_data()\n\nX_train = X_train.astype(np.float32)\nX_test = X_test.astype(np.float32)\n\n# convert ys to one-hots\ny_train = keras.utils.to_categorical(y_train, num_classes=N_CLASSES)\ny_test = keras.utils.to_categorical(y_test, num_classes=N_CLASSES)\n\n# normalize data\n\nmean_X = np.mean(X_train, axis=0)\nX_train -= mean_X\nX_test -= mean_X\n\nstd_X = np.std(X_test)\nX_train /= std_X\nX_test /= std_X",
"execution_count": 167,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "def build_plain_cnn_1(input_shape=(32, 32, 3), n_clases=10):\n\n i = Input(shape=input_shape, name='input')\n\n x = BatchNormalization()(i)\n\n x = Conv2D(64, (3, 3), padding='same')(x)\n x = BatchNormalization()(x)\n x = Activation('relu')(x)\n x = Conv2D(64, (3, 3), padding='same')(x)\n x = BatchNormalization()(x)\n x = Activation('relu')(x)\n x = MaxPooling2D((2, 2))(x)\n\n x = Conv2D(128, (3, 3), padding='same')(x)\n x = BatchNormalization()(x)\n x = Activation('relu')(x)\n x = Conv2D(128, (3, 3), padding='same')(x)\n x = BatchNormalization()(x)\n x = Activation('relu')(x)\n x = MaxPooling2D((2, 2))(x)\n\n x = Conv2D(256, (3, 3), padding='same')(x)\n x = BatchNormalization()(x)\n x = Activation('relu')(x)\n x = Conv2D(256, (3, 3), padding='same')(x)\n x = BatchNormalization()(x)\n x = Activation('relu')(x)\n x = MaxPooling2D((2, 2))(x)\n\n x = Conv2D(512, (3, 3), padding='same')(x)\n x = BatchNormalization()(x)\n x = Activation('relu')(x)\n x = Conv2D(512, (3, 3), padding='same')(x)\n x = BatchNormalization()(x)\n x = Activation('relu')(x)\n x = MaxPooling2D((2, 2))(x)\n\n x = Flatten()(x)\n x = Dense(1024, activation='relu')(x)\n x = Dense(1024, activation='relu')(x)\n x = Dense(n_clases, activation='softmax')(x)\n\n return Model(inputs=[i], outputs=[x])",
"execution_count": 172,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "def build_resnet18(input_size=(32, 32, 3), n_classes=10):\n\n i = Input(shape=input_size, name='input')\n\n x = Conv2D(64, (7, 7), strides=(2, 2), padding='same', name='conv_0')(i)\n x = BatchNormalization()(x)\n x = Activation('relu')(x)\n\n y = MaxPooling2D((2, 2))(x)\n\n x = Conv2D(64, (3, 3), padding='same', name='conv_64_1a')(y)\n x = BatchNormalization()(x)\n x = Activation('relu')(x)\n x = Conv2D(64, (3, 3), padding='same', name='conv_64_1b')(x)\n x = BatchNormalization()(x)\n x = Activation('relu')(x)\n\n y = Add()([x, y])\n y = Activation('relu')(y)\n\n x = Conv2D(64, (3, 3), padding='same', name='conv_64_2a')(y)\n x = BatchNormalization()(x)\n x = Activation('relu')(x)\n x = Conv2D(64, (3, 3), padding='same', name='conv_64_2b')(x)\n x = BatchNormalization()(x)\n x = Activation('relu')(x)\n\n y = Add()([x, y])\n y = Activation('relu')(y)\n\n x = Conv2D(\n 128, (3, 3), strides=(2, 2), padding='same', name='conv_128_1a')(y)\n x = BatchNormalization()(x)\n x = Activation('relu')(x)\n x = Conv2D(128, (3, 3), padding='same', name='conv_128_1b')(x)\n x = BatchNormalization()(x)\n x = Activation('relu')(x)\n\n y = Conv2D(128, (3, 3), strides=(2, 2), padding='same')(y)\n y = BatchNormalization()(y)\n y = Add()([x, y])\n y = Activation('relu')(y)\n\n x = Conv2D(128, (3, 3), padding='same', name='conv_128_2a')(y)\n x = BatchNormalization()(x)\n x = Activation('relu')(x)\n x = Conv2D(128, (3, 3), padding='same', name='conv_128_2b')(x)\n x = BatchNormalization()(x)\n x = Activation('relu')(x)\n\n y = Add()([x, y])\n y = Activation('relu')(y)\n\n x = Conv2D(\n 256, (3, 3), strides=(2, 2), padding='same', name='conv_256_1a')(y)\n x = BatchNormalization()(x)\n x = Activation('relu')(x)\n x = Conv2D(256, (3, 3), padding='same', name='conv_256_1b')(x)\n x = BatchNormalization()(x)\n x = Activation('relu')(x)\n\n y = Conv2D(256, (3, 3), strides=(2, 2), padding='same')(y)\n y = BatchNormalization()(y)\n y = Add()([x, y])\n y = Activation('relu')(y)\n\n x = Conv2D(256, (3, 3), padding='same', name='conv_256_2a')(y)\n x = BatchNormalization()(x)\n x = Activation('relu')(x)\n x = Conv2D(256, (3, 3), padding='same', name='conv_256_2b')(x)\n x = BatchNormalization()(x)\n x = Activation('relu')(x)\n\n y = Add()([x, y])\n y = Activation('relu')(y)\n\n x = Conv2D(\n 512, (3, 3), strides=(2, 2), padding='same', name='conv_512_1a')(y)\n x = BatchNormalization()(x)\n x = Activation('relu')(x)\n x = Conv2D(512, (3, 3), padding='same', name='conv_512_1b')(x)\n x = BatchNormalization()(x)\n x = Activation('relu')(x)\n\n y = Conv2D(512, (3, 3), strides=(2, 2), padding='same')(y)\n y = BatchNormalization()(y)\n y = Add()([x, y])\n y = Activation('relu')(y)\n\n x = Conv2D(512, (3, 3), padding='same', name='conv_512_2a')(y)\n x = BatchNormalization()(x)\n x = Activation('relu')(x)\n x = Conv2D(512, (3, 3), padding='same', name='conv_512_2b')(x)\n x = BatchNormalization()(x)\n x = Activation('relu')(x)\n\n y = Add()([x, y])\n y = Activation('relu')(y)\n\n # add average poling as needed to bring filter size to 1x1\n # (on ImageNet 224x224 input filter size herer is 7x7, on 32x32 for CIFAR it's 1x1 already)\n y_filter_size = y.shape.as_list()[1:3]\n if max(y_filter_size) > 1:\n x = AveragePooling2D(tuple(y_filter_size))(y)\n\n x = Flatten()(y)\n x = Dense(n_classes, activation='softmax')(x)\n\n return Model(inputs=[i], outputs=[x])",
"execution_count": 173,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "model = build_resnet18()\n# model = build_plain_cnn_1()\n\nmodel.compile(\n loss='categorical_crossentropy',\n optimizer=RMSprop(lr=1e-3),\n metrics=['accuracy'])\n\nAUGMENTATION = True\nBATCH_SIZE = 256\nEPOCHS = 100\nSAMPLES_PER_EPOCH = len(X_train) // BATCH_SIZE * BATCH_SIZE\nprint('SAMPLES_PER_EPOCH:', SAMPLES_PER_EPOCH, 'of', len(X_train))\n\ncallbacks = [\n ReduceLROnPlateau(\n factor=0.333,\n cooldown=0,\n patience=2,\n min_lr=1e-9,\n verbose=True,\n monitor='val_acc'),\n EarlyStopping(\n min_delta=0.001, patience=10, monitor='val_acc', verbose=True)\n]\n\nif AUGMENTATION:\n\n g = ImageDataGenerator(\n featurewise_center=False,\n samplewise_center=False,\n featurewise_std_normalization=False,\n samplewise_std_normalization=False,\n zca_whitening=False,\n rotation_range=0.,\n width_shift_range=0.25,\n height_shift_range=0.25,\n horizontal_flip=True,\n vertical_flip=False)\n\n g.fit(X_train)\n\n model.fit_generator(\n g.flow(X_train, y_train, batch_size=BATCH_SIZE, shuffle=True),\n steps_per_epoch=SAMPLES_PER_EPOCH // BATCH_SIZE,\n validation_data=(X_test, y_test),\n epochs=EPOCHS,\n verbose=1,\n max_queue_size=100,\n callbacks=callbacks)\n\nelse:\n\n model.fit(\n X_train,\n y_train,\n batch_size=BATCH_SIZE,\n epochs=EPOCHS,\n validation_data=(X_test, y_test),\n shuffle=True,\n callbacks=callbacks)",
"execution_count": 194,
"outputs": [
{
"output_type": "stream",
"text": "SAMPLES_PER_EPOCH: 49920 of 50000\nEpoch 1/100\n195/195 [==============================] - 13s - loss: 1.9017 - acc: 0.3465 - val_loss: 2.6356 - val_acc: 0.2595\nEpoch 2/100\n195/195 [==============================] - 10s - loss: 1.4769 - acc: 0.4786 - val_loss: 1.7127 - val_acc: 0.4994\nEpoch 3/100\n195/195 [==============================] - 10s - loss: 1.2804 - acc: 0.5462 - val_loss: 1.3760 - val_acc: 0.5210\nEpoch 4/100\n195/195 [==============================] - 10s - loss: 1.1823 - acc: 0.5852 - val_loss: 2.3478 - val_acc: 0.4452\nEpoch 5/100\n195/195 [==============================] - 10s - loss: 1.0856 - acc: 0.6204 - val_loss: 1.2669 - val_acc: 0.5871\nEpoch 6/100\n195/195 [==============================] - 10s - loss: 0.9964 - acc: 0.6505 - val_loss: 1.1076 - val_acc: 0.6321\nEpoch 7/100\n195/195 [==============================] - 10s - loss: 0.9235 - acc: 0.6753 - val_loss: 1.4291 - val_acc: 0.5585\nEpoch 8/100\n195/195 [==============================] - 10s - loss: 0.9127 - acc: 0.6860 - val_loss: 1.1149 - val_acc: 0.6114\nEpoch 9/100\n195/195 [==============================] - 10s - loss: 0.8369 - acc: 0.7057 - val_loss: 0.9764 - val_acc: 0.6821\nEpoch 10/100\n195/195 [==============================] - 10s - loss: 0.7910 - acc: 0.7234 - val_loss: 0.8064 - val_acc: 0.7233\nEpoch 11/100\n195/195 [==============================] - 10s - loss: 0.7447 - acc: 0.7383 - val_loss: 0.8988 - val_acc: 0.6991\nEpoch 12/100\n195/195 [==============================] - 10s - loss: 0.7097 - acc: 0.7501 - val_loss: 0.9078 - val_acc: 0.6902\nEpoch 13/100\n195/195 [==============================] - 10s - loss: 0.6814 - acc: 0.7606 - val_loss: 0.7662 - val_acc: 0.7334\nEpoch 14/100\n195/195 [==============================] - 10s - loss: 0.6645 - acc: 0.7666 - val_loss: 0.7401 - val_acc: 0.7481\nEpoch 15/100\n195/195 [==============================] - 10s - loss: 0.6416 - acc: 0.7745 - val_loss: 0.8701 - val_acc: 0.7083\nEpoch 16/100\n195/195 [==============================] - 10s - loss: 0.6122 - acc: 0.7861 - val_loss: 0.8371 - val_acc: 0.7322\nEpoch 17/100\n195/195 [==============================] - 10s - loss: 0.6017 - acc: 0.7892 - val_loss: 0.6550 - val_acc: 0.7817\nEpoch 18/100\n195/195 [==============================] - 10s - loss: 0.5798 - acc: 0.7954 - val_loss: 0.7455 - val_acc: 0.7517\nEpoch 19/100\n195/195 [==============================] - 10s - loss: 0.5672 - acc: 0.8020 - val_loss: 0.7141 - val_acc: 0.7617\nEpoch 20/100\n194/195 [============================>.] - ETA: 0s - loss: 0.5487 - acc: 0.8073\nEpoch 00019: reducing learning rate to 0.0003330000158166513.\n195/195 [==============================] - 13s - loss: 0.5489 - acc: 0.8073 - val_loss: 0.6659 - val_acc: 0.7687\nEpoch 21/100\n195/195 [==============================] - 10s - loss: 0.4617 - acc: 0.8380 - val_loss: 0.5565 - val_acc: 0.8157\nEpoch 22/100\n195/195 [==============================] - 10s - loss: 0.4401 - acc: 0.8460 - val_loss: 0.5346 - val_acc: 0.8210\nEpoch 23/100\n195/195 [==============================] - 10s - loss: 0.4284 - acc: 0.8492 - val_loss: 0.5175 - val_acc: 0.8232\nEpoch 24/100\n195/195 [==============================] - 10s - loss: 0.4196 - acc: 0.8533 - val_loss: 0.5118 - val_acc: 0.8288\nEpoch 25/100\n195/195 [==============================] - 10s - loss: 0.4144 - acc: 0.8535 - val_loss: 0.5120 - val_acc: 0.8273\nEpoch 26/100\n195/195 [==============================] - 10s - loss: 0.4006 - acc: 0.8586 - val_loss: 0.5601 - val_acc: 0.8213\nEpoch 27/100\n194/195 [============================>.] - ETA: 0s - loss: 0.3995 - acc: 0.8595\nEpoch 00026: reducing learning rate to 0.00011088900119648315.\n195/195 [==============================] - 10s - loss: 0.3994 - acc: 0.8594 - val_loss: 0.5371 - val_acc: 0.8207\nEpoch 28/100\n195/195 [==============================] - 10s - loss: 0.3717 - acc: 0.8690 - val_loss: 0.4783 - val_acc: 0.8407\nEpoch 29/100\n195/195 [==============================] - 10s - loss: 0.3634 - acc: 0.8730 - val_loss: 0.4830 - val_acc: 0.8402\nEpoch 30/100\n195/195 [==============================] - 10s - loss: 0.3576 - acc: 0.8744 - val_loss: 0.4864 - val_acc: 0.8382\nEpoch 31/100\n195/195 [==============================] - 10s - loss: 0.3527 - acc: 0.8751 - val_loss: 0.4828 - val_acc: 0.8430\nEpoch 32/100\n195/195 [==============================] - 10s - loss: 0.3462 - acc: 0.8762 - val_loss: 0.4816 - val_acc: 0.8436\nEpoch 33/100\n195/195 [==============================] - 10s - loss: 0.3495 - acc: 0.8770 - val_loss: 0.4905 - val_acc: 0.8405\nEpoch 34/100\n195/195 [==============================] - 10s - loss: 0.3410 - acc: 0.8783 - val_loss: 0.4828 - val_acc: 0.8440\nEpoch 35/100\n195/195 [==============================] - 10s - loss: 0.3416 - acc: 0.8793 - val_loss: 0.4821 - val_acc: 0.8460\nEpoch 36/100\n195/195 [==============================] - 10s - loss: 0.3361 - acc: 0.8809 - val_loss: 0.4853 - val_acc: 0.8464\nEpoch 37/100\n195/195 [==============================] - 10s - loss: 0.3270 - acc: 0.8840 - val_loss: 0.4817 - val_acc: 0.8470\nEpoch 38/100\n195/195 [==============================] - 10s - loss: 0.3292 - acc: 0.8808 - val_loss: 0.4879 - val_acc: 0.8443\nEpoch 39/100\n195/195 [==============================] - 10s - loss: 0.3263 - acc: 0.8843 - val_loss: 0.4947 - val_acc: 0.8460\nEpoch 40/100\n194/195 [============================>.] - ETA: 0s - loss: 0.3240 - acc: 0.8852\nEpoch 00039: reducing learning rate to 3.692603672971018e-05.\n195/195 [==============================] - 10s - loss: 0.3238 - acc: 0.8853 - val_loss: 0.4858 - val_acc: 0.8439\nEpoch 41/100\n195/195 [==============================] - 10s - loss: 0.3151 - acc: 0.8886 - val_loss: 0.4793 - val_acc: 0.8480\nEpoch 42/100\n195/195 [==============================] - 10s - loss: 0.3135 - acc: 0.8894 - val_loss: 0.4799 - val_acc: 0.8479\nEpoch 43/100\n195/195 [==============================] - 10s - loss: 0.3164 - acc: 0.8887 - val_loss: 0.4753 - val_acc: 0.8476\nEpoch 44/100\n195/195 [==============================] - 10s - loss: 0.3103 - acc: 0.8893 - val_loss: 0.4752 - val_acc: 0.8494\nEpoch 45/100\n195/195 [==============================] - 10s - loss: 0.3075 - acc: 0.8911 - val_loss: 0.4767 - val_acc: 0.8488\nEpoch 46/100\n195/195 [==============================] - 10s - loss: 0.3075 - acc: 0.8905 - val_loss: 0.4770 - val_acc: 0.8490\nEpoch 47/100\n194/195 [============================>.] - ETA: 0s - loss: 0.3062 - acc: 0.8915\nEpoch 00046: reducing learning rate to 1.2296370608964936e-05.\n195/195 [==============================] - 10s - loss: 0.3059 - acc: 0.8915 - val_loss: 0.4735 - val_acc: 0.8478\nEpoch 48/100\n195/195 [==============================] - 10s - loss: 0.3034 - acc: 0.8924 - val_loss: 0.4731 - val_acc: 0.8490\nEpoch 49/100\n194/195 [============================>.] - ETA: 0s - loss: 0.3032 - acc: 0.8929\nEpoch 00048: reducing learning rate to 4.094691272257478e-06.\n195/195 [==============================] - 10s - loss: 0.3032 - acc: 0.8930 - val_loss: 0.4739 - val_acc: 0.8492\nEpoch 50/100\n195/195 [==============================] - 10s - loss: 0.3021 - acc: 0.8937 - val_loss: 0.4723 - val_acc: 0.8511\nEpoch 51/100\n195/195 [==============================] - 10s - loss: 0.3055 - acc: 0.8918 - val_loss: 0.4718 - val_acc: 0.8504\nEpoch 52/100\n195/195 [==============================] - 10s - loss: 0.3020 - acc: 0.8920 - val_loss: 0.4715 - val_acc: 0.8507\nEpoch 53/100\n194/195 [============================>.] - ETA: 0s - loss: 0.3053 - acc: 0.8924\nEpoch 00052: reducing learning rate to 1.3635321433866921e-06.\n195/195 [==============================] - 10s - loss: 0.3052 - acc: 0.8924 - val_loss: 0.4716 - val_acc: 0.8507\nEpoch 54/100\n195/195 [==============================] - 10s - loss: 0.2992 - acc: 0.8936 - val_loss: 0.4714 - val_acc: 0.8515\nEpoch 55/100\n195/195 [==============================] - 10s - loss: 0.2984 - acc: 0.8924 - val_loss: 0.4711 - val_acc: 0.8506\nEpoch 56/100\n195/195 [==============================] - 10s - loss: 0.3030 - acc: 0.8921 - val_loss: 0.4710 - val_acc: 0.8508\nEpoch 57/100\n194/195 [============================>.] - ETA: 0s - loss: 0.3008 - acc: 0.8938\nEpoch 00056: reducing learning rate to 4.540562199508713e-07.\n",
"name": "stdout"
},
{
"output_type": "stream",
"text": "195/195 [==============================] - 10s - loss: 0.3006 - acc: 0.8939 - val_loss: 0.4707 - val_acc: 0.8506\nEpoch 58/100\n195/195 [==============================] - 10s - loss: 0.2978 - acc: 0.8947 - val_loss: 0.4710 - val_acc: 0.8507\nEpoch 59/100\n194/195 [============================>.] - ETA: 0s - loss: 0.3039 - acc: 0.8919\nEpoch 00058: reducing learning rate to 1.512007213193556e-07.\n195/195 [==============================] - 10s - loss: 0.3036 - acc: 0.8919 - val_loss: 0.4706 - val_acc: 0.8507\nEpoch 60/100\n195/195 [==============================] - 10s - loss: 0.3015 - acc: 0.8933 - val_loss: 0.4710 - val_acc: 0.8512\nEpoch 61/100\n194/195 [============================>.] - ETA: 0s - loss: 0.3033 - acc: 0.8933\nEpoch 00060: reducing learning rate to 5.03498407766756e-08.\n195/195 [==============================] - 10s - loss: 0.3033 - acc: 0.8934 - val_loss: 0.4708 - val_acc: 0.8505\nEpoch 00060: early stopping\n",
"name": "stdout"
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "",
"execution_count": null,
"outputs": []
}
],
"metadata": {
"kernelspec": {
"name": "python3",
"display_name": "Python 3",
"language": "python"
},
"language_info": {
"file_extension": ".py",
"pygments_lexer": "ipython3",
"name": "python",
"version": "3.5.2",
"nbconvert_exporter": "python",
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"mimetype": "text/x-python"
},
"gist": {
"id": "fbc8f9f42fe04c2ac37cb40e363decd1",
"data": {
"description": "ResNet vs VanillaCNN on CIFAR10",
"public": true
}
},
"_draft": {
"nbviewer_url": "https://gist.github.com/fbc8f9f42fe04c2ac37cb40e363decd1"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment