Skip to content

Instantly share code, notes, and snippets.

@heaven00
Created October 26, 2018 17:10
Show Gist options
  • Save heaven00/0a8b1fff3d6aff6f96347d2786f90278 to your computer and use it in GitHub Desktop.
Save heaven00/0a8b1fff3d6aff6f96347d2786f90278 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"colab_type": "code",
"id": "K70hAckqg0EA",
"outputId": "2fc6e1dd-b97c-46c9-9b74-1738093a5255"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[33mYou are using pip version 18.0, however version 18.1 is available.\r\n",
"You should consider upgrading via the 'pip install --upgrade pip' command.\u001b[0m\r\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/jayant/anaconda3/lib/python3.6/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n",
" from ._conv import register_converters as _register_converters\n",
"Using TensorFlow backend.\n",
"/home/jayant/anaconda3/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: compiletime version 3.5 of module 'tensorflow.python.framework.fast_tensor_util' does not match runtime version 3.6\n",
" return f(*args, **kwds)\n"
]
}
],
"source": [
"# https://keras.io/\n",
"!pip install -q keras\n",
"import keras"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "wVIx_KIigxPV"
},
"outputs": [],
"source": [
"import keras\n",
"from keras.datasets import cifar10\n",
"from keras.models import Model, Sequential\n",
"from keras.layers import Dense, Dropout, Flatten, Input, AveragePooling2D, merge, Activation\n",
"from keras.layers import Conv2D, MaxPooling2D, BatchNormalization\n",
"from keras.layers import Concatenate\n",
"from keras.optimizers import Adam, SGD\n",
"from keras.preprocessing.image import ImageDataGenerator\n",
"\n",
"from sklearn.metrics import confusion_matrix\n",
"import seaborn as sns\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"np.set_printoptions(suppress=True)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "UNHw6luQg3gc"
},
"outputs": [],
"source": [
"# this part will prevent tensorflow to allocate all the avaliable GPU Memory\n",
"# backend\n",
"import tensorflow as tf\n",
"from keras import backend as k\n",
"\n",
"# Don't pre-allocate memory; allocate as-needed\n",
"config = tf.ConfigProto()\n",
"config.gpu_options.allow_growth = True\n",
"\n",
"# Create a session with the above options specified.\n",
"k.tensorflow_backend.set_session(tf.Session(config=config))"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "dsO_yGxcg5D8"
},
"outputs": [],
"source": [
"# Hyperparameters\n",
"batch_size = 200\n",
"num_classes = 10\n",
"epochs = 1\n",
"l = 40\n",
"num_filter = 12\n",
"compression = 0.5\n",
"dropout_rate = 0.2"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 51
},
"colab_type": "code",
"id": "mB7o3zu1g6eT",
"outputId": "c1cea922-a38d-45da-f9d8-7977ab9c2dd2"
},
"outputs": [],
"source": [
"# Load CIFAR10 Data\n",
"(x_train, train_y), (x_test, test_y) = cifar10.load_data()\n",
"img_height, img_width, channel = x_train.shape[1],x_train.shape[2],x_train.shape[3]\n",
"\n",
"# convert to one hot encoing \n",
"y_train = keras.utils.to_categorical(train_y, num_classes)\n",
"y_test = keras.utils.to_categorical(test_y, num_classes)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "ee-sge5Kg7vr"
},
"outputs": [],
"source": [
"# Dense Block\n",
"def add_denseblock(input, num_filter = 12, dropout_rate = 0.2):\n",
" global compression\n",
" temp = input\n",
" for _ in range(l):\n",
" BatchNorm = BatchNormalization()(temp)\n",
" relu = Activation('relu')(BatchNorm)\n",
" Conv2D_3_3 = Conv2D(int(num_filter*compression), (3,3), use_bias=False ,padding='same')(relu)\n",
" if dropout_rate>0:\n",
" Conv2D_3_3 = Dropout(dropout_rate)(Conv2D_3_3)\n",
" concat = Concatenate(axis=-1)([temp,Conv2D_3_3])\n",
" \n",
" temp = concat\n",
" \n",
" return temp"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "OOP6IPsGhBwb"
},
"outputs": [],
"source": [
"def add_transition(input, num_filter = 12, dropout_rate = 0.2):\n",
" global compression\n",
" BatchNorm = BatchNormalization()(input)\n",
" relu = Activation('relu')(BatchNorm)\n",
" Conv2D_BottleNeck = Conv2D(int(num_filter*compression), (1,1), use_bias=False ,padding='same')(relu)\n",
" if dropout_rate>0:\n",
" Conv2D_BottleNeck = Dropout(dropout_rate)(Conv2D_BottleNeck)\n",
" avg = AveragePooling2D(pool_size=(2,2))(Conv2D_BottleNeck)\n",
" \n",
" return avg"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "0RaKFpubhDIC"
},
"outputs": [],
"source": [
"def output_layer(input):\n",
" global compression\n",
" BatchNorm = BatchNormalization()(input)\n",
" relu = Activation('relu')(BatchNorm)\n",
" AvgPooling = AveragePooling2D(pool_size=(2,2))(relu)\n",
" flat = Flatten()(AvgPooling)\n",
" output = Dense(num_classes, activation=tf.nn.softmax)(flat)\n",
" \n",
" return output"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"def plot_confusion_matrix(y_pred, y_test):\n",
" sns.heatmap(confusion_matrix(y_test.argmax(axis=1), y_pred.argmax(axis=1)))\n",
" plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"def rand_by_mask(mask, n=4):\n",
" return np.random.choice(np.where(mask)[0], n, replace=False)\n",
"\n",
"def plot_imgs(mask, x_test, y_test, pred, figsize=(12,6)):\n",
" f = plt.figure(figsize=figsize)\n",
" for index, idx in enumerate(mask):\n",
" sp = f.add_subplot(1, len(mask), index + 1)\n",
" sp.axis('Off')\n",
" title = 'actual:{0}, pred:{1}'.format(y_test[idx], pred[idx])\n",
" sp.set_title(title)\n",
" plt.imshow(x_test[idx])\n"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"# Hyperparameters\n",
"batch_size = 250\n",
"num_classes = 10\n",
"epochs = 25\n",
"l = 12\n",
"num_filter = 12\n",
"compression = 1\n",
"dropout_rate = 0"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "anPCpQWhhGb7"
},
"outputs": [],
"source": [
"input = Input(shape=(img_height, img_width, channel,))\n",
"First_Conv2D = Conv2D(num_filter, (3,3), use_bias=False ,padding='same')(input)\n",
"\n",
"First_Block = add_denseblock(First_Conv2D, num_filter, dropout_rate)\n",
"First_Transition = add_transition(First_Block, num_filter, dropout_rate)\n",
"\n",
"Second_Block = add_denseblock(First_Transition, num_filter, dropout_rate)\n",
"Second_Transition = add_transition(Second_Block, num_filter, dropout_rate)\n",
"\n",
"Third_Block = add_denseblock(Second_Transition, num_filter, dropout_rate)\n",
"Third_Transition = add_transition(Third_Block, num_filter, dropout_rate)\n",
"\n",
"Last_Block = add_denseblock(Third_Transition, num_filter, dropout_rate)\n",
"output = output_layer(Last_Block)\n"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 9860
},
"colab_type": "code",
"id": "1kFh7pdxhNtT",
"outputId": "160abc05-9e09-4454-d453-0e33a7d95796"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"__________________________________________________________________________________________________\n",
"Layer (type) Output Shape Param # Connected to \n",
"==================================================================================================\n",
"input_1 (InputLayer) (None, 32, 32, 3) 0 \n",
"__________________________________________________________________________________________________\n",
"conv2d_1 (Conv2D) (None, 32, 32, 12) 324 input_1[0][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_1 (BatchNor (None, 32, 32, 12) 48 conv2d_1[0][0] \n",
"__________________________________________________________________________________________________\n",
"activation_1 (Activation) (None, 32, 32, 12) 0 batch_normalization_1[0][0] \n",
"__________________________________________________________________________________________________\n",
"conv2d_2 (Conv2D) (None, 32, 32, 12) 1296 activation_1[0][0] \n",
"__________________________________________________________________________________________________\n",
"concatenate_1 (Concatenate) (None, 32, 32, 24) 0 conv2d_1[0][0] \n",
" conv2d_2[0][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_2 (BatchNor (None, 32, 32, 24) 96 concatenate_1[0][0] \n",
"__________________________________________________________________________________________________\n",
"activation_2 (Activation) (None, 32, 32, 24) 0 batch_normalization_2[0][0] \n",
"__________________________________________________________________________________________________\n",
"conv2d_3 (Conv2D) (None, 32, 32, 12) 2592 activation_2[0][0] \n",
"__________________________________________________________________________________________________\n",
"concatenate_2 (Concatenate) (None, 32, 32, 36) 0 concatenate_1[0][0] \n",
" conv2d_3[0][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_3 (BatchNor (None, 32, 32, 36) 144 concatenate_2[0][0] \n",
"__________________________________________________________________________________________________\n",
"activation_3 (Activation) (None, 32, 32, 36) 0 batch_normalization_3[0][0] \n",
"__________________________________________________________________________________________________\n",
"conv2d_4 (Conv2D) (None, 32, 32, 12) 3888 activation_3[0][0] \n",
"__________________________________________________________________________________________________\n",
"concatenate_3 (Concatenate) (None, 32, 32, 48) 0 concatenate_2[0][0] \n",
" conv2d_4[0][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_4 (BatchNor (None, 32, 32, 48) 192 concatenate_3[0][0] \n",
"__________________________________________________________________________________________________\n",
"activation_4 (Activation) (None, 32, 32, 48) 0 batch_normalization_4[0][0] \n",
"__________________________________________________________________________________________________\n",
"conv2d_5 (Conv2D) (None, 32, 32, 12) 5184 activation_4[0][0] \n",
"__________________________________________________________________________________________________\n",
"concatenate_4 (Concatenate) (None, 32, 32, 60) 0 concatenate_3[0][0] \n",
" conv2d_5[0][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_5 (BatchNor (None, 32, 32, 60) 240 concatenate_4[0][0] \n",
"__________________________________________________________________________________________________\n",
"activation_5 (Activation) (None, 32, 32, 60) 0 batch_normalization_5[0][0] \n",
"__________________________________________________________________________________________________\n",
"conv2d_6 (Conv2D) (None, 32, 32, 12) 6480 activation_5[0][0] \n",
"__________________________________________________________________________________________________\n",
"concatenate_5 (Concatenate) (None, 32, 32, 72) 0 concatenate_4[0][0] \n",
" conv2d_6[0][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_6 (BatchNor (None, 32, 32, 72) 288 concatenate_5[0][0] \n",
"__________________________________________________________________________________________________\n",
"activation_6 (Activation) (None, 32, 32, 72) 0 batch_normalization_6[0][0] \n",
"__________________________________________________________________________________________________\n",
"conv2d_7 (Conv2D) (None, 32, 32, 12) 7776 activation_6[0][0] \n",
"__________________________________________________________________________________________________\n",
"concatenate_6 (Concatenate) (None, 32, 32, 84) 0 concatenate_5[0][0] \n",
" conv2d_7[0][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_7 (BatchNor (None, 32, 32, 84) 336 concatenate_6[0][0] \n",
"__________________________________________________________________________________________________\n",
"activation_7 (Activation) (None, 32, 32, 84) 0 batch_normalization_7[0][0] \n",
"__________________________________________________________________________________________________\n",
"conv2d_8 (Conv2D) (None, 32, 32, 12) 9072 activation_7[0][0] \n",
"__________________________________________________________________________________________________\n",
"concatenate_7 (Concatenate) (None, 32, 32, 96) 0 concatenate_6[0][0] \n",
" conv2d_8[0][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_8 (BatchNor (None, 32, 32, 96) 384 concatenate_7[0][0] \n",
"__________________________________________________________________________________________________\n",
"activation_8 (Activation) (None, 32, 32, 96) 0 batch_normalization_8[0][0] \n",
"__________________________________________________________________________________________________\n",
"conv2d_9 (Conv2D) (None, 32, 32, 12) 10368 activation_8[0][0] \n",
"__________________________________________________________________________________________________\n",
"concatenate_8 (Concatenate) (None, 32, 32, 108) 0 concatenate_7[0][0] \n",
" conv2d_9[0][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_9 (BatchNor (None, 32, 32, 108) 432 concatenate_8[0][0] \n",
"__________________________________________________________________________________________________\n",
"activation_9 (Activation) (None, 32, 32, 108) 0 batch_normalization_9[0][0] \n",
"__________________________________________________________________________________________________\n",
"conv2d_10 (Conv2D) (None, 32, 32, 12) 11664 activation_9[0][0] \n",
"__________________________________________________________________________________________________\n",
"concatenate_9 (Concatenate) (None, 32, 32, 120) 0 concatenate_8[0][0] \n",
" conv2d_10[0][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_10 (BatchNo (None, 32, 32, 120) 480 concatenate_9[0][0] \n",
"__________________________________________________________________________________________________\n",
"activation_10 (Activation) (None, 32, 32, 120) 0 batch_normalization_10[0][0] \n",
"__________________________________________________________________________________________________\n",
"conv2d_11 (Conv2D) (None, 32, 32, 12) 12960 activation_10[0][0] \n",
"__________________________________________________________________________________________________\n",
"concatenate_10 (Concatenate) (None, 32, 32, 132) 0 concatenate_9[0][0] \n",
" conv2d_11[0][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_11 (BatchNo (None, 32, 32, 132) 528 concatenate_10[0][0] \n",
"__________________________________________________________________________________________________\n",
"activation_11 (Activation) (None, 32, 32, 132) 0 batch_normalization_11[0][0] \n",
"__________________________________________________________________________________________________\n",
"conv2d_12 (Conv2D) (None, 32, 32, 12) 14256 activation_11[0][0] \n",
"__________________________________________________________________________________________________\n",
"concatenate_11 (Concatenate) (None, 32, 32, 144) 0 concatenate_10[0][0] \n",
" conv2d_12[0][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_12 (BatchNo (None, 32, 32, 144) 576 concatenate_11[0][0] \n",
"__________________________________________________________________________________________________\n",
"activation_12 (Activation) (None, 32, 32, 144) 0 batch_normalization_12[0][0] \n",
"__________________________________________________________________________________________________\n",
"conv2d_13 (Conv2D) (None, 32, 32, 12) 15552 activation_12[0][0] \n",
"__________________________________________________________________________________________________\n",
"concatenate_12 (Concatenate) (None, 32, 32, 156) 0 concatenate_11[0][0] \n",
" conv2d_13[0][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_13 (BatchNo (None, 32, 32, 156) 624 concatenate_12[0][0] \n",
"__________________________________________________________________________________________________\n",
"activation_13 (Activation) (None, 32, 32, 156) 0 batch_normalization_13[0][0] \n",
"__________________________________________________________________________________________________\n",
"conv2d_14 (Conv2D) (None, 32, 32, 12) 1872 activation_13[0][0] \n",
"__________________________________________________________________________________________________\n",
"average_pooling2d_1 (AveragePoo (None, 16, 16, 12) 0 conv2d_14[0][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_14 (BatchNo (None, 16, 16, 12) 48 average_pooling2d_1[0][0] \n",
"__________________________________________________________________________________________________\n",
"activation_14 (Activation) (None, 16, 16, 12) 0 batch_normalization_14[0][0] \n",
"__________________________________________________________________________________________________\n",
"conv2d_15 (Conv2D) (None, 16, 16, 12) 1296 activation_14[0][0] \n",
"__________________________________________________________________________________________________\n",
"concatenate_13 (Concatenate) (None, 16, 16, 24) 0 average_pooling2d_1[0][0] \n",
" conv2d_15[0][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_15 (BatchNo (None, 16, 16, 24) 96 concatenate_13[0][0] \n",
"__________________________________________________________________________________________________\n",
"activation_15 (Activation) (None, 16, 16, 24) 0 batch_normalization_15[0][0] \n",
"__________________________________________________________________________________________________\n",
"conv2d_16 (Conv2D) (None, 16, 16, 12) 2592 activation_15[0][0] \n",
"__________________________________________________________________________________________________\n",
"concatenate_14 (Concatenate) (None, 16, 16, 36) 0 concatenate_13[0][0] \n",
" conv2d_16[0][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_16 (BatchNo (None, 16, 16, 36) 144 concatenate_14[0][0] \n",
"__________________________________________________________________________________________________\n",
"activation_16 (Activation) (None, 16, 16, 36) 0 batch_normalization_16[0][0] \n",
"__________________________________________________________________________________________________\n",
"conv2d_17 (Conv2D) (None, 16, 16, 12) 3888 activation_16[0][0] \n",
"__________________________________________________________________________________________________\n",
"concatenate_15 (Concatenate) (None, 16, 16, 48) 0 concatenate_14[0][0] \n",
" conv2d_17[0][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_17 (BatchNo (None, 16, 16, 48) 192 concatenate_15[0][0] \n",
"__________________________________________________________________________________________________\n",
"activation_17 (Activation) (None, 16, 16, 48) 0 batch_normalization_17[0][0] \n",
"__________________________________________________________________________________________________\n",
"conv2d_18 (Conv2D) (None, 16, 16, 12) 5184 activation_17[0][0] \n",
"__________________________________________________________________________________________________\n",
"concatenate_16 (Concatenate) (None, 16, 16, 60) 0 concatenate_15[0][0] \n",
" conv2d_18[0][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_18 (BatchNo (None, 16, 16, 60) 240 concatenate_16[0][0] \n",
"__________________________________________________________________________________________________\n",
"activation_18 (Activation) (None, 16, 16, 60) 0 batch_normalization_18[0][0] \n",
"__________________________________________________________________________________________________\n",
"conv2d_19 (Conv2D) (None, 16, 16, 12) 6480 activation_18[0][0] \n",
"__________________________________________________________________________________________________\n",
"concatenate_17 (Concatenate) (None, 16, 16, 72) 0 concatenate_16[0][0] \n",
" conv2d_19[0][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_19 (BatchNo (None, 16, 16, 72) 288 concatenate_17[0][0] \n",
"__________________________________________________________________________________________________\n",
"activation_19 (Activation) (None, 16, 16, 72) 0 batch_normalization_19[0][0] \n",
"__________________________________________________________________________________________________\n",
"conv2d_20 (Conv2D) (None, 16, 16, 12) 7776 activation_19[0][0] \n",
"__________________________________________________________________________________________________\n",
"concatenate_18 (Concatenate) (None, 16, 16, 84) 0 concatenate_17[0][0] \n",
" conv2d_20[0][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_20 (BatchNo (None, 16, 16, 84) 336 concatenate_18[0][0] \n",
"__________________________________________________________________________________________________\n",
"activation_20 (Activation) (None, 16, 16, 84) 0 batch_normalization_20[0][0] \n",
"__________________________________________________________________________________________________\n",
"conv2d_21 (Conv2D) (None, 16, 16, 12) 9072 activation_20[0][0] \n",
"__________________________________________________________________________________________________\n",
"concatenate_19 (Concatenate) (None, 16, 16, 96) 0 concatenate_18[0][0] \n",
" conv2d_21[0][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_21 (BatchNo (None, 16, 16, 96) 384 concatenate_19[0][0] \n",
"__________________________________________________________________________________________________\n",
"activation_21 (Activation) (None, 16, 16, 96) 0 batch_normalization_21[0][0] \n",
"__________________________________________________________________________________________________\n",
"conv2d_22 (Conv2D) (None, 16, 16, 12) 10368 activation_21[0][0] \n",
"__________________________________________________________________________________________________\n",
"concatenate_20 (Concatenate) (None, 16, 16, 108) 0 concatenate_19[0][0] \n",
" conv2d_22[0][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_22 (BatchNo (None, 16, 16, 108) 432 concatenate_20[0][0] \n",
"__________________________________________________________________________________________________\n",
"activation_22 (Activation) (None, 16, 16, 108) 0 batch_normalization_22[0][0] \n",
"__________________________________________________________________________________________________\n",
"conv2d_23 (Conv2D) (None, 16, 16, 12) 11664 activation_22[0][0] \n",
"__________________________________________________________________________________________________\n",
"concatenate_21 (Concatenate) (None, 16, 16, 120) 0 concatenate_20[0][0] \n",
" conv2d_23[0][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_23 (BatchNo (None, 16, 16, 120) 480 concatenate_21[0][0] \n",
"__________________________________________________________________________________________________\n",
"activation_23 (Activation) (None, 16, 16, 120) 0 batch_normalization_23[0][0] \n",
"__________________________________________________________________________________________________\n",
"conv2d_24 (Conv2D) (None, 16, 16, 12) 12960 activation_23[0][0] \n",
"__________________________________________________________________________________________________\n",
"concatenate_22 (Concatenate) (None, 16, 16, 132) 0 concatenate_21[0][0] \n",
" conv2d_24[0][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_24 (BatchNo (None, 16, 16, 132) 528 concatenate_22[0][0] \n",
"__________________________________________________________________________________________________\n",
"activation_24 (Activation) (None, 16, 16, 132) 0 batch_normalization_24[0][0] \n",
"__________________________________________________________________________________________________\n",
"conv2d_25 (Conv2D) (None, 16, 16, 12) 14256 activation_24[0][0] \n",
"__________________________________________________________________________________________________\n",
"concatenate_23 (Concatenate) (None, 16, 16, 144) 0 concatenate_22[0][0] \n",
" conv2d_25[0][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_25 (BatchNo (None, 16, 16, 144) 576 concatenate_23[0][0] \n",
"__________________________________________________________________________________________________\n",
"activation_25 (Activation) (None, 16, 16, 144) 0 batch_normalization_25[0][0] \n",
"__________________________________________________________________________________________________\n",
"conv2d_26 (Conv2D) (None, 16, 16, 12) 15552 activation_25[0][0] \n",
"__________________________________________________________________________________________________\n",
"concatenate_24 (Concatenate) (None, 16, 16, 156) 0 concatenate_23[0][0] \n",
" conv2d_26[0][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_26 (BatchNo (None, 16, 16, 156) 624 concatenate_24[0][0] \n",
"__________________________________________________________________________________________________\n",
"activation_26 (Activation) (None, 16, 16, 156) 0 batch_normalization_26[0][0] \n",
"__________________________________________________________________________________________________\n",
"conv2d_27 (Conv2D) (None, 16, 16, 12) 1872 activation_26[0][0] \n",
"__________________________________________________________________________________________________\n",
"average_pooling2d_2 (AveragePoo (None, 8, 8, 12) 0 conv2d_27[0][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_27 (BatchNo (None, 8, 8, 12) 48 average_pooling2d_2[0][0] \n",
"__________________________________________________________________________________________________\n",
"activation_27 (Activation) (None, 8, 8, 12) 0 batch_normalization_27[0][0] \n",
"__________________________________________________________________________________________________\n",
"conv2d_28 (Conv2D) (None, 8, 8, 12) 1296 activation_27[0][0] \n",
"__________________________________________________________________________________________________\n",
"concatenate_25 (Concatenate) (None, 8, 8, 24) 0 average_pooling2d_2[0][0] \n",
" conv2d_28[0][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_28 (BatchNo (None, 8, 8, 24) 96 concatenate_25[0][0] \n",
"__________________________________________________________________________________________________\n",
"activation_28 (Activation) (None, 8, 8, 24) 0 batch_normalization_28[0][0] \n",
"__________________________________________________________________________________________________\n",
"conv2d_29 (Conv2D) (None, 8, 8, 12) 2592 activation_28[0][0] \n",
"__________________________________________________________________________________________________\n",
"concatenate_26 (Concatenate) (None, 8, 8, 36) 0 concatenate_25[0][0] \n",
" conv2d_29[0][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_29 (BatchNo (None, 8, 8, 36) 144 concatenate_26[0][0] \n",
"__________________________________________________________________________________________________\n",
"activation_29 (Activation) (None, 8, 8, 36) 0 batch_normalization_29[0][0] \n",
"__________________________________________________________________________________________________\n",
"conv2d_30 (Conv2D) (None, 8, 8, 12) 3888 activation_29[0][0] \n",
"__________________________________________________________________________________________________\n",
"concatenate_27 (Concatenate) (None, 8, 8, 48) 0 concatenate_26[0][0] \n",
" conv2d_30[0][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_30 (BatchNo (None, 8, 8, 48) 192 concatenate_27[0][0] \n",
"__________________________________________________________________________________________________\n",
"activation_30 (Activation) (None, 8, 8, 48) 0 batch_normalization_30[0][0] \n",
"__________________________________________________________________________________________________\n",
"conv2d_31 (Conv2D) (None, 8, 8, 12) 5184 activation_30[0][0] \n",
"__________________________________________________________________________________________________\n",
"concatenate_28 (Concatenate) (None, 8, 8, 60) 0 concatenate_27[0][0] \n",
" conv2d_31[0][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_31 (BatchNo (None, 8, 8, 60) 240 concatenate_28[0][0] \n",
"__________________________________________________________________________________________________\n",
"activation_31 (Activation) (None, 8, 8, 60) 0 batch_normalization_31[0][0] \n",
"__________________________________________________________________________________________________\n",
"conv2d_32 (Conv2D) (None, 8, 8, 12) 6480 activation_31[0][0] \n",
"__________________________________________________________________________________________________\n",
"concatenate_29 (Concatenate) (None, 8, 8, 72) 0 concatenate_28[0][0] \n",
" conv2d_32[0][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_32 (BatchNo (None, 8, 8, 72) 288 concatenate_29[0][0] \n",
"__________________________________________________________________________________________________\n",
"activation_32 (Activation) (None, 8, 8, 72) 0 batch_normalization_32[0][0] \n",
"__________________________________________________________________________________________________\n",
"conv2d_33 (Conv2D) (None, 8, 8, 12) 7776 activation_32[0][0] \n",
"__________________________________________________________________________________________________\n",
"concatenate_30 (Concatenate) (None, 8, 8, 84) 0 concatenate_29[0][0] \n",
" conv2d_33[0][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_33 (BatchNo (None, 8, 8, 84) 336 concatenate_30[0][0] \n",
"__________________________________________________________________________________________________\n",
"activation_33 (Activation) (None, 8, 8, 84) 0 batch_normalization_33[0][0] \n",
"__________________________________________________________________________________________________\n",
"conv2d_34 (Conv2D) (None, 8, 8, 12) 9072 activation_33[0][0] \n",
"__________________________________________________________________________________________________\n",
"concatenate_31 (Concatenate) (None, 8, 8, 96) 0 concatenate_30[0][0] \n",
" conv2d_34[0][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_34 (BatchNo (None, 8, 8, 96) 384 concatenate_31[0][0] \n",
"__________________________________________________________________________________________________\n",
"activation_34 (Activation) (None, 8, 8, 96) 0 batch_normalization_34[0][0] \n",
"__________________________________________________________________________________________________\n",
"conv2d_35 (Conv2D) (None, 8, 8, 12) 10368 activation_34[0][0] \n",
"__________________________________________________________________________________________________\n",
"concatenate_32 (Concatenate) (None, 8, 8, 108) 0 concatenate_31[0][0] \n",
" conv2d_35[0][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_35 (BatchNo (None, 8, 8, 108) 432 concatenate_32[0][0] \n",
"__________________________________________________________________________________________________\n",
"activation_35 (Activation) (None, 8, 8, 108) 0 batch_normalization_35[0][0] \n",
"__________________________________________________________________________________________________\n",
"conv2d_36 (Conv2D) (None, 8, 8, 12) 11664 activation_35[0][0] \n",
"__________________________________________________________________________________________________\n",
"concatenate_33 (Concatenate) (None, 8, 8, 120) 0 concatenate_32[0][0] \n",
" conv2d_36[0][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_36 (BatchNo (None, 8, 8, 120) 480 concatenate_33[0][0] \n",
"__________________________________________________________________________________________________\n",
"activation_36 (Activation) (None, 8, 8, 120) 0 batch_normalization_36[0][0] \n",
"__________________________________________________________________________________________________\n",
"conv2d_37 (Conv2D) (None, 8, 8, 12) 12960 activation_36[0][0] \n",
"__________________________________________________________________________________________________\n",
"concatenate_34 (Concatenate) (None, 8, 8, 132) 0 concatenate_33[0][0] \n",
" conv2d_37[0][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_37 (BatchNo (None, 8, 8, 132) 528 concatenate_34[0][0] \n",
"__________________________________________________________________________________________________\n",
"activation_37 (Activation) (None, 8, 8, 132) 0 batch_normalization_37[0][0] \n",
"__________________________________________________________________________________________________\n",
"conv2d_38 (Conv2D) (None, 8, 8, 12) 14256 activation_37[0][0] \n",
"__________________________________________________________________________________________________\n",
"concatenate_35 (Concatenate) (None, 8, 8, 144) 0 concatenate_34[0][0] \n",
" conv2d_38[0][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_38 (BatchNo (None, 8, 8, 144) 576 concatenate_35[0][0] \n",
"__________________________________________________________________________________________________\n",
"activation_38 (Activation) (None, 8, 8, 144) 0 batch_normalization_38[0][0] \n",
"__________________________________________________________________________________________________\n",
"conv2d_39 (Conv2D) (None, 8, 8, 12) 15552 activation_38[0][0] \n",
"__________________________________________________________________________________________________\n",
"concatenate_36 (Concatenate) (None, 8, 8, 156) 0 concatenate_35[0][0] \n",
" conv2d_39[0][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_39 (BatchNo (None, 8, 8, 156) 624 concatenate_36[0][0] \n",
"__________________________________________________________________________________________________\n",
"activation_39 (Activation) (None, 8, 8, 156) 0 batch_normalization_39[0][0] \n",
"__________________________________________________________________________________________________\n",
"conv2d_40 (Conv2D) (None, 8, 8, 12) 1872 activation_39[0][0] \n",
"__________________________________________________________________________________________________\n",
"average_pooling2d_3 (AveragePoo (None, 4, 4, 12) 0 conv2d_40[0][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_40 (BatchNo (None, 4, 4, 12) 48 average_pooling2d_3[0][0] \n",
"__________________________________________________________________________________________________\n",
"activation_40 (Activation) (None, 4, 4, 12) 0 batch_normalization_40[0][0] \n",
"__________________________________________________________________________________________________\n",
"conv2d_41 (Conv2D) (None, 4, 4, 12) 1296 activation_40[0][0] \n",
"__________________________________________________________________________________________________\n",
"concatenate_37 (Concatenate) (None, 4, 4, 24) 0 average_pooling2d_3[0][0] \n",
" conv2d_41[0][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_41 (BatchNo (None, 4, 4, 24) 96 concatenate_37[0][0] \n",
"__________________________________________________________________________________________________\n",
"activation_41 (Activation) (None, 4, 4, 24) 0 batch_normalization_41[0][0] \n",
"__________________________________________________________________________________________________\n",
"conv2d_42 (Conv2D) (None, 4, 4, 12) 2592 activation_41[0][0] \n",
"__________________________________________________________________________________________________\n",
"concatenate_38 (Concatenate) (None, 4, 4, 36) 0 concatenate_37[0][0] \n",
" conv2d_42[0][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_42 (BatchNo (None, 4, 4, 36) 144 concatenate_38[0][0] \n",
"__________________________________________________________________________________________________\n",
"activation_42 (Activation) (None, 4, 4, 36) 0 batch_normalization_42[0][0] \n",
"__________________________________________________________________________________________________\n",
"conv2d_43 (Conv2D) (None, 4, 4, 12) 3888 activation_42[0][0] \n",
"__________________________________________________________________________________________________\n",
"concatenate_39 (Concatenate) (None, 4, 4, 48) 0 concatenate_38[0][0] \n",
" conv2d_43[0][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_43 (BatchNo (None, 4, 4, 48) 192 concatenate_39[0][0] \n",
"__________________________________________________________________________________________________\n",
"activation_43 (Activation) (None, 4, 4, 48) 0 batch_normalization_43[0][0] \n",
"__________________________________________________________________________________________________\n",
"conv2d_44 (Conv2D) (None, 4, 4, 12) 5184 activation_43[0][0] \n",
"__________________________________________________________________________________________________\n",
"concatenate_40 (Concatenate) (None, 4, 4, 60) 0 concatenate_39[0][0] \n",
" conv2d_44[0][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_44 (BatchNo (None, 4, 4, 60) 240 concatenate_40[0][0] \n",
"__________________________________________________________________________________________________\n",
"activation_44 (Activation) (None, 4, 4, 60) 0 batch_normalization_44[0][0] \n",
"__________________________________________________________________________________________________\n",
"conv2d_45 (Conv2D) (None, 4, 4, 12) 6480 activation_44[0][0] \n",
"__________________________________________________________________________________________________\n",
"concatenate_41 (Concatenate) (None, 4, 4, 72) 0 concatenate_40[0][0] \n",
" conv2d_45[0][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_45 (BatchNo (None, 4, 4, 72) 288 concatenate_41[0][0] \n",
"__________________________________________________________________________________________________\n",
"activation_45 (Activation) (None, 4, 4, 72) 0 batch_normalization_45[0][0] \n",
"__________________________________________________________________________________________________\n",
"conv2d_46 (Conv2D) (None, 4, 4, 12) 7776 activation_45[0][0] \n",
"__________________________________________________________________________________________________\n",
"concatenate_42 (Concatenate) (None, 4, 4, 84) 0 concatenate_41[0][0] \n",
" conv2d_46[0][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_46 (BatchNo (None, 4, 4, 84) 336 concatenate_42[0][0] \n",
"__________________________________________________________________________________________________\n",
"activation_46 (Activation) (None, 4, 4, 84) 0 batch_normalization_46[0][0] \n",
"__________________________________________________________________________________________________\n",
"conv2d_47 (Conv2D) (None, 4, 4, 12) 9072 activation_46[0][0] \n",
"__________________________________________________________________________________________________\n",
"concatenate_43 (Concatenate) (None, 4, 4, 96) 0 concatenate_42[0][0] \n",
" conv2d_47[0][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_47 (BatchNo (None, 4, 4, 96) 384 concatenate_43[0][0] \n",
"__________________________________________________________________________________________________\n",
"activation_47 (Activation) (None, 4, 4, 96) 0 batch_normalization_47[0][0] \n",
"__________________________________________________________________________________________________\n",
"conv2d_48 (Conv2D) (None, 4, 4, 12) 10368 activation_47[0][0] \n",
"__________________________________________________________________________________________________\n",
"concatenate_44 (Concatenate) (None, 4, 4, 108) 0 concatenate_43[0][0] \n",
" conv2d_48[0][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_48 (BatchNo (None, 4, 4, 108) 432 concatenate_44[0][0] \n",
"__________________________________________________________________________________________________\n",
"activation_48 (Activation) (None, 4, 4, 108) 0 batch_normalization_48[0][0] \n",
"__________________________________________________________________________________________________\n",
"conv2d_49 (Conv2D) (None, 4, 4, 12) 11664 activation_48[0][0] \n",
"__________________________________________________________________________________________________\n",
"concatenate_45 (Concatenate) (None, 4, 4, 120) 0 concatenate_44[0][0] \n",
" conv2d_49[0][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_49 (BatchNo (None, 4, 4, 120) 480 concatenate_45[0][0] \n",
"__________________________________________________________________________________________________\n",
"activation_49 (Activation) (None, 4, 4, 120) 0 batch_normalization_49[0][0] \n",
"__________________________________________________________________________________________________\n",
"conv2d_50 (Conv2D) (None, 4, 4, 12) 12960 activation_49[0][0] \n",
"__________________________________________________________________________________________________\n",
"concatenate_46 (Concatenate) (None, 4, 4, 132) 0 concatenate_45[0][0] \n",
" conv2d_50[0][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_50 (BatchNo (None, 4, 4, 132) 528 concatenate_46[0][0] \n",
"__________________________________________________________________________________________________\n",
"activation_50 (Activation) (None, 4, 4, 132) 0 batch_normalization_50[0][0] \n",
"__________________________________________________________________________________________________\n",
"conv2d_51 (Conv2D) (None, 4, 4, 12) 14256 activation_50[0][0] \n",
"__________________________________________________________________________________________________\n",
"concatenate_47 (Concatenate) (None, 4, 4, 144) 0 concatenate_46[0][0] \n",
" conv2d_51[0][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_51 (BatchNo (None, 4, 4, 144) 576 concatenate_47[0][0] \n",
"__________________________________________________________________________________________________\n",
"activation_51 (Activation) (None, 4, 4, 144) 0 batch_normalization_51[0][0] \n",
"__________________________________________________________________________________________________\n",
"conv2d_52 (Conv2D) (None, 4, 4, 12) 15552 activation_51[0][0] \n",
"__________________________________________________________________________________________________\n",
"concatenate_48 (Concatenate) (None, 4, 4, 156) 0 concatenate_47[0][0] \n",
" conv2d_52[0][0] \n",
"__________________________________________________________________________________________________\n",
"batch_normalization_52 (BatchNo (None, 4, 4, 156) 624 concatenate_48[0][0] \n",
"__________________________________________________________________________________________________\n",
"activation_52 (Activation) (None, 4, 4, 156) 0 batch_normalization_52[0][0] \n",
"__________________________________________________________________________________________________\n",
"average_pooling2d_4 (AveragePoo (None, 2, 2, 156) 0 activation_52[0][0] \n",
"__________________________________________________________________________________________________\n",
"flatten_1 (Flatten) (None, 624) 0 average_pooling2d_4[0][0] \n",
"__________________________________________________________________________________________________\n",
"dense_1 (Dense) (None, 10) 6250 flatten_1[0][0] \n",
"==================================================================================================\n",
"Total params: 434,014\n",
"Trainable params: 425,278\n",
"Non-trainable params: 8,736\n",
"__________________________________________________________________________________________________\n"
]
}
],
"source": [
"model = Model(inputs=[input], outputs=[output])\n",
"model.summary()"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "b4XOsW3ahSkL"
},
"outputs": [],
"source": [
"# determine Loss function and Optimizer\n",
"sgd = SGD(lr=0.1, momentum=0.9, nesterov=True)\n",
"model.compile(loss='categorical_crossentropy',\n",
" optimizer=sgd,\n",
" metrics=['accuracy'])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Image Augmentation\n",
"\n",
"- Rotation, the model is confused between dogs and cats at different rotation position\n",
"- Add zoom range \n",
"- Shift of up down\n",
"- Shift of left and right\n",
"- Shear range\n",
"- Horizontal flip"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"datagen = ImageDataGenerator(width_shift_range=0.05, height_shift_range=0.05, shear_range=0.05,\n",
" zoom_range=0.05, fill_mode='nearest', horizontal_flip=True)\n",
"datagen.fit(x_train)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1771
},
"colab_type": "code",
"id": "crhGk7kEhXAz",
"outputId": "e3e2d0d0-1492-41ab-df5b-5a7ecd705c2c",
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/25\n",
"200/200 [==============================] - 131s 656ms/step - loss: 1.7398 - acc: 0.3827 - val_loss: 2.3229 - val_acc: 0.2981\n",
"Epoch 2/25\n",
"200/200 [==============================] - 123s 617ms/step - loss: 1.3004 - acc: 0.5287 - val_loss: 2.4500 - val_acc: 0.3313\n",
"Epoch 3/25\n",
"200/200 [==============================] - 124s 620ms/step - loss: 1.0823 - acc: 0.6147 - val_loss: 1.2594 - val_acc: 0.5695\n",
"Epoch 4/25\n",
"200/200 [==============================] - 125s 623ms/step - loss: 0.9271 - acc: 0.6685 - val_loss: 1.3724 - val_acc: 0.5882\n",
"Epoch 5/25\n",
"200/200 [==============================] - 125s 625ms/step - loss: 0.8172 - acc: 0.7112 - val_loss: 1.2374 - val_acc: 0.5953\n",
"Epoch 6/25\n",
"200/200 [==============================] - 125s 626ms/step - loss: 0.7394 - acc: 0.7400 - val_loss: 1.2518 - val_acc: 0.6198\n",
"Epoch 7/25\n",
"200/200 [==============================] - 125s 626ms/step - loss: 0.6756 - acc: 0.7624 - val_loss: 0.9029 - val_acc: 0.6967\n",
"Epoch 8/25\n",
"200/200 [==============================] - 125s 627ms/step - loss: 0.6208 - acc: 0.7815 - val_loss: 0.9332 - val_acc: 0.6914\n",
"Epoch 9/25\n",
"200/200 [==============================] - 126s 628ms/step - loss: 0.5855 - acc: 0.7948 - val_loss: 0.9463 - val_acc: 0.6940\n",
"Epoch 10/25\n",
"200/200 [==============================] - 126s 628ms/step - loss: 0.5455 - acc: 0.8088 - val_loss: 0.9990 - val_acc: 0.6961\n",
"Epoch 11/25\n",
"200/200 [==============================] - 126s 628ms/step - loss: 0.5209 - acc: 0.8174 - val_loss: 0.9168 - val_acc: 0.7176\n",
"Epoch 12/25\n",
"200/200 [==============================] - 126s 628ms/step - loss: 0.4883 - acc: 0.8294 - val_loss: 0.6579 - val_acc: 0.7798\n",
"Epoch 13/25\n",
"200/200 [==============================] - 126s 628ms/step - loss: 0.4648 - acc: 0.8368 - val_loss: 0.8200 - val_acc: 0.7474\n",
"Epoch 14/25\n",
"200/200 [==============================] - 126s 628ms/step - loss: 0.4460 - acc: 0.8446 - val_loss: 0.7980 - val_acc: 0.7439\n",
"Epoch 15/25\n",
"200/200 [==============================] - 126s 628ms/step - loss: 0.4243 - acc: 0.8526 - val_loss: 0.6542 - val_acc: 0.7864\n",
"Epoch 16/25\n",
"200/200 [==============================] - 126s 628ms/step - loss: 0.4099 - acc: 0.8544 - val_loss: 0.5500 - val_acc: 0.8196\n",
"Epoch 17/25\n",
"200/200 [==============================] - 125s 627ms/step - loss: 0.3902 - acc: 0.8636 - val_loss: 0.5449 - val_acc: 0.8102\n",
"Epoch 18/25\n",
"200/200 [==============================] - 126s 628ms/step - loss: 0.3756 - acc: 0.8683 - val_loss: 0.9821 - val_acc: 0.7316\n",
"Epoch 19/25\n",
"200/200 [==============================] - 126s 629ms/step - loss: 0.3641 - acc: 0.8712 - val_loss: 0.5663 - val_acc: 0.8135\n",
"Epoch 20/25\n",
"200/200 [==============================] - 127s 633ms/step - loss: 0.3477 - acc: 0.8782 - val_loss: 0.6375 - val_acc: 0.7999\n",
"Epoch 21/25\n",
"200/200 [==============================] - 126s 631ms/step - loss: 0.3380 - acc: 0.8821 - val_loss: 0.8798 - val_acc: 0.7512\n",
"Epoch 22/25\n",
"200/200 [==============================] - 126s 631ms/step - loss: 0.3233 - acc: 0.8860 - val_loss: 0.7800 - val_acc: 0.7685\n",
"Epoch 23/25\n",
"200/200 [==============================] - 126s 632ms/step - loss: 0.3106 - acc: 0.8908 - val_loss: 0.5561 - val_acc: 0.8216\n",
"Epoch 24/25\n",
"200/200 [==============================] - 127s 633ms/step - loss: 0.2982 - acc: 0.8953 - val_loss: 0.5910 - val_acc: 0.8213\n",
"Epoch 25/25\n",
"200/200 [==============================] - 126s 631ms/step - loss: 0.2959 - acc: 0.8958 - val_loss: 0.6032 - val_acc: 0.8137\n"
]
},
{
"data": {
"text/plain": [
"<keras.callbacks.History at 0x7f2a498d3940>"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.fit_generator(datagen.flow(x_train, y_train, batch_size=batch_size),\n",
" epochs=epochs,\n",
" verbose=1,\n",
" validation_data=(x_test, y_test))"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 68
},
"colab_type": "code",
"id": "ZcWydmIVhZGr",
"outputId": "a0345aa5-79ff-4e56-eb94-50437b43c4fe"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"10000/10000 [==============================] - 11s 1ms/step\n",
"Test loss: 0.6031653834104538\n",
"Test accuracy: 0.8137\n"
]
}
],
"source": [
"# Test the model\n",
"score = model.evaluate(x_test, y_test, verbose=1)\n",
"print('Test loss:', score[0])\n",
"print('Test accuracy:', score[1])"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"predictions = model.predict(x_test)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x7f2ac69a6e10>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plot_confusion_matrix(predictions, y_test)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"pred_arg = predictions.argmax(axis=1)\n",
"y_argmax = y_test.argmax(axis=1)"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x7f2b2808b8d0>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Correct images\n",
"is_correct = rand_by_mask(pred_arg == y_argmax)\n",
"plot_imgs(is_correct, x_test, y_argmax, pred_arg)"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x7f2a42d30cf8>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# images that we got wrong\n",
"is_wrong = rand_by_mask(pred_arg != y_argmax)\n",
"plot_imgs(is_wrong, x_test, y_argmax, pred_arg)"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x7f2a42bed278>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# images that have actual prediction 5 and predicted cls 3\n",
"actual_cls = 5\n",
"predicted_cls = 3\n",
"actual_5_pred_3 = (pred_arg != y_argmax) & (pred_arg == predicted_cls) & (y_argmax == actual_cls)\n",
"is_wrong = rand_by_mask(actual_5_pred_3)\n",
"plot_imgs(is_wrong, x_test, y_argmax, pred_arg)"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x7f2a42afb9b0>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# images that have actual prediction 5 and predicted cls 3\n",
"actual_cls = 3\n",
"predicted_cls = 5\n",
"actual_5_pred_3 = (pred_arg == predicted_cls) & (y_argmax == actual_cls)\n",
"\n",
"is_correct = rand_by_mask(actual_5_pred_3)\n",
"plot_imgs(is_correct, x_test, y_argmax, pred_arg)"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"colab_type": "code",
"id": "UE3lF6EH1r_L",
"outputId": "92df862c-76a7-4a02-9533-6c164bc5984d"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Saved model to disk\n"
]
}
],
"source": [
"# Save the trained weights in to .h5 format\n",
"model.save_weights(\"DNST_model_changelog_row_15_and_more_filters.h5\")\n",
"print(\"Saved model to disk\")"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "ai-yZ2ED5AK1"
},
"outputs": [],
"source": [
"from google.colab import files\n",
"\n",
"files.download('DNST_model.h5')"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "Og56VCRh5j8V"
},
"outputs": [],
"source": []
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"name": "DNST_CIFAR10_AUG.ipynb",
"provenance": [],
"version": "0.3.2"
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.4"
}
},
"nbformat": 4,
"nbformat_minor": 1
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment