Skip to content

Instantly share code, notes, and snippets.

@ia35
Created January 15, 2020 13:24
Show Gist options
  • Save ia35/50bd377b15954fb0df7277cbe3a022cb to your computer and use it in GitHub Desktop.
Save ia35/50bd377b15954fb0df7277cbe3a022cb to your computer and use it in GitHub Desktop.
TfSegmentationImage.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"accelerator": "GPU",
"colab": {
"name": "TfSegmentationImage.ipynb",
"provenance": [],
"private_outputs": true,
"collapsed_sections": [],
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/ia35/50bd377b15954fb0df7277cbe3a022cb/tfsegmentationimage.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "cZCM65CBt1CJ"
},
"source": [
"##### Copyright 2019 The TensorFlow Authors.\n",
"\n",
"Licensed under the Apache License, Version 2.0 (the \"License\");"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "loUSYmbcg6XF",
"colab_type": "text"
},
"source": [
"[![](http://bec552ebfe.url-de-test.ws/ml/buttonBackProp.png)](https://www.backprop.fr)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "D7SUWShgg-cW",
"colab_type": "text"
},
"source": [
"[![](https://raw.githubusercontent.com/BackProp-fr/meetup/master/images/LogoBackPropTranspSmall.png)](https://www.backprop.fr)\n",
"Le logo BackProp est présenté chaque fois qu'une modification importante est apportée au code ou à chaque fois qu'un commentaire doit être signalé. "
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "8NYKv5z1ojWT",
"colab_type": "text"
},
"source": [
"## <font color=\"teal\">Inspiration</font>\n",
"\n",
"- Code d'origine est [ici](https://www.tensorflow.org/tutorials/images/segmentation)\n",
"\n",
"- TensorFlow Dataset [BackProp](https://colab.research.google.com/drive/1T8ij4QID73-SdYw3oSMvQbAW-lHeu1HG)"
]
},
{
"cell_type": "code",
"metadata": {
"cellView": "both",
"colab_type": "code",
"id": "JOgMcEajtkmg",
"colab": {}
},
"source": [
"#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
"# https://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License."
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "rCSP-dbMw88x"
},
"source": [
"# <font color=\"teal\">Image segmentation</font>\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "NEWs8JXRuGex"
},
"source": [
"<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
" <td>\n",
" <a target=\"_blank\" href=\"https://www.tensorflow.org/tutorials/images/segmentation\">\n",
" <img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />\n",
" View on TensorFlow.org</a>\n",
" </td>\n",
" <td>\n",
" <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/tutorials/images/segmentation.ipynb\">\n",
" <img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />\n",
" Run in Google Colab</a>\n",
" </td>\n",
" <td>\n",
" <a target=\"_blank\" href=\"https://github.com/tensorflow/docs/blob/master/site/en/tutorials/images/segmentation.ipynb\">\n",
" <img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />\n",
" View source on GitHub</a>\n",
" </td>\n",
" <td>\n",
" <a href=\"https://storage.googleapis.com/tensorflow_docs/docs/site/en/tutorials/images/segmentation.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n",
" </td>\n",
"</table>"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "sMP7mglMuGT2"
},
"source": [
"This tutorial focuses on the task of image segmentation, using a modified U-Net. Voir [U-net](https://lmb.informatik.uni-freiburg.de/people/ronneber/u-net/) ici.\n",
"\n",
"## What is image segmentation?\n",
"So far you have seen image classification, where the task of the network is to assign a label or class to an input image. However, suppose you want to know where an object is located in the image, the shape of that object, which pixel belongs to which object, etc. In this case you will want to segment the image, i.e., each pixel of the image is given a label. Thus, the task of image segmentation is to train a neural network to output a pixel-wise mask of the image. This helps in understanding the image at a much lower level, i.e., the pixel level. \n",
"\n",
"Image segmentation has many applications in medical imaging, self-driving cars and satellite imaging to name a few.\n",
"\n",
"The dataset that will be used for this tutorial is the [Oxford-IIIT Pet Dataset](https://www.robots.ox.ac.uk/~vgg/data/pets/), created by Parkhi *et al*. The dataset consists of images, their corresponding labels, and pixel-wise masks. The masks are basically labels for each pixel. Each pixel is given one of three categories :\n",
"\n",
"* Class 1 : Pixel belonging to the pet.\n",
"* Class 2 : Pixel bordering the pet.\n",
"* Class 3 : None of the above/ Surrounding pixel."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Bh3oV0V2pg_s",
"colab_type": "text"
},
"source": [
"[![](https://raw.githubusercontent.com/BackProp-fr/meetup/master/images/LogoBackPropTranspSmall.png)](https://www.backprop.fr)\n",
"\n",
"[pip](https://pip.pypa.io/en/stable/reference/pip_install/) currently supports cloning over git, git+http, git+https"
]
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "MQmKthrSBCld",
"colab": {}
},
"source": [
"!pip install git+https://github.com/tensorflow/examples.git"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "4A2wSTlbhpEq",
"colab_type": "text"
},
"source": [
"[![](https://raw.githubusercontent.com/BackProp-fr/meetup/master/images/LogoBackPropTranspSmall.png)](https://www.backprop.fr)"
]
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "YQX7R4bhZy5h",
"colab": {}
},
"source": [
"%tensorflow_version 2.x \n",
"import tensorflow as tf\n",
"print(tf.__version__)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "AapLACqMvy_9",
"colab_type": "text"
},
"source": [
"[![](https://raw.githubusercontent.com/BackProp-fr/meetup/master/images/LogoBackPropTranspSmall.png)](https://www.backprop.fr)\n",
"\n",
"## <font color=\"orange\">Pix2Pix</font>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "nm8V4fjOvnI1",
"colab_type": "text"
},
"source": [
"This [notebook](https://www.tensorflow.org/tutorials/generative/pix2pix) demonstrates image to image translation using conditional GAN's, as described in Image-to-Image Translation with Conditional Adversarial Networks. Using this technique we can colorize black and white photos, convert google maps to google earth, etc. Here, we convert building facades to real buildings."
]
},
{
"cell_type": "code",
"metadata": {
"id": "f4WNHn92vddo",
"colab_type": "code",
"colab": {}
},
"source": [
"from __future__ import absolute_import, division, print_function, unicode_literals\n",
"from tensorflow_examples.models.pix2pix import pix2pix"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "eeqpl_h-v6Tf",
"colab_type": "text"
},
"source": [
"[![](https://raw.githubusercontent.com/BackProp-fr/meetup/master/images/LogoBackPropTranspSmall.png)](https://www.backprop.fr)\n",
"\n",
"## <font color=\"orange\">tfds</font>\n"
]
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "g87--n2AtyO_",
"colab": {}
},
"source": [
"import tensorflow_datasets as tfds"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "REyIWDo_u6Vt",
"colab_type": "text"
},
"source": [
"tfds (TensorFlow Dataset) est une liste de jeux de données que TensorFlow Tf met à disposition.\n",
"On y trouve tous les standards comme MNIST, fashion_mnist, iris, titanic... et d'autres jeux de données moins connus.\n",
"C'est très pratique pour se former au Deep Learning (DL) et à Tf.\n",
"La doc Tf sur Dataset est ici"
]
},
{
"cell_type": "code",
"metadata": {
"id": "BPYtKWoivV5x",
"colab_type": "code",
"colab": {}
},
"source": [
"tfds.list_builders()"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "FlQGkjEtvAF_",
"colab_type": "code",
"colab": {}
},
"source": [
"tfds.disable_progress_bar()\n",
"\n",
"from IPython.display import clear_output\n",
"import matplotlib.pyplot as plt"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "oWe0_rQM4JbC"
},
"source": [
"## Download the Oxford-IIIT Pets dataset\n",
"\n",
"The dataset is already included in TensorFlow datasets, all that is needed to do is download it. The segmentation masks are included in version 3+."
]
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "40ITeStwDwZb",
"colab": {}
},
"source": [
"dataset, info = tfds.load('oxford_iiit_pet:3.*.*', with_info=True)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "cyqKZ3U45UKv",
"colab_type": "text"
},
"source": [
"[![](https://raw.githubusercontent.com/BackProp-fr/meetup/master/images/LogoBackPropTranspSmall.png)](https://www.backprop.fr)\n",
"\n",
"The following annotations are available for every image in the dataset: (a) species and breed name; (b) a tight bounding box (ROI) around the head of the animal; and (c) a pixel level foreground-background segmentation (Trimap).\n",
"\n",
"[![](http://www.robots.ox.ac.uk/~vgg/data/pets/pet_annotations.jpg)](http://www.robots.ox.ac.uk/~vgg/data/pets/)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "DOJrNlq7whri",
"colab_type": "text"
},
"source": [
"[![](https://raw.githubusercontent.com/BackProp-fr/meetup/master/images/LogoBackPropTranspSmall.png)](https://www.backprop.fr)"
]
},
{
"cell_type": "code",
"metadata": {
"id": "eH8ticC6wZc4",
"colab_type": "code",
"colab": {}
},
"source": [
"type(dataset)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "j3vL6zVNweYc",
"colab_type": "code",
"colab": {}
},
"source": [
"dataset.keys()"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "4gL8oJhoqzaC",
"colab_type": "text"
},
"source": [
"[![](https://raw.githubusercontent.com/BackProp-fr/meetup/master/images/LogoBackPropTranspSmall.png)](https://www.backprop.fr)"
]
},
{
"cell_type": "code",
"metadata": {
"id": "6VNyoKWDqrGK",
"colab_type": "code",
"colab": {}
},
"source": [
"info"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "Njvk5D1kwvAF",
"colab_type": "code",
"colab": {}
},
"source": [
"info.features"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "tW4op7C2wzpB",
"colab_type": "code",
"colab": {}
},
"source": [
"info.features[\"label\"].num_classes"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "rJcVdj_U4vzf"
},
"source": [
"The following code performs a simple augmentation of flipping an image. In addition, image is normalized to [0,1]. \n",
"\n",
"Finally, as mentioned above the pixels in the segmentation mask are labeled either {1, 2, 3}. For the sake of convenience, let's subtract 1 from the segmentation mask, resulting in labels that are : {0, 1, 2}."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "v4Ds4yrI33oQ",
"colab_type": "text"
},
"source": [
"[![](https://raw.githubusercontent.com/BackProp-fr/meetup/master/images/LogoBackPropTranspSmall.png)](https://www.backprop.fr)\n",
"\n",
"Un flip est un retournement. On peut faire un flip horizontal ou vertical. Le flip horizonal est un peu comme lire une feuille de l'autre côté."
]
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "FD60EbcAQqov",
"colab": {}
},
"source": [
"def normalize(input_image, input_mask):\n",
" input_image = tf.cast(input_image, tf.float32) / 255.0\n",
" input_mask -= 1\n",
" return input_image, input_mask"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "89M3IOmJBw6V",
"colab_type": "text"
},
"source": [
"[![](https://raw.githubusercontent.com/BackProp-fr/meetup/master/images/LogoBackPropTranspSmall.png)](https://www.backprop.fr)\n",
"\n",
"tf.image.[flip_left_right](https://www.tensorflow.org/api_docs/python/tf/image/flip_left_right?version=stable)\n",
"\n",
"Flip an image horizontally (left to right)"
]
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "2NPlCnBXQwb1",
"colab": {}
},
"source": [
"@tf.function\n",
"def load_image_train(datapoint):\n",
" input_image = tf.image.resize(datapoint['image'], (128, 128))\n",
" input_mask = tf.image.resize(datapoint['segmentation_mask'], (128, 128))\n",
"\n",
" if tf.random.uniform(()) > 0.5:\n",
" input_image = tf.image.flip_left_right(input_image)\n",
" input_mask = tf.image.flip_left_right(input_mask)\n",
"\n",
" input_image, input_mask = normalize(input_image, input_mask)\n",
"\n",
" return input_image, input_mask"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "Zf0S67hJRp3D",
"colab": {}
},
"source": [
"def load_image_test(datapoint):\n",
" input_image = tf.image.resize(datapoint['image'], (128, 128))\n",
" input_mask = tf.image.resize(datapoint['segmentation_mask'], (128, 128))\n",
"\n",
" input_image, input_mask = normalize(input_image, input_mask)\n",
"\n",
" return input_image, input_mask"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "65-qHTjX5VZh"
},
"source": [
"The dataset already contains the required splits of test and train and so let's continue to use the same split."
]
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "yHwj2-8SaQli",
"colab": {}
},
"source": [
"TRAIN_LENGTH = info.splits['train'].num_examples\n",
"BATCH_SIZE = 64\n",
"BUFFER_SIZE = 1000\n",
"STEPS_PER_EPOCH = TRAIN_LENGTH // BATCH_SIZE"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "XpoSxP3ZD5JH",
"colab_type": "text"
},
"source": [
"[![](https://raw.githubusercontent.com/BackProp-fr/meetup/master/images/LogoBackPropTranspSmall.png)](https://www.backprop.fr)\n",
"\n",
"On a déjà cette info avec info"
]
},
{
"cell_type": "code",
"metadata": {
"id": "IThCNLq6Dy4s",
"colab_type": "code",
"colab": {}
},
"source": [
"TRAIN_LENGTH"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "q4hgIo1oEj3P",
"colab_type": "text"
},
"source": [
"[![](https://raw.githubusercontent.com/BackProp-fr/meetup/master/images/LogoBackPropTranspSmall.png)](https://www.backprop.fr)\n",
"\n",
".map applique à chaque élément du dataset une fonction\n",
"\n",
"La différence entre train et test est le flip dans le 1er cas"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "BcB3msOfFUiQ",
"colab_type": "text"
},
"source": [
"[![](https://raw.githubusercontent.com/BackProp-fr/meetup/master/images/LogoBackPropTranspSmall.png)](https://www.backprop.fr)\n",
"\n",
"Module: [tf.data.experimental](https://www.tensorflow.org/api_docs/python/tf/data/experimental)\n",
"\n",
"Experimental API for building input pipelines.\n",
"\n",
"This module contains experimental Dataset sources and transformations that can be used in conjunction with the tf.data.Dataset API. Note that the tf.data.experimental API is not subject to the same backwards compatibility guarantees as tf.data, but we will provide deprecation advice in advance of removing existing functionality.\n",
"\n",
"[Choosing](https://www.tensorflow.org/guide/data_performance) the best value for the num_parallel_calls argument depends on your hardware, characteristics of your training data (such as its size and shape), the cost of your map function, and what other processing is happening on the CPU at the same time. A simple heuristic is to use the number of available CPU cores. However, as for the prefetch and interleave transformation, the map transformation supports tf.data.experimental.AUTOTUNE which will delegate the decision about what level of parallelism to use to the tf.data runtime."
]
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "39fYScNz9lmo",
"colab": {}
},
"source": [
"train = dataset['train'].map(load_image_train, num_parallel_calls=tf.data.experimental.AUTOTUNE)\n",
"test = dataset['test'].map(load_image_test)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZaQ4gyVgHc0Q",
"colab_type": "text"
},
"source": [
"[![](https://raw.githubusercontent.com/BackProp-fr/meetup/master/images/LogoBackPropTranspSmall.png)](https://www.backprop.fr)\n",
"\n",
"[Caches](https://www.tensorflow.org/api_docs/python/tf/data/Dataset?version=stable) the elements in this dataset.\n",
"\n",
"The first time the dataset is iterated over, its elements will be cached either in the specified file or in memory. Subsequent iterations will use the cached data."
]
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "DeFwFDN6EVoI",
"colab": {}
},
"source": [
"train_dataset = train.cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE).repeat()\n",
"train_dataset = train_dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)\n",
"test_dataset = test.batch(BATCH_SIZE)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "Xa3gMAE_9qNa"
},
"source": [
"Let's take a look at an image example and it's correponding mask from the dataset."
]
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "3N2RPAAW9q4W",
"colab": {}
},
"source": [
"def display(display_list):\n",
" plt.figure(figsize=(15, 15))\n",
"\n",
" title = ['Input Image', 'True Mask', 'Predicted Mask']\n",
"\n",
" for i in range(len(display_list)):\n",
" plt.subplot(1, len(display_list), i+1)\n",
" plt.title(title[i])\n",
" plt.imshow(tf.keras.preprocessing.image.array_to_img(display_list[i]))\n",
" plt.axis('off')\n",
" plt.show()"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "xJXgg4kPIUhM",
"colab_type": "text"
},
"source": [
"[![](https://raw.githubusercontent.com/BackProp-fr/meetup/master/images/LogoBackPropTranspSmall.png)](https://www.backprop.fr)\n",
"\n",
"Ce code n'est pas optimum !"
]
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "a6u_Rblkteqb",
"colab": {}
},
"source": [
"for image, mask in train.take(1):\n",
" sample_image, sample_mask = image, mask\n",
"display([sample_image, sample_mask])"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "FAOe93FRMk3w"
},
"source": [
"## Define the model\n",
"The model being used here is a modified U-Net. A U-Net consists of an encoder (downsampler) and decoder (upsampler). In-order to learn robust features, and reduce the number of trainable parameters, a pretrained model can be used as the encoder. Thus, the encoder for this task will be a pretrained MobileNetV2 model, whose intermediate outputs will be used, and the decoder will be the upsample block already implemented in TensorFlow Examples in the [Pix2pix tutorial](https://github.com/tensorflow/examples/blob/master/tensorflow_examples/models/pix2pix/pix2pix.py). \n",
"\n",
"The reason to output three channels is because there are three possible labels for each pixel. Think of this as multi-classification where each pixel is being classified into three classes."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "i_-wWJYSOZFs",
"colab_type": "text"
},
"source": [
"[![](https://raw.githubusercontent.com/BackProp-fr/meetup/master/images/LogoBackPropTranspSmall.png)](https://www.backprop.fr)\n",
"\n",
"The [u-net](https://lmb.informatik.uni-freiburg.de/people/ronneber/u-net/) is convolutional network architecture for fast and precise segmentation of images. \n",
"\n",
"Up to now it has outperformed the prior best method "
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "tan6N_lvPxVy",
"colab_type": "text"
},
"source": [
"[![](https://raw.githubusercontent.com/BackProp-fr/meetup/master/images/LogoBackPropTranspSmall.png)](https://www.backprop.fr)\n",
"\n",
"[U-Net](https://fr.wikipedia.org/wiki/U-Net) est un réseau de neurones à convolution développé pour la segmentation d'images biomédicales au département d'informatique de l'université de Fribourg en Allemagne1. Le réseau est basé sur le réseau entièrement convolutionnel2 et son architecture a été modifiée et étendue pour fonctionner avec moins d’images de training et pour permettre une segmentation plus précise. La segmentation d'une image 512 * 512 prend moins d'une seconde sur un GPU récent."
]
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "c6iB4iMvMkX9",
"colab": {}
},
"source": [
"OUTPUT_CHANNELS = 3"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "W4mQle3lthit"
},
"source": [
"As mentioned, the encoder will be a pretrained MobileNetV2 model which is prepared and ready to use in [tf.keras.applications](https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/keras/applications). The encoder consists of specific outputs from intermediate layers in the model. Note that the encoder will not be trained during the training process."
]
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "liCeLH0ctjq7",
"colab": {}
},
"source": [
"base_model = tf.keras.applications.MobileNetV2(input_shape=[128, 128, 3], include_top=False)\n",
"\n",
"# Use the activations of these layers\n",
"layer_names = [\n",
" 'block_1_expand_relu', # 64x64\n",
" 'block_3_expand_relu', # 32x32\n",
" 'block_6_expand_relu', # 16x16\n",
" 'block_13_expand_relu', # 8x8\n",
" 'block_16_project', # 4x4\n",
"]\n",
"layers = [base_model.get_layer(name).output for name in layer_names]\n",
"\n",
"# Create the feature extraction model\n",
"down_stack = tf.keras.Model(inputs=base_model.input, outputs=layers)\n",
"\n",
"down_stack.trainable = False"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "J9ptKhasXFMr",
"colab_type": "code",
"colab": {}
},
"source": [
"layers"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "uL623VPBXaww",
"colab_type": "code",
"colab": {}
},
"source": [
"base_model.input"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "Cy1FHW0yJG3W",
"colab_type": "code",
"colab": {}
},
"source": [
"base_model.summary()"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "EnIM2jWwQrHE",
"colab_type": "code",
"colab": {}
},
"source": [
"down_stack.summary()"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "KPw8Lzra5_T9"
},
"source": [
"The decoder/upsampler is simply a series of upsample blocks implemented in TensorFlow examples."
]
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "p0ZbfywEbZpJ",
"colab": {}
},
"source": [
"up_stack = [\n",
" pix2pix.upsample(512, 3), # 4x4 -> 8x8\n",
" pix2pix.upsample(256, 3), # 8x8 -> 16x16\n",
" pix2pix.upsample(128, 3), # 16x16 -> 32x32\n",
" pix2pix.upsample(64, 3), # 32x32 -> 64x64\n",
"]"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "45HByxpVtrPF",
"colab": {}
},
"source": [
"def unet_model(output_channels):\n",
"\n",
" # This is the last layer of the model\n",
" last = tf.keras.layers.Conv2DTranspose(\n",
" output_channels, 3, strides=2,\n",
" padding='same', activation='softmax') #64x64 -> 128x128\n",
"\n",
" inputs = tf.keras.layers.Input(shape=[128, 128, 3])\n",
" x = inputs\n",
"\n",
" # Downsampling through the model\n",
" skips = down_stack(x)\n",
" x = skips[-1]\n",
" skips = reversed(skips[:-1])\n",
"\n",
" # Upsampling and establishing the skip connections\n",
" for up, skip in zip(up_stack, skips):\n",
" x = up(x)\n",
" concat = tf.keras.layers.Concatenate()\n",
" x = concat([x, skip])\n",
"\n",
" x = last(x)\n",
"\n",
" return tf.keras.Model(inputs=inputs, outputs=x)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "j0DGH_4T0VYn"
},
"source": [
"## Train the model\n",
"Now, all that is left to do is to compile and train the model. The loss being used here is losses.sparse_categorical_crossentropy. The reason to use this loss function is because the network is trying to assign each pixel a label, just like multi-class prediction. In the true segmentation mask, each pixel has either a {0,1,2}. The network here is outputting three channels. Essentially, each channel is trying to learn to predict a class, and losses.sparse_categorical_crossentropy is the recommended loss for such a scenario. Using the output of the network, the label assigned to the pixel is the channel with the highest value. This is what the create_mask function is doing."
]
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "6he36HK5uKAc",
"colab": {}
},
"source": [
"model = unet_model(OUTPUT_CHANNELS)\n",
"model.compile(optimizer='adam', loss='sparse_categorical_crossentropy',\n",
" metrics=['accuracy'])"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "W8Xe7F3BJuZJ",
"colab_type": "code",
"colab": {}
},
"source": [
"model.summary()"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "xVMzbIZLcyEF"
},
"source": [
"Have a quick look at the resulting model architecture:"
]
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "sw82qF1Gcovr",
"colab": {}
},
"source": [
"tf.keras.utils.plot_model(model, show_shapes=True)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "Tc3MiEO2twLS"
},
"source": [
"Let's try out the model to see what it predicts before training."
]
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "UwvIKLZPtxV_",
"colab": {}
},
"source": [
"def create_mask(pred_mask):\n",
" pred_mask = tf.argmax(pred_mask, axis=-1)\n",
" pred_mask = pred_mask[..., tf.newaxis]\n",
" return pred_mask[0]"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "YLNsrynNtx4d",
"colab": {}
},
"source": [
"def show_predictions(dataset=None, num=1):\n",
" if dataset:\n",
" for image, mask in dataset.take(num):\n",
" pred_mask = model.predict(image)\n",
" display([image[0], mask[0], create_mask(pred_mask)])\n",
" else:\n",
" display([sample_image, sample_mask,\n",
" create_mask(model.predict(sample_image[tf.newaxis, ...]))])"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "X_1CC0T4dho3",
"colab": {}
},
"source": [
"show_predictions()"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "22AyVYWQdkgk"
},
"source": [
"Let's observe how the model improves while it is training. To accomplish this task, a callback function is defined below. "
]
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "wHrHsqijdmL6",
"colab": {}
},
"source": [
"class DisplayCallback(tf.keras.callbacks.Callback):\n",
" def on_epoch_end(self, epoch, logs=None):\n",
" clear_output(wait=True)\n",
" show_predictions()\n",
" print ('\\nSample Prediction after epoch {}\\n'.format(epoch+1))"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "StKDH_B9t4SD",
"colab": {}
},
"source": [
"EPOCHS = 20\n",
"VAL_SUBSPLITS = 5\n",
"VALIDATION_STEPS = info.splits['test'].num_examples//BATCH_SIZE//VAL_SUBSPLITS\n",
"\n",
"model_history = model.fit(train_dataset, epochs=EPOCHS,\n",
" steps_per_epoch=STEPS_PER_EPOCH,\n",
" validation_steps=VALIDATION_STEPS,\n",
" validation_data=test_dataset,\n",
" callbacks=[DisplayCallback()])"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "P_mu0SAbt40Q",
"colab": {}
},
"source": [
"loss = model_history.history['loss']\n",
"val_loss = model_history.history['val_loss']\n",
"\n",
"epochs = range(EPOCHS)\n",
"\n",
"plt.figure()\n",
"plt.plot(epochs, loss, 'r', label='Training loss')\n",
"plt.plot(epochs, val_loss, 'bo', label='Validation loss')\n",
"plt.title('Training and Validation Loss')\n",
"plt.xlabel('Epoch')\n",
"plt.ylabel('Loss Value')\n",
"plt.ylim([0, 1])\n",
"plt.legend()\n",
"plt.show()"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "unP3cnxo_N72"
},
"source": [
"## Make predictions"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "7BVXldSo-0mW"
},
"source": [
"Let's make some predictions. In the interest of saving time, the number of epochs was kept small, but you may set this higher to achieve more accurate results."
]
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "ikrzoG24qwf5",
"colab": {}
},
"source": [
"show_predictions(test_dataset, 3)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "R24tahEqmSCk"
},
"source": [
"## Next steps\n",
"Now that you have an understanding of what image segmentation is and how it works, you can try this tutorial out with different intermediate layer outputs, or even different pretrained model. You may also challenge yourself by trying out the [Carvana](https://www.kaggle.com/c/carvana-image-masking-challenge/overview) image masking challenge hosted on Kaggle.\n",
"\n",
"You may also want to see the [Tensorflow Object Detection API](https://github.com/tensorflow/models/tree/master/research/object_detection) for another model you can retrain on your own data."
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment