Skip to content

Instantly share code, notes, and snippets.

@ia35
Created May 26, 2020 11:52
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ia35/c1059a6210796692a612d7cdb4825656 to your computer and use it in GitHub Desktop.
Save ia35/c1059a6210796692a612d7cdb4825656 to your computer and use it in GitHub Desktop.
data_augmentation.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"accelerator": "GPU",
"colab": {
"name": "data_augmentation.ipynb",
"provenance": [],
"private_outputs": true,
"collapsed_sections": [],
"include_colab_link": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/ia35/c1059a6210796692a612d7cdb4825656/data_augmentation.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "W2nZiEKlG2Z6",
"colab_type": "text"
},
"source": [
"[![](http://bec552ebfe.url-de-test.ws/ml/buttonBackProp.png)](https://www.backprop.fr)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fMFnGROsG4ho",
"colab_type": "text"
},
"source": [
"[![](https://raw.githubusercontent.com/BackProp-fr/meetup/master/images/LogoBackPropTranspSmall.png)](https://www.backprop.fr)\n",
"Le logo BackProp est présenté chaque fois qu'un ajout, une modification importante est apportée au code ou à chaque fois qu'un commentaire doit être signalé. \n",
"\n",
"Le texte en anglais est soit le texte d'origine soit un extrait de site qui apporte des explications."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xKKO5i7VG_zY",
"colab_type": "text"
},
"source": [
"## <font color=\"teal\">Références</font>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "CqlwSbgOHDo4",
"colab_type": "text"
},
"source": [
"- [Data augmentation](https://www.tensorflow.org/tutorials/images/data_augmentation)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "llx_cGpoAyAq"
},
"source": [
"##### Copyright 2020 The TensorFlow Authors."
]
},
{
"cell_type": "code",
"metadata": {
"cellView": "form",
"colab_type": "code",
"id": "5MAYU_6KA0Kt",
"colab": {}
},
"source": [
"#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
"# https://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License."
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "35lZ8kr3UcsB"
},
"source": [
"# Data augmentation"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "MfBg1C5NB3X0"
},
"source": [
"<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
" <td>\n",
" <a target=\"_blank\" href=\"https://www.tensorflow.org/tutorials/images/data_augmentation\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />View on TensorFlow.org</a>\n",
" </td>\n",
" <td>\n",
" <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/tutorials/images/data_augmentation.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
" </td>\n",
" <td>\n",
" <a target=\"_blank\" href=\"https://github.com/tensorflow/docs/blob/master/site/en/tutorials/images/data_augmentation.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
" </td>\n",
" <td>\n",
" <a href=\"https://storage.googleapis.com/tensorflow_docs/docs/site/en/tutorials/images/data_augmentation.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n",
" </td>\n",
"</table>"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "BZP72A6eFw74"
},
"source": [
"## Overview\n",
"\n",
"This tutorial demonstrates manual image manipulations and augmentation using `tf.image`.\n",
"\n",
"Data augmentation is a common technique to improve results and avoid overfitting, see [Overfitting and Underfitting](../keras/overfit_and_underfit.ipynb) for others."
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "8sZIVqk7HvnC"
},
"source": [
"## Setup"
]
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "3TS4vBJBd8jY",
"colab": {}
},
"source": [
"!pip install git+https://github.com/tensorflow/docs"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "D6RJSN7NHj3k",
"colab_type": "text"
},
"source": [
"[![](https://raw.githubusercontent.com/BackProp-fr/meetup/master/images/LogoBackPropTranspSmall.png)](https://www.backprop.fr)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "VWwMUiLoHf_R",
"colab_type": "text"
},
"source": [
"The [tf.data API](https://www.tensorflow.org/guide/data_performance) provides the tf.data.Dataset.prefetch transformation. \n",
"\n",
"It can be used to decouple the time when data is produced from the time when data is consumed. \n",
"\n",
"In particular, the transformation uses a background thread and an internal buffer to prefetch elements from the input dataset ahead of the time they are requested. \n",
"\n",
"The number of elements to prefetch should be equal to (or possibly greater than) the number of batches consumed by a single training step. \n",
"\n",
"You could either manually tune this value, or set it to tf.data.experimental.AUTOTUNE which will prompt the tf.data runtime to tune the value dynamically at runtime."
]
},
{
"cell_type": "code",
"metadata": {
"id": "rj4UISoeHwnO",
"colab_type": "code",
"colab": {}
},
"source": [
"import urllib\n",
"\n",
"import tensorflow as tf\n",
"from tensorflow.keras.datasets import mnist\n",
"from tensorflow.keras import layers\n",
"AUTOTUNE = tf.data.experimental.AUTOTUNE"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "rdP8EQbPsyRA",
"colab": {}
},
"source": [
"import tensorflow_docs as tfdocs\n",
"import tensorflow_docs.plots\n",
"\n",
"import tensorflow_datasets as tfds\n",
"\n",
"import PIL.Image\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import matplotlib as mpl\n",
"mpl.rcParams['figure.figsize'] = (12, 5)\n",
"\n",
"import numpy as np"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "HS4gMHvPIOxb",
"colab_type": "text"
},
"source": [
"[![](https://raw.githubusercontent.com/BackProp-fr/meetup/master/images/LogoBackPropTranspSmall.png)](https://www.backprop.fr)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "yb4CRNIPKxam",
"colab_type": "text"
},
"source": [
"Plutôt que d'afficher directement l'image on peut aussi l'ajouter à TensorBoard"
]
},
{
"cell_type": "code",
"metadata": {
"id": "okTglBr9INWb",
"colab_type": "code",
"colab": {}
},
"source": [
"%load_ext tensorboard"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "__EuXwM-8uth"
},
"source": [
"Let's check the data augmentation features on an image and then augment a whole dataset later to train a model."
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "frBSdODBLOOI"
},
"source": [
"Download [this image](https://commons.wikimedia.org/wiki/File:Felis_catus-cat_on_snow.jpg), by Von.grzanka, for augmentation."
]
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "s5ThIwG8KqzI",
"colab": {}
},
"source": [
"image_path = tf.keras.utils.get_file(\"cat.jpg\", \"https://storage.googleapis.com/download.tensorflow.org/example_images/320px-Felis_catus-cat_on_snow.jpg\")\n",
"PIL.Image.open(image_path)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "JtMI80ahMr0E",
"colab_type": "text"
},
"source": [
"Le module [tf.image](https://www.tensorflow.org/api_docs/python/tf/image) comprend de nombreuses fonctions : que ce soit pour les lire les images, les modifier (taille, format, couleurs, cropping, bounding boxes, flipping, rotation, transposition), les encoder ou les décoder."
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "-Ec3bGonGDCF"
},
"source": [
"Read and decode the image to tensor format."
]
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "cdCoB8b-uZjf",
"colab": {}
},
"source": [
"image_string=tf.io.read_file(image_path)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "9sQaimBwOLss",
"colab_type": "code",
"colab": {}
},
"source": [
"image=tf.image.decode_jpeg(image_string,channels=3)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "XxR0P5D6MQZi",
"colab_type": "text"
},
"source": [
"### <font color=\"orange\">TensorBoard</font>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "oowsToynKJec",
"colab_type": "text"
},
"source": [
"L'image doit être de rang 4 pour être affichée dans TensorBoard (batch_size, height, width, channels)\n",
"\n",
"tf propose une solution simple pour augmenter le rang : [tf.expand_dims](https://www.tensorflow.org/api_docs/python/tf/expand_dims)"
]
},
{
"cell_type": "code",
"metadata": {
"id": "CwUqoGEWIngm",
"colab_type": "code",
"colab": {}
},
"source": [
"image.shape"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "PJ0fjlrzKoSG",
"colab_type": "text"
},
"source": [
"le paramètre de tf.expand_dims est le numéro de l'axe"
]
},
{
"cell_type": "code",
"metadata": {
"id": "sd3BgiAZJMFj",
"colab_type": "code",
"colab": {}
},
"source": [
"img = tf.expand_dims(image, 0)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "qvZVMWb7JQPG",
"colab_type": "code",
"colab": {}
},
"source": [
"img.shape"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "LcSTd3O3J5KN",
"colab_type": "code",
"colab": {}
},
"source": [
"from datetime import datetime"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "UIXdV0RJJue1",
"colab_type": "code",
"colab": {}
},
"source": [
"# Clear out any prior log data.\n",
"!rm -rf logs\n",
"\n",
"# Sets up a timestamped log directory.\n",
"logdir = \"logs/train_data/\" + datetime.now().strftime(\"%Y%m%d-%H%M%S\")\n",
"# Creates a file writer for the log directory.\n",
"file_writer = tf.summary.create_file_writer(logdir)\n",
"\n",
"# Using the file writer, log the reshaped image.\n",
"with file_writer.as_default():\n",
" tf.summary.image(\"Training data\", img, step=0)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "2eIN2d4EKACF",
"colab_type": "code",
"colab": {}
},
"source": [
"%tensorboard --logdir logs/train_data"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "isGwyT0386yi"
},
"source": [
"A function to visualize and compare the original and augmented image side by side."
]
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "FKnRfw2dvyql",
"colab": {}
},
"source": [
"def visualize(original, augmented):\n",
" fig = plt.figure()\n",
" plt.subplot(1,2,1)\n",
" plt.title('Original image')\n",
" plt.imshow(original)\n",
"\n",
" plt.subplot(1,2,2)\n",
" plt.title('Augmented image')\n",
" plt.imshow(augmented)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "jYLzpEOhGqWY"
},
"source": [
"## Augment a single image"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "8IiXghY99Bo6"
},
"source": [
"### Flipping the image\n",
"Flip the image either vertically or horizontally."
]
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "X14VjLlFxnvZ",
"colab": {}
},
"source": [
"flipped = tf.image.flip_left_right(image)\n",
"visualize(image, flipped)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "eJb-oQ1BTo3i",
"colab_type": "text"
},
"source": [
"[![](https://raw.githubusercontent.com/BackProp-fr/meetup/master/images/LogoBackPropTranspSmall.png)](https://www.backprop.fr)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "AlRMXNdPTp4W",
"colab_type": "text"
},
"source": [
"Affichage dans TensorBoard (+ haut)"
]
},
{
"cell_type": "code",
"metadata": {
"id": "GWFxEYQkQsvQ",
"colab_type": "code",
"colab": {}
},
"source": [
"with file_writer.as_default():\n",
" tf.summary.image(\"Original image\", tf.expand_dims(image, 0), step=0)\n",
" tf.summary.image(\"Augmented image\", tf.expand_dims(flipped, 0), step=0)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "ObsvSmu99MfC"
},
"source": [
"### Grayscale the image\n",
"Grayscale an image."
]
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "mnqQA2ubyo6O",
"colab": {}
},
"source": [
"grayscaled = tf.image.rgb_to_grayscale(image)\n",
"visualize(image, tf.squeeze(grayscaled))\n",
"plt.colorbar()"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "juI4A4HF9gYc"
},
"source": [
"### Saturate the image\n",
"Saturate an image by providing a saturation factor."
]
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "tiTUhw-gzCJW",
"colab": {}
},
"source": [
"saturated = tf.image.adjust_saturation(image, 3)\n",
"visualize(image, saturated)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "E82CqomP9qcR"
},
"source": [
"### Change image brightness\n",
"Change the brightness of image by providing a brightness factor."
]
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "05dA6uEtzfyd",
"colab": {}
},
"source": [
"bright = tf.image.adjust_brightness(image, 0.4)\n",
"visualize(image, bright)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "5_0kMbmS91x6"
},
"source": [
"### Rotate the image\n",
"Rotate an image by 90 degrees."
]
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "edNoQzhszxo8",
"colab": {}
},
"source": [
"rotated = tf.image.rot90(image)\n",
"visualize(image, rotated)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "bomBnFWp9895"
},
"source": [
"### Center crop the image\n",
"Crop the image from center upto the image part you desire."
]
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "fvgz_6t21dq2",
"colab": {}
},
"source": [
"cropped = tf.image.central_crop(image, central_fraction=0.5)\n",
"visualize(image,cropped)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "8W5E_c7o-H96"
},
"source": [
"See the `tf.image` reference for details about available augmentation options."
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "92lBGZSQ-1Tx"
},
"source": [
"## Augment a dataset and train a model with it"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "lrDez4xIX9Ss"
},
"source": [
"Train a model on an augmented dataset.\n",
"\n",
"Note: The problem solved here is somewhat artificial. It trains a densely connected network to be shift invariant by jittering the input images. It's much more efficient to use convolutional layers instead."
]
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "mazlEonS_gTR",
"colab": {}
},
"source": [
"dataset, info = tfds.load('mnist', as_supervised=True, with_info=True)\n",
"train_dataset, test_dataset = dataset['train'], dataset['test']\n",
"\n",
"num_train_examples= info.splits['train'].num_examples"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "0R9Bqpv-URBk",
"colab_type": "code",
"colab": {}
},
"source": [
"info"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "y9czinuhUOx5",
"colab_type": "code",
"colab": {}
},
"source": [
"num_train_examples"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "011caOa0YCz5"
},
"source": [
"Write a function to augment the images. Map it over the the dataset. This returns a dataset that augments the data on the fly."
]
},
{
"cell_type": "code",
"metadata": {
"id": "1rqw67sfVxnL",
"colab_type": "code",
"colab": {}
},
"source": [
"def convert(image, label):\n",
" image = tf.image.convert_image_dtype(image, tf.float32) # Cast and normalize the image to [0,1]\n",
" return image, label"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "aNSc8slVWQ9x",
"colab_type": "text"
},
"source": [
"[![](https://raw.githubusercontent.com/BackProp-fr/meetup/master/images/LogoBackPropTranspSmall.png)](https://www.backprop.fr)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "I7Rz-16TV13s",
"colab_type": "text"
},
"source": [
"Il semblerait qu'il y ait une erreur dans le code de tf.\n",
"La fonction convert cast et normalise l'image.\n",
"Cette fonction est appelée dans augment puis ligne suivante on refait la même chose.\n",
"\n",
"Raison pour laquelle, la ligne est mise en commentaire par nous"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "mXqRnCtIWgHv",
"colab_type": "text"
},
"source": [
"L'augmentation consiste ici en un padding de 6, suivi d'un cropping et d'un random brightness"
]
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "3oaSV5QcDS8p",
"colab": {}
},
"source": [
"def augment(image,label):\n",
" #image,label = convert(image, label)\n",
" image = tf.image.convert_image_dtype(image, tf.float32) # Cast and normalize the image to [0,1]\n",
" image = tf.image.resize_with_crop_or_pad(image, 34, 34) # Add 6 pixels of padding\n",
" image = tf.image.random_crop(image, size=[28, 28, 1]) # Random crop back to 28x28\n",
" image = tf.image.random_brightness(image, max_delta=0.5) # Random brightness\n",
"\n",
" return image,label"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "xNROtM5uqSjg",
"colab": {}
},
"source": [
"BATCH_SIZE = 64\n",
"# Only use a subset of the data so it's easier to overfit, for this tutorial\n",
"NUM_EXAMPLES = 2048"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "dEq0VnUIy-8l"
},
"source": [
"Create the augmented dataset."
]
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "jWJpV6JOqN_7",
"colab": {}
},
"source": [
"augmented_train_batches = (\n",
" train_dataset\n",
" # Only train on a subset, so you can quickly see the effect.\n",
" .take(NUM_EXAMPLES)\n",
" .cache()\n",
" .shuffle(num_train_examples//4)\n",
" # The augmentation is added here.\n",
" .map(augment, num_parallel_calls=AUTOTUNE)\n",
" .batch(BATCH_SIZE)\n",
" .prefetch(AUTOTUNE)\n",
") "
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "PGm5X_77zE3q"
},
"source": [
"And a non-augmented one for comparison."
]
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "td4yU920qOgU",
"colab": {}
},
"source": [
"non_augmented_train_batches = (\n",
" train_dataset\n",
" # Only train on a subset, so you can quickly see the effect.\n",
" .take(NUM_EXAMPLES)\n",
" .cache()\n",
" .shuffle(num_train_examples//4)\n",
" # No augmentation.\n",
" .map(convert, num_parallel_calls=AUTOTUNE)\n",
" .batch(BATCH_SIZE)\n",
" .prefetch(AUTOTUNE)\n",
") "
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "zkLkgEudzKfc"
},
"source": [
"Setup the validation dataset. This doesn't change whether or not you're using the augmentation."
]
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "6eqKD4zOqpvE",
"colab": {}
},
"source": [
"validation_batches = (\n",
" test_dataset\n",
" .map(convert, num_parallel_calls=AUTOTUNE)\n",
" .batch(2*BATCH_SIZE)\n",
")"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "Yi9TIwR-ZIOi"
},
"source": [
"Create and compile the model. The model is a two layered, fully-connected neural network without convolution."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "COgj-s1fXFiX",
"colab_type": "text"
},
"source": [
"Le modèle n'utilise pas de convolution (raison pédagogique)"
]
},
{
"cell_type": "code",
"metadata": {
"id": "t_nk9ZzrXu8L",
"colab_type": "code",
"colab": {}
},
"source": [
"tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=logdir)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "hHhkA4Q0CsHx",
"colab": {}
},
"source": [
"def make_model():\n",
" model = tf.keras.Sequential([\n",
" layers.Flatten(input_shape=(28, 28, 1)),\n",
" layers.Dense(4096, activation='relu'),\n",
" layers.Dense(4096, activation='relu'),\n",
" layers.Dense(10)\n",
" ])\n",
" model.compile(optimizer = 'adam',\n",
" loss=tf.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
" metrics=['accuracy'])\n",
" return model"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "P0rciou3ZWwy"
},
"source": [
"Train the model, **without** augmentation:"
]
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "z8X8CpqvNhG9",
"colab": {}
},
"source": [
"model_without_aug = make_model()\n",
"no_aug_history = model_without_aug.fit(non_augmented_train_batches, epochs=50, validation_data=validation_batches, callbacks=[tensorboard_callback])"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "yH3LUd1prFWe"
},
"source": [
"Train it again with augmentation:"
]
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "z4RqS9FWrEiS",
"colab": {}
},
"source": [
"model_with_aug = make_model()\n",
"aug_history = model_with_aug.fit(augmented_train_batches, epochs=50, validation_data=validation_batches, callbacks=[tensorboard_callback])"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "UEqeeNsHZaC5"
},
"source": [
"## Conclusion:\n",
"\n",
"In this example the augmented model converges to an accuracy ~95% on validation set. This is slightly higher (+1%) than the model trained without data augmentation."
]
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "ZnkliKKm0VVw",
"colab": {}
},
"source": [
"plotter = tfdocs.plots.HistoryPlotter()\n",
"plotter.plot({\"Augmented\": aug_history, \"Non-Augmented\": no_aug_history}, metric = \"accuracy\")\n",
"plt.title(\"Accuracy\")\n",
"plt.ylim([0.75,1])"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "cAaUjNhI3j7k"
},
"source": [
"In terms of loss, the non-augmented model is obviously in the overfitting regime. The augmented model, while a few epoch slower, is still training correctly and clearly not overfitting."
]
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "1-dTZpc4zq0-",
"colab": {}
},
"source": [
"plotter = tfdocs.plots.HistoryPlotter()\n",
"plotter.plot({\"Augmented\": aug_history, \"Non-Augmented\": no_aug_history}, metric = \"loss\")\n",
"plt.title(\"Loss\")\n",
"plt.ylim([0,1])"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "F2qyOoWTFztO",
"colab_type": "code",
"colab": {}
},
"source": [
""
],
"execution_count": 0,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment