Skip to content

Instantly share code, notes, and snippets.

@sujee
Created May 21, 2020 07:13
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sujee/5aff46106d750fdd8d7f949c663b5de5 to your computer and use it in GitHub Desktop.
Save sujee/5aff46106d750fdd8d7f949c663b5de5 to your computer and use it in GitHub Desktop.
CNN for Mnist, tweaked for Tensorflow v2 GPU
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Lab : Using Convolutional Neural Networks (CNN) to identify digits\n",
"\n",
"In this lab, we are going to setup a CNN to classify MNIST"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## About MNIST Data\n",
"\n",
"MNIST is a widely used dataset of hand written digits.\n",
"\n",
"<img src=\"https://www.researchgate.net/profile/Steven_Young11/publication/306056875/figure/fig1/AS:393921575309346@1470929630835/Example-images-from-the-MNIST-dataset.png\" />"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Running in Google COLAB : False\n"
]
}
],
"source": [
"## Determine if we are running on google colab\n",
"\n",
"from __future__ import absolute_import, division, print_function, unicode_literals\n",
"import time\n",
"import os,sys\n",
"\n",
"try:\n",
" import google.colab\n",
" RUNNING_IN_COLAB = True\n",
"except:\n",
" RUNNING_IN_COLAB = False\n",
"\n",
"print (\"Running in Google COLAB : \", RUNNING_IN_COLAB)\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"try:\n",
" # %tensorflow_version only exists in Colab.\n",
" %tensorflow_version 2.x\n",
"except Exception:\n",
" pass\n",
"\n",
"## disable info logs from TF\n",
"# Level | Level for Humans | Level Description \n",
"# -------|------------------|------------------------------------ \n",
"# 0 | DEBUG | [Default] Print all messages \n",
"# 1 | INFO | Filter out INFO messages \n",
"# 2 | WARNING | Filter out INFO & WARNING messages \n",
"# 3 | ERROR | Filter out all messages \n",
"\n",
"os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # or any {'0', '1', '2'}\n",
"\n",
"import tensorflow as tf\n",
"tf.get_logger().setLevel('WARN')\n",
"from tensorflow import keras\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Device mapping:\n",
"/job:localhost/replica:0/task:0/device:XLA_CPU:0 -> device: XLA_CPU device\n",
"/job:localhost/replica:0/task:0/device:GPU:0 -> device: 0, name: GeForce RTX 2070, pci bus id: 0000:01:00.0, compute capability: 7.5\n",
"/job:localhost/replica:0/task:0/device:XLA_GPU:0 -> device: XLA_GPU device\n",
"\n"
]
}
],
"source": [
"# for running in GPU\n",
"\n",
"from tensorflow.compat.v1.keras.backend import set_session\n",
"config = tf.compat.v1.ConfigProto()\n",
"config.gpu_options.allow_growth = True # dynamically grow the memory used on the GPU\n",
"config.log_device_placement = True # to log device placement (on which device the operation ran)\n",
"sess = tf.compat.v1.Session(config=config)\n",
"set_session(sess)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Step 1: Download Data"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"train_images shape : (60000, 28, 28)\n",
"train_labels shape : (60000,)\n",
"test_images shape : (10000, 28, 28)\n",
"test_labels shape : (10000,)\n"
]
}
],
"source": [
"mnist = tf.keras.datasets.mnist\n",
"\n",
"(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()\n",
"\n",
"# backup just in case\n",
"(train_images2, train_labels2), (test_images2, test_labels2) = tf.keras.datasets.mnist.load_data()\n",
"\n",
"print(\"train_images shape : \", train_images.shape)\n",
"print(\"train_labels shape : \", train_labels.shape)\n",
"print(\"test_images shape : \", test_images.shape)\n",
"print(\"test_labels shape : \", test_labels.shape)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Step 2: Data Exploration"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Displaying train index = 58888 , label = 8\n"
]
},
{
"data": {
"text/plain": [
"<matplotlib.image.AxesImage at 0x7f07240f0710>"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAPLUlEQVR4nO3df5DU9X3H8deL80BFbCEqohKxERWrDaYX1NIfOtiMmrbAdNIJk7GkdYKtcdSOrXFsp7HTydSJQWNmDBWjIxqjyUw0Mi2ROBRrbS3htKjgqaChihAQMRXrCMfdu3/cYk+872fP/X73drnP8zFzs3vf9372+3a9F9/d/ex3P44IARj9xrS6AQAjg7ADmSDsQCYIO5AJwg5k4pCR3NlYj4tDNX4kdwlk5T39r/bGHg9VKxV22xdKulVSh6TvRMSNqdsfqvE623PK7BJAwppYVVhr+Gm87Q5Jt0m6SNLpkhbYPr3R+wPQXGVes8+StCkiXomIvZIekDS3mrYAVK1M2I+X9Nqg37fUtn2A7UW2u21392pPid0BKKNM2Id6E+BDn72NiKUR0RURXZ0aV2J3AMooE/YtkqYO+v0ESVvLtQOgWcqEfa2k6bZPsj1W0uclLa+mLQBVa3jqLSL22b5C0koNTL3dFREbKusMQKVKzbNHxApJKyrqBUAT8XFZIBOEHcgEYQcyQdiBTBB2IBOEHcgEYQcyQdiBTBB2IBOEHcgEYQcyQdiBTBB2IBOEHcgEYQcyQdiBTBB2IBOEHcgEYQcyQdiBTBB2IBOEHcgEYQcyQdiBTBB2IBOEHcgEYQcyQdiBTBB2IBOlVnHFyBgz8/Rkve+wzsLaz+Yfnh577J5k/W9n/VOy/oUJ25L1nt7ewtq8xy9Pjj3pbifr457dnKz37XwzWc9NqbDb3ixpt6Q+SfsioquKpgBUr4oj+/kRsbOC+wHQRLxmBzJRNuwh6Se2n7K9aKgb2F5ku9t2d6/Srw8BNE/Zp/GzI2Kr7WMkPWr7hYh4fPANImKppKWSdKQnRcn9AWhQqSN7RGytXe6Q9JCkWVU0BaB6DYfd9njbE/Zfl/QZSeuragxAtco8jZ8s6SHb++/nexHxSCVdjTJx7ieT9c1XpV/drDz328n6CYccVljrV39ybFn17n1GZ/FnAHrm3J4ePCddPvPeK5P1k657Mn0HmWk47BHxiqT0XzGAtsHUG5AJwg5kgrADmSDsQCYIO5AJTnGtwJjx45P1K+79frJ+wWG76+xhXLK6J4pPI31676HJsV/88WXJ+pj30qeZnvy9er0X6zt8bLLeuSN934ddlO4NH8SRHcgEYQcyQdiBTBB2IBOEHcgEYQcyQdiBTDDPPkwdv/xLhbW37j8qObbePHrP3vSJovMfS3/l8pRHik8jnfD9/0yOna41yXo9Zb56qN6Rpq9O/dgXN5XYe344sgOZIOxAJgg7kAnCDmSCsAOZIOxAJgg7kAnm2Ydp34xphbXHfu07pe77kiV/kayf8vX/KHX/gMSRHcgGYQcyQdiBTBB2IBOEHcgEYQcyQdiBTDDPPkx+8pnC2llP/kly7E/PSc/D37TozmT9mo5Lk/Vpd79SWNu37efJschH3SO77bts77C9ftC2SbYftb2xdjmxuW0CKGs4T+PvlnThAduuk7QqIqZLWlX7HUAbqxv2iHhc0q4DNs+VtKx2fZmkeRX3BaBijb5BNzkitklS7fKYohvaXmS723Z3r/Y0uDsAZTX93fiIWBoRXRHR1VlngUIAzdNo2LfbniJJtcsd1bUEoBkaDftySQtr1xdKeriadgA0iyPS3/xt+35J50k6StJ2SV+V9CNJP5D0cUmvSvpcRBz4Jt6HHOlJcbbnlGz54DN9bfrly+Ljnih1/1/5+bmFtX+7/dPJscd8t/jzA5LU/+67DfVUhTj3k8n6m2cenqz/Ykbx3/bsc55Pju2Pcq9wn3nw9GT9uG805zsK1sQqvR27hly4vu6HaiJiQUEpv9QCBzE+LgtkgrADmSDsQCYIO5AJwg5kou7UW5VynXo75PjjkvUTf5SetfyHKauT9cM9trDWr/Ry0CvfLV6KWpJuvvIL6X3/9OVkXZOLl7N+7WvpyaDln1qarB93SHpK842+4o9nP/zOjOTYW3782WS9nhP/uTdZP+Rfnip1/0VSU28c2YFMEHYgE4QdyARhBzJB2IFMEHYgE4QdyATz7AeBzX9ffAqrJH16Tk9h7drjHkmOPbWzo6Ge9vvu21OT9Q4Xz/MvmPB6cuzu/r3J+vndX0rWJ996aHFfjz2dHHuwYp4dAGEHckHYgUwQdiAThB3IBGEHMkHYgUwwzz7Kdfzqqcn6C9eOT9Z7Lri91P7HJI4n9c61v/BP/zxZH7uyu6GeRjPm2QEQdiAXhB3IBGEHMkHYgUwQdiAThB3IRN1VXHFw69vwYrIe/b+erKfmyYej08Xny/fW+YjH1BteStZ3PvWxZL1v55vpHWSm7v9J23fZ3mF7/aBtN9h+3fa62s/FzW0TQFnD+Wf7bkkXDrH9loiYWftZUW1bAKpWN+wR8bik9PpEANpemRdkV9h+tvY0f2LRjWwvst1tu7tXxWtvAWiuRsO+RNInJM2UtE3S4qIbRsTSiOiKiK5OpRfiA9A8DYU9IrZHRF9E9Eu6Q9KsatsCULWGwm57yqBf50taX3RbAO2h7jy77fslnSfpKNtbJH1V0nm2Z0oKSZslXdbEHlHC1r/6jWR97QU3Jev9Kl77Xar/vfH/eNP8wtriv1mSHHvHx1cl62f+5ZXJ+knXPZms56Zu2CNiwRCb72xCLwCaiI/LApkg7EAmCDuQCcIOZIKwA5ngFNdRwF1nFNZWX5meWpswJj21dtqKy5P1GdduStYnvVU8/bVw1qLk2Bd+/7Zkveu3XkjWOcH1gziyA5kg7EAmCDuQCcIOZIKwA5kg7EAmCDuQCebZR4HNfzChsFZvHn3JL6Yn6zNuSn/9YN9bbyXrKSesHHJl4fft/uzeZP3PpqxO1m885Q8La30vvZwcOxpxZAcyQdiBTBB2IBOEHcgEYQcyQdiBTBB2IBPMsx8EYvbMZP2BS76ZqKb/PX/kj2en9/3ShmS9jMMfWpOsv7j4sGT97HG96fGXH11YO/lq5tkBjFKEHcgEYQcyQdiBTBB2IBOEHcgEYQcywTz7QeC9o9LnpJ85trOw1q9Iju3Y8T/J+r5ktZyOo4vnwSVpwpj0+exjVPzfLUkxMT0+N3WP7Lan2l5tu8f2BttX1bZPsv2o7Y21y4nNbxdAo4bzNH6fpGsiYoakcyR92fbpkq6TtCoipktaVfsdQJuqG/aI2BYRT9eu75bUI+l4SXMlLavdbJmkec1qEkB5H+kNOtvTJJ0laY2kyRGxTRr4B0HSMQVjFtnutt3dqz3lugXQsGGH3fYRkn4o6eqIeHu44yJiaUR0RURXp8Y10iOACgwr7LY7NRD0+yLiwdrm7ban1OpTJO1oTosAqlB36s22Jd0pqScibh5UWi5poaQba5cPN6VD6Ijndybr9+4+trC2YMLrVbdTmY3XnJysn9rZkaz3qz9ZH/ezQz9yT6PZcObZZ0u6RNJzttfVtl2vgZD/wPalkl6V9LnmtAigCnXDHhFPSCr6Nv851bYDoFn4uCyQCcIOZIKwA5kg7EAmCDuQCU5xPQj0bXwlWf+7f51bWFvwe99Ojo170nPVm/7rnGS9nvvm3VZYO7nz3+uMTp/ae9qKy5P1U7/WXVhLn/g7OnFkBzJB2IFMEHYgE4QdyARhBzJB2IFMEHYgE8yzjwIzvlX8xUHfmn1acuxDp9T5GoJTGuno/41JHE/668yjL3j54mR9xk27kvW+Xr5KejCO7EAmCDuQCcIOZIKwA5kg7EAmCDuQCcIOZIJ59lGgb8OLhbXV501Ljl2y+HeS9Z4Lbm+kpfdd+ur5hbW1K89Ijj0xcT66JEXvGw31lCuO7EAmCDuQCcIOZIKwA5kg7EAmCDuQCcIOZMIR6W/Qtj1V0j2SjpXUL2lpRNxq+wZJX5K0f7Lz+ohYkbqvIz0pzjYLvwLNsiZW6e3YNeSqy8P5UM0+SddExNO2J0h6yvajtdotEfGNqhoF0DzDWZ99m6Rtteu7bfdIOr7ZjQGo1kd6zW57mqSzJK2pbbrC9rO277I9sWDMItvdtrt7tadUswAaN+yw2z5C0g8lXR0Rb0taIukTkmZq4Mi/eKhxEbE0IroioqtT4ypoGUAjhhV2250aCPp9EfGgJEXE9ojoi4h+SXdImtW8NgGUVTfsti3pTkk9EXHzoO1TBt1svqT11bcHoCrDeTd+tqRLJD1ne11t2/WSFtieqYHVbzdLuqwpHQKoxHDejX9C0lDzdsk5dQDthU/QAZkg7EAmCDuQCcIOZIKwA5kg7EAmCDuQCcIOZIKwA5kg7EAmCDuQCcIOZIKwA5kg7EAm6n6VdKU7s9+Q9N+DNh0laeeINfDRtGtv7dqXRG+NqrK3EyPi6KEKIxr2D+3c7o6IrpY1kNCuvbVrXxK9NWqkeuNpPJAJwg5kotVhX9ri/ae0a2/t2pdEb40akd5a+podwMhp9ZEdwAgh7EAmWhJ22xfaftH2JtvXtaKHIrY3237O9jrb3S3u5S7bO2yvH7Rtku1HbW+sXQ65xl6LervB9uu1x26d7Ytb1NtU26tt99jeYPuq2vaWPnaJvkbkcRvx1+y2OyS9JOl3JW2RtFbSgoh4fkQbKWB7s6SuiGj5BzBs/7akdyTdExFn1LZ9XdKuiLix9g/lxIj4Spv0doOkd1q9jHdttaIpg5cZlzRP0hfVwscu0dcfaQQet1Yc2WdJ2hQRr0TEXkkPSJrbgj7aXkQ8LmnXAZvnSlpWu75MA38sI66gt7YQEdsi4una9d2S9i8z3tLHLtHXiGhF2I+X9Nqg37eovdZ7D0k/sf2U7UWtbmYIkyNimzTwxyPpmBb3c6C6y3iPpAOWGW+bx66R5c/LakXYh1pKqp3m/2ZHxKckXSTpy7WnqxieYS3jPVKGWGa8LTS6/HlZrQj7FklTB/1+gqStLehjSBGxtXa5Q9JDar+lqLfvX0G3drmjxf28r52W8R5qmXG1wWPXyuXPWxH2tZKm2z7J9lhJn5e0vAV9fIjt8bU3TmR7vKTPqP2Wol4uaWHt+kJJD7ewlw9ol2W8i5YZV4sfu5Yvfx4RI/4j6WINvCP/sqS/bkUPBX39iqRnaj8bWt2bpPs18LSuVwPPiC6V9DFJqyRtrF1OaqPe7pX0nKRnNRCsKS3q7Tc18NLwWUnraj8Xt/qxS/Q1Io8bH5cFMsEn6IBMEHYgE4QdyARhBzJB2IFMEHYgE4QdyMT/AfKih+ZZtV4AAAAAAElFTkSuQmCC\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"## Run this cell a few times to randomly display some digit data\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import random\n",
"\n",
"index = random.randint(0, len(train_images))\n",
"# index = 10\n",
"print (\"Displaying train index = \", index, \", label = \", train_labels[index])\n",
"\n",
"# print(\"train label [{}] = {} \".format(index, train_labels[index]))\n",
"# print (\"------------ raw data for train_image[{}] -------\".format(index))\n",
"# print(train_images[index])\n",
"# print (\"--------------------\")\n",
"\n",
"plt.imshow(train_images[index])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Step 3 : Shape Data"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 3.1 - Shape the array to 4 dimensional\n",
"ConvNets expect data in 4D. Let's add a channel dimension to our data."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"train_images shape : (60000, 28, 28, 1)\n",
"train_labels shape : (60000,)\n",
"test_images shape : (10000, 28, 28, 1)\n",
"test_labels shape : (10000,)\n"
]
}
],
"source": [
"## Reshape to add 'channel'.\n",
"train_images = train_images.reshape(( train_images.shape[0], 28, 28, 1))\n",
"test_images = test_images.reshape((test_images.shape[0], 28, 28, 1))\n",
"\n",
"print(\"train_images shape : \", train_images.shape)\n",
"print(\"train_labels shape : \", train_labels.shape)\n",
"print(\"test_images shape : \", test_images.shape)\n",
"print(\"test_labels shape : \", test_labels.shape)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 3.2 - Normalize Data\n",
"The images are stored as a 2D array of pixels. \n",
"Each pixel is a value from 0 to 255 \n",
"We are going to normalize them in the range of 0 to 1"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"## Normalize pixel values to be between 0 and 1\n",
"train_images, test_images = train_images / 255.0, test_images / 255.0"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"# import matplotlib.pyplot as plt\n",
"# import random\n",
"\n",
"# index = random.randint(0, len(train_images))\n",
"# index = 10\n",
"# print (\"Displaying train index = \", index)\n",
"\n",
"# print(\"train label [{}] = {} \".format(index, train_labels[index]))\n",
"# print (\"------------ raw data for train_image[{}] (just printing first 3 rows) -------\".format(index))\n",
"# print(train_images[index][0:2])\n",
"# print (\"--------------------\")\n",
"\n",
"# plt.imshow(train_images[index])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Step 4 : Create Model\n",
"\n",
"### Neural Net Architecture\n",
"\n",
"<img src=\"https://www.pyimagesearch.com/wp-content/uploads/2016/06/lenet_architecture-768x226.png\" style=\"width:80%\"/>"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 4.1 - Create a CNN\n",
"\n",
"The code below define the convolutional base using a common pattern: a stack of [Conv2D](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Conv2D) and [MaxPooling2D](https://www.tensorflow.org/api_docs/python/tf/keras/layers/MaxPool2D) layers.\n",
"\n",
"As input, a CNN takes tensors of shape (image_height, image_width, color_channels), ignoring the batch size. If you are new to color channels, MNIST has one (because the images are grayscale), whereas a color image has three (R,G,B). In this example, we will configure our CNN to process inputs of shape (28, 28, 1), which is the format of MNIST images. We do this by passing the argument `input_shape` to our first layer."
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: \"sequential\"\n",
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
"conv2d (Conv2D) (None, 26, 26, 32) 320 \n",
"_________________________________________________________________\n",
"max_pooling2d (MaxPooling2D) (None, 13, 13, 32) 0 \n",
"_________________________________________________________________\n",
"conv2d_1 (Conv2D) (None, 11, 11, 64) 18496 \n",
"_________________________________________________________________\n",
"max_pooling2d_1 (MaxPooling2 (None, 5, 5, 64) 0 \n",
"_________________________________________________________________\n",
"conv2d_2 (Conv2D) (None, 3, 3, 64) 36928 \n",
"=================================================================\n",
"Total params: 55,744\n",
"Trainable params: 55,744\n",
"Non-trainable params: 0\n",
"_________________________________________________________________\n",
"None\n"
]
}
],
"source": [
"model = tf.keras.Sequential( [ \n",
" tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),\n",
" tf.keras.layers.MaxPooling2D((2, 2)),\n",
" tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),\n",
" tf.keras.layers.MaxPooling2D((2, 2)),\n",
" tf.keras.layers.Conv2D(64, (3, 3), activation='relu')\n",
"])\n",
"\n",
"print (model.summary())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Above, you can see that the output of every Conv2D and MaxPooling2D layer is a 3D tensor of shape (height, width, channels). The width and height dimensions tend to shrink as we go deeper in the network. The number of output channels for each Conv2D layer is controlled by the first argument (e.g., 32 or 64). Typically, as the width and height shrink, we can afford (computationally) to add more output channels in each Conv2D layer."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 4.2 - Add Dense layers on top\n",
"To complete our model, we will feed the last output tensor from the convolutional base (of shape (3, 3, 64)) into one or more Dense layers to perform classification. Dense layers take vectors as input (which are 1D), while the current output is a 3D tensor. First, we will flatten (or unroll) the 3D output to 1D, then add one or more Dense layers on top. MNIST has 10 output classes, so we use a final Dense layer with 10 outputs and a softmax activation."
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: \"sequential\"\n",
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
"conv2d (Conv2D) (None, 26, 26, 32) 320 \n",
"_________________________________________________________________\n",
"max_pooling2d (MaxPooling2D) (None, 13, 13, 32) 0 \n",
"_________________________________________________________________\n",
"conv2d_1 (Conv2D) (None, 11, 11, 64) 18496 \n",
"_________________________________________________________________\n",
"max_pooling2d_1 (MaxPooling2 (None, 5, 5, 64) 0 \n",
"_________________________________________________________________\n",
"conv2d_2 (Conv2D) (None, 3, 3, 64) 36928 \n",
"_________________________________________________________________\n",
"flatten (Flatten) (None, 576) 0 \n",
"_________________________________________________________________\n",
"dense (Dense) (None, 64) 36928 \n",
"_________________________________________________________________\n",
"dense_1 (Dense) (None, 10) 650 \n",
"=================================================================\n",
"Total params: 93,322\n",
"Trainable params: 93,322\n",
"Non-trainable params: 0\n",
"_________________________________________________________________\n",
"None\n"
]
}
],
"source": [
"model.add(tf.keras.layers.Flatten())\n",
"model.add(tf.keras.layers.Dense(64, activation=tf.nn.relu))\n",
"model.add(tf.keras.layers.Dense(10, activation=tf.nn.softmax))\n",
"\n",
"print(model.summary())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 4.3 - Compile and Train"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"model.compile(optimizer=tf.keras.optimizers.Adam(), # 'adam'\n",
" loss='sparse_categorical_crossentropy',\n",
" metrics=['accuracy'])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Step 5 - Setup Tensorboard"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Saving TB logs to : /tmp/tensorboard-logs/cnn-mnist/2020-05-21--00-09-25\n"
]
}
],
"source": [
"## This is fairly boiler plate code\n",
"\n",
"import datetime\n",
"import os\n",
"\n",
"app_name = 'cnn-mnist' # you can change this, if you like\n",
"\n",
"tb_top_level_dir= '/tmp/tensorboard-logs'\n",
"tensorboard_logs_dir= os.path.join (tb_top_level_dir, app_name, \n",
" datetime.datetime.now().strftime(\"%Y-%m-%d--%H-%M-%S\"))\n",
"print (\"Saving TB logs to : \" , tensorboard_logs_dir)\n",
"\n",
"tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=tensorboard_logs_dir, histogram_freq=1)\n",
"\n",
"# Loading of tensorboard in Colab\n",
"if RUNNING_IN_COLAB:\n",
" %load_ext tensorboard\n",
" %tensorboard --logdir $tb_top_level_dir"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Step 6 : Train"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"training starting ...\n",
"Train on 48000 samples, validate on 12000 samples\n",
"Epoch 1/10\n",
"48000/48000 [==============================] - 13s 280us/sample - loss: 0.1608 - accuracy: 0.9487 - val_loss: 0.0683 - val_accuracy: 0.9810\n",
"Epoch 2/10\n",
"48000/48000 [==============================] - 4s 85us/sample - loss: 0.0492 - accuracy: 0.9843 - val_loss: 0.0448 - val_accuracy: 0.9876\n",
"Epoch 3/10\n",
"48000/48000 [==============================] - 4s 83us/sample - loss: 0.0355 - accuracy: 0.9888 - val_loss: 0.0547 - val_accuracy: 0.9842\n",
"Epoch 4/10\n",
"48000/48000 [==============================] - 4s 83us/sample - loss: 0.0268 - accuracy: 0.9911 - val_loss: 0.0346 - val_accuracy: 0.9914\n",
"Epoch 5/10\n",
"48000/48000 [==============================] - 4s 85us/sample - loss: 0.0221 - accuracy: 0.9928 - val_loss: 0.0329 - val_accuracy: 0.9903\n",
"Epoch 6/10\n",
"48000/48000 [==============================] - 4s 76us/sample - loss: 0.0164 - accuracy: 0.9948 - val_loss: 0.0410 - val_accuracy: 0.9893\n",
"Epoch 7/10\n",
"48000/48000 [==============================] - 4s 78us/sample - loss: 0.0131 - accuracy: 0.9957 - val_loss: 0.0497 - val_accuracy: 0.9887\n",
"Epoch 8/10\n",
"48000/48000 [==============================] - 4s 74us/sample - loss: 0.0133 - accuracy: 0.9957 - val_loss: 0.0376 - val_accuracy: 0.9906\n",
"Epoch 9/10\n",
"48000/48000 [==============================] - 3s 71us/sample - loss: 0.0107 - accuracy: 0.9965 - val_loss: 0.0384 - val_accuracy: 0.9908\n",
"Epoch 10/10\n",
"48000/48000 [==============================] - 3s 72us/sample - loss: 0.0101 - accuracy: 0.9966 - val_loss: 0.0497 - val_accuracy: 0.9877\n",
"training done in 47,643.77 ms\n"
]
}
],
"source": [
"epochs = 10\n",
"\n",
"print (\"training starting ...\")\n",
"t1 = time.perf_counter()\n",
"history = model.fit(train_images, train_labels, \n",
" epochs=epochs, validation_split = 0.2, verbose=1,\n",
" callbacks=[tensorboard_callback])\n",
"t2 = time.perf_counter()\n",
"print (\"training done in {:,.2f} ms\".format ((t2-t1)*1000))\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Step 7 : See Training History"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"%matplotlib inline\n",
"import matplotlib.pyplot as plt\n",
"\n",
"plt.plot(history.history['accuracy'], label='train_accuracy')\n",
"plt.plot(history.history['val_accuracy'], label='val_accuracy')\n",
"plt.legend()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Step 8 - Predict\n",
"\n",
"**==> Compare prediction time vs training time. Prediction is very quick!**"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"predicting on 10,000 images\n",
"prediction done in 323.48 ms\n"
]
}
],
"source": [
"\n",
"t1 = time.perf_counter()\n",
"print (\"predicting on {:,} images\".format(len(test_images)))\n",
"predictions = model.predict(test_images)\n",
"t2 = time.perf_counter()\n",
"print (\"prediction done in {:,.2f} ms\".format ((t2-t1)*1000))"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"random index = 3236\n",
"test_label[3236] = 7. So the number is 7\n",
"prediction of test_image[3236] = [ 0.000 0.000 0.014 0.000 0.000 0.000 0.000 0.983 0.004 0.000]\n",
"max softmax output = 0.98250765\n",
"index of max softmax output = 7. So the prediction is same (7)\n"
]
},
{
"data": {
"text/plain": [
"<matplotlib.image.AxesImage at 0x7f06c45c8590>"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAM30lEQVR4nO3dbawc5XnG8evCNiZ24sinrqnjOCWlIAKRauDUtHFVkaBSh6Y1fEiFK4LTQp2WUAUpH0rdD7EqVbJKSUKlKI0JFiaiRJEIwapQg+VaIlElh4Pj2qZObUoc8EttqKvYRMH45e6HM65OzNlnj3dmdza+/z9ptbtzz+zcWp3rzOw+u/s4IgTgwndR2w0AGAzCDiRB2IEkCDuQBGEHkpg+yJ1d7JlxiWYPcpdAKm/qJ3orTniyWq2w214m6SFJ0yR9NSLWlta/RLN1g2+qs0sABVtjc8daz6fxtqdJ+pKkj0q6WtIK21f3+ngA+qvOa/Ylkl6KiJcj4i1JX5e0vJm2ADStTtgXSnp1wv391bKfYXuV7THbYyd1osbuANRRJ+yTvQnwts/eRsS6iBiNiNEZmlljdwDqqBP2/ZIWTbj/XkkH67UDoF/qhP15SVfYfr/tiyXdLmljM20BaFrPQ28Rccr2vZK+rfGht/UR8WJjnQFoVK1x9oh4RtIzDfUCoI/4uCyQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJ1JrFFYNx5J4PFeuL79zZsfbK6iuL207/1xd66gk/f2qF3fY+ScclnZZ0KiJGm2gKQPOaOLJ/OCJeb+BxAPQRr9mBJOqGPSQ9a/sF26smW8H2KttjtsdO6kTN3QHoVd3T+KURcdD2fEmbbP8gIp6buEJErJO0TpLmeCRq7g9Aj2od2SPiYHV9RNJTkpY00RSA5vUcdtuzbb/r7G1JN0va1VRjAJpV5zT+UklP2T77OP8UEf/SSFcXmL2PXVdeocuLm+995IFi/d0XXdKxtmf9s8Vt/2j7nxTrC+88UKyfPnasWMfw6DnsEfGypF9rsBcAfcTQG5AEYQeSIOxAEoQdSIKwA0nwFdcBiFPl/6l7fvcrXR6h89BaN1fOuLhY3/brjxfrVz/8yWL98nv2F+un/+dosY7B4cgOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0k4YnA/HjPHI3GDbxrY/oaFr7+mWH/9ujnF+rV37yjWN+++qmNt04cfKm572fRZxfqZLt+/vWrL3cX6r97x/WIdzdoam3UsjnqyGkd2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCcfYL3E9vLc/bseVL/1isdxtn3/LT8nftH/jjOzrWLvoOY/BNY5wdAGEHsiDsQBKEHUiCsANJEHYgCcIOJMHvxl/g3vGt7xXrN/rPi/XHvvhgsX7zrGnF+l99oPM4/LzvFDdFw7oe2W2vt33E9q4Jy0Zsb7K9t7qe2982AdQ1ldP4RyUtO2fZ/ZI2R8QVkjZX9wEMsa5hj4jnJJ07h89ySRuq2xsk3dpwXwAa1usbdJdGxCFJqq7nd1rR9irbY7bHTupEj7sDUFff342PiHURMRoRozM0s9+7A9BBr2E/bHuBJFXXR5prCUA/9Br2jZJWVrdXSnq6mXYA9EvXcXbbT0i6UdI82/slfU7SWknfsH2XpFckfbyfTaJ/Zj21tVj/s3tvL9b/+ary//n//eCZjrV5xS3RtK5hj4gVHUr8CgXwc4SPywJJEHYgCcIOJEHYgSQIO5AEX3FF0fQ7Ow+dSdIXv31lsf7ox77Ssbb2H24rbnv6pR8W6zg/HNmBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnG2VF06sDBYv3VN0eK9fvm7ulY++EdC4rbvm8N4+xN4sgOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSfB9dhT5+muK9Y+8uzxl80Vyx1p0LqEPuh7Zba+3fcT2rgnL1tg+YHt7dbmlv20CqGsqp/GPSlo2yfIvRMTi6vJMs20BaFrXsEfEc5KODqAXAH1U5w26e23vqE7z53ZayfYq22O2x07qRI3dAaij17B/WdLlkhZLOiTpwU4rRsS6iBiNiNEZmtnj7gDU1VPYI+JwRJyOiDOSHpa0pNm2ADStp7DbnvgbwLdJ2tVpXQDDoes4u+0nJN0oaZ7t/ZI+J+lG24slhaR9kj7Vxx7Roteun1Os/96sHxfrpdndHT00hJ51DXtErJhk8SN96AVAH/FxWSAJwg4kQdiBJAg7kARhB5LgK64oGr17e9stoCEc2YEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcbZUfQHI9+vtf2JONmxNv0ntR4a54kjO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kwTg7annyjXnF+t9+dbIfJx73ngf+rel2UMCRHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSYJw9uR/9zW8W68vesa1Yn+Y3i/UHfsy8zMOi65Hd9iLbW2zvtv2i7c9Uy0dsb7K9t7qe2/92AfRqKqfxpyR9NiI+IOk3JH3a9tWS7pe0OSKukLS5ug9gSHUNe0Qcioht1e3jknZLWihpuaQN1WobJN3aryYB1Hdeb9DZvkzStZK2Sro0Ig5J4/8QJM3vsM0q22O2x07qRL1uAfRsymG3/U5JT0q6LyKOTXW7iFgXEaMRMTpDM3vpEUADphR22zM0HvTHI+Kb1eLDthdU9QWSjvSnRQBN6Dr0ZtuSHpG0OyI+P6G0UdJKSWur66f70iFadUblobPVhxcX6/Of2FV4bAzSVMbZl0r6hKSdts9O1r1a4yH/hu27JL0i6eP9aRFAE7qGPSK+K8kdyjc12w6AfuHjskAShB1IgrADSRB2IAnCDiTBV1yTW7via7W2f/aVq4r1+cd/UOvx0RyO7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBOPsF7j/vu9Dxfrvzyr/VPRfHFxarL/nnvKPFp0qVjFIHNmBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnG2S9wx645WWv75w+/r1gfObCn1uNjcDiyA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EASjijPv217kaTHJP2SxqfUXhcRD9leI+lPJb1Wrbo6Ip4pPdYcj8QNZuJXoF+2xmYdi6OTzro8lQ/VnJL02YjYZvtdkl6wvamqfSEi/r6pRgH0z1TmZz8k6VB1+7jt3ZIW9rsxAM06r9fsti+TdK2krdWie23vsL3e9twO26yyPWZ77KRO1GoWQO+mHHbb75T0pKT7IuKYpC9LulzSYo0f+R+cbLuIWBcRoxExOkMzG2gZQC+mFHbbMzQe9Mcj4puSFBGHI+J0RJyR9LCkJf1rE0BdXcNu25IekbQ7Ij4/YfmCCavdJmlX8+0BaMpU3o1fKukTknba3l4tWy1phe3FkkLSPkmf6kuHABoxlXfjvytpsnG74pg6gOHCJ+iAJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJdP0p6UZ3Zr8m6UcTFs2T9PrAGjg/w9rbsPYl0VuvmuztlyPiFycrDDTsb9u5PRYRo601UDCsvQ1rXxK99WpQvXEaDyRB2IEk2g77upb3XzKsvQ1rXxK99WogvbX6mh3A4LR9ZAcwIIQdSKKVsNteZvs/bb9k+/42eujE9j7bO21vtz3Wci/rbR+xvWvCshHbm2zvra4nnWOvpd7W2D5QPXfbbd/SUm+LbG+xvdv2i7Y/Uy1v9bkr9DWQ523gr9ltT5O0R9LvSNov6XlJKyLiPwbaSAe290kajYjWP4Bh+7clvSHpsYj4YLXs7yQdjYi11T/KuRHxl0PS2xpJb7Q9jXc1W9GCidOMS7pV0ifV4nNX6OsPNYDnrY0j+xJJL0XEyxHxlqSvS1reQh9DLyKek3T0nMXLJW2obm/Q+B/LwHXobShExKGI2FbdPi7p7DTjrT53hb4Goo2wL5T06oT7+zVc872HpGdtv2B7VdvNTOLSiDgkjf/xSJrfcj/n6jqN9yCdM8340Dx3vUx/XlcbYZ9sKqlhGv9bGhHXSfqopE9Xp6uYmilN4z0ok0wzPhR6nf68rjbCvl/Sogn33yvpYAt9TCoiDlbXRyQ9peGbivrw2Rl0q+sjLffz/4ZpGu/JphnXEDx3bU5/3kbYn5d0he33275Y0u2SNrbQx9vYnl29cSLbsyXdrOGbinqjpJXV7ZWSnm6xl58xLNN4d5pmXC0/d61Pfx4RA79IukXj78j/l6S/bqOHDn39iqR/ry4vtt2bpCc0flp3UuNnRHdJ+gVJmyXtra5Hhqi3r0naKWmHxoO1oKXefkvjLw13SNpeXW5p+7kr9DWQ542PywJJ8Ak6IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUji/wB2W+AteOC0yQAAAABJRU5ErkJggg==\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"## Print a sample prediction\n",
"\n",
"%matplotlib inline\n",
"import matplotlib.pyplot as plt\n",
"import random\n",
"import numpy as np\n",
"from pprint import pprint\n",
"\n",
"np.set_printoptions(formatter={'float': '{: 0.3f}'.format})\n",
"\n",
"index = random.randint(0, len(test_images))\n",
"\n",
"print (\"random index = \", index)\n",
"print (\"test_label[{}] = {}. So the number is {}\".format(index, test_labels[index], test_labels[index]))\n",
"print (\"prediction of test_image[{}] = {}\".format(index, predictions[index]))\n",
"print ('max softmax output = ', np.amax(predictions[index]))\n",
"print ('index of max softmax output = {}. So the prediction is same ({})'.format(np.argmax(predictions[index]), np.argmax(predictions[index])))\n",
"\n",
"plt.imshow(test_images2[index])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Step 9 : Evaluate the Model "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 9.1 - Metrics"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"model metrics : ['loss', 'accuracy']\n",
"Metric : loss = 0.043\n",
"Metric : accuracy = 0.987\n"
]
}
],
"source": [
"metric_names = model.metrics_names\n",
"print (\"model metrics : \" , metric_names)\n",
"\n",
"metrics = model.evaluate(test_images, test_labels, verbose=0)\n",
"\n",
"for idx, metric in enumerate(metric_names):\n",
" print (\"Metric : {} = {:,.3f}\".format (metric_names[idx], metrics[idx]))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 9.2 - Confusion Matrix"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"predictions shape : (10000, 10)\n",
"prediction 0 : [ 0.000 0.000 0.000 0.000 0.000 0.000 0.000 1.000 0.000 0.000]\n",
"prediction 1 : [ 0.000 0.000 1.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000]\n"
]
}
],
"source": [
"## our predictions is an array of arrays\n",
"print('predictions shape : ', predictions.shape)\n",
"print ('prediction 0 : ' , predictions[0])\n",
"print ('prediction 1 : ' , predictions[1])"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"prediction2 0 : 7\n",
"prediction2 1 : 2\n"
]
}
],
"source": [
"## We need to find the final output (max of softmax probabilities for each prediction)\n",
"predictions2 = [ np.argmax(p) for p in predictions]\n",
"print ('prediction2 0 : ' , predictions2[0])\n",
"print ('prediction2 1 : ' , predictions2[1])"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[ 971, 0, 3, 0, 1, 0, 3, 1, 1, 0],\n",
" [ 0, 1112, 0, 3, 0, 0, 15, 1, 4, 0],\n",
" [ 0, 1, 1014, 14, 0, 0, 1, 1, 1, 0],\n",
" [ 0, 0, 1, 1006, 0, 2, 0, 0, 1, 0],\n",
" [ 0, 0, 0, 0, 972, 0, 2, 0, 4, 4],\n",
" [ 0, 0, 0, 6, 0, 882, 2, 1, 0, 1],\n",
" [ 2, 1, 0, 0, 1, 2, 951, 0, 1, 0],\n",
" [ 0, 2, 11, 4, 0, 0, 0, 1008, 0, 3],\n",
" [ 0, 0, 0, 6, 0, 1, 0, 0, 967, 0],\n",
" [ 1, 0, 0, 3, 2, 6, 0, 0, 7, 990]])"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.metrics import confusion_matrix\n",
"import seaborn as sns\n",
"\n",
"cm = confusion_matrix(test_labels, predictions2, labels = [0,1,2,3,4,5,6,7,8,9])\n",
"cm"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[]"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 576x432 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"import seaborn as sns\n",
"\n",
"plt.figure(figsize = (8,6))\n",
"\n",
"# colormaps : cmap=\"YlGnBu\" , cmap=\"Greens\", cmap=\"Blues\", cmap=\"Reds\"\n",
"sns.heatmap(cm, annot=True, cmap=\"Reds\", fmt='d').plot()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 9.3 - Metrics Calculated from Confusion Matrix"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'0': {'f1-score': 0.9938587512794269,\n",
" 'precision': 0.9969199178644764,\n",
" 'recall': 0.9908163265306122,\n",
" 'support': 980},\n",
" '1': {'f1-score': 0.988005330964016,\n",
" 'precision': 0.996415770609319,\n",
" 'recall': 0.9797356828193833,\n",
" 'support': 1135},\n",
" '2': {'f1-score': 0.9839883551673945,\n",
" 'precision': 0.9854227405247813,\n",
" 'recall': 0.9825581395348837,\n",
" 'support': 1032},\n",
" '3': {'f1-score': 0.9805068226120859,\n",
" 'precision': 0.9654510556621881,\n",
" 'recall': 0.996039603960396,\n",
" 'support': 1010},\n",
" '4': {'f1-score': 0.992849846782431,\n",
" 'precision': 0.9959016393442623,\n",
" 'recall': 0.9898167006109979,\n",
" 'support': 982},\n",
" '5': {'f1-score': 0.988235294117647,\n",
" 'precision': 0.9876819708846585,\n",
" 'recall': 0.9887892376681614,\n",
" 'support': 892},\n",
" '6': {'f1-score': 0.9844720496894409,\n",
" 'precision': 0.9763860369609856,\n",
" 'recall': 0.9926931106471816,\n",
" 'support': 958},\n",
" '7': {'f1-score': 0.988235294117647,\n",
" 'precision': 0.9960474308300395,\n",
" 'recall': 0.980544747081712,\n",
" 'support': 1028},\n",
" '8': {'f1-score': 0.9867346938775511,\n",
" 'precision': 0.9807302231237323,\n",
" 'recall': 0.9928131416837782,\n",
" 'support': 974},\n",
" '9': {'f1-score': 0.9865470852017938,\n",
" 'precision': 0.9919839679358717,\n",
" 'recall': 0.981169474727453,\n",
" 'support': 1009},\n",
" 'accuracy': 0.9873,\n",
" 'macro avg': {'f1-score': 0.9873433523809434,\n",
" 'precision': 0.9872940753740315,\n",
" 'recall': 0.987497616526456,\n",
" 'support': 10000},\n",
" 'weighted avg': {'f1-score': 0.9873175638923014,\n",
" 'precision': 0.9874420624726045,\n",
" 'recall': 0.9873,\n",
" 'support': 10000}}\n"
]
}
],
"source": [
"from sklearn.metrics import classification_report\n",
"from pprint import pprint\n",
"\n",
"pprint(classification_report(test_labels, predictions2, output_dict=True))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Step 10 : Save the Model\n",
"\n",
"We are saving it to be loaded later"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"#model.save(\"mnist_model_28x28\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "TensorFlow-GPU",
"language": "python",
"name": "tf2-gpu"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.7"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment