Skip to content

Instantly share code, notes, and snippets.

@roycoding
Created March 14, 2018 04:05
Show Gist options
  • Save roycoding/05d1daaf807aa9a261f0ddb72b26cafa to your computer and use it in GitHub Desktop.
Save roycoding/05d1daaf807aa9a261f0ddb72b26cafa to your computer and use it in GitHub Desktop.
Transfer learning example (fast.ai Dogs vs Cats image classifier) on Google Colab
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "Transfer_Learning.ipynb",
"version": "0.3.2",
"views": {},
"default_view": {},
"provenance": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"accelerator": "GPU"
},
"cells": [
{
"metadata": {
"id": "Y_N2zBXhGVs6",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"# Transfer Learning: Dogs vs Cats image classification\n",
"Roy Keyes and Gerardo Garcia\n",
"\n",
"[Houston Data Science Meetup](https://meetup.com/Houston-Data-Science)\n",
"\n",
"12 March 2018\n",
"\n",
"\n",
"In this notebook, we will demonstrate an example of transfer learning. Using the [fast.ai](https://fast.ai) library with PyTorch, we will build a state of the art convolutional neural network (CNN) to classify images as either containing a dog or a cat.\n",
"\n",
"This notebook is adapted from [Lesson 1 (v2)](https://www.youtube.com/watch?v=IPBSB1HLNLo) ([Lesson 1 v2 notebook](https://github.com/fastai/fastai/blob/master/courses/dl1/lesson1.ipynb)) of the [fast.ai deep learning course](http://course.fast.ai/index.html).\n",
"\n",
"This notebook is running on Google Colaboratory with an Nvidia K80 GPU enabled. Currently Google offers this for free.\n",
"\n",
"If you have no already enabled GPU support for this notebook, from the notebook menu select `Runtime` -> `Change runtime type` and then select GPU."
]
},
{
"metadata": {
"id": "qqzkfTkkJNyC",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"## Install PyTorch and fast.ai libraries"
]
},
{
"metadata": {
"id": "9nyucEzCjEED",
"colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
}
},
"cell_type": "code",
"source": [
"!pip3 install http://download.pytorch.org/whl/cu80/torch-0.3.0.post4-cp36-cp36m-linux_x86_64.whl\n",
"!pip3 install torchvision"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "VsrkadNpJ4NG",
"colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
}
},
"cell_type": "code",
"source": [
"# Extra dependencies needed by OpenCV\n",
"!apt update && apt install -y libsm6 libxext6"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "tLE0ji--J-8J",
"colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
}
},
"cell_type": "code",
"source": [
"!pip3 install fastai"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "yvVVokmlRccU",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"---->>>> **You may need to restart the runtime at this point for the upgraded version of Pillow to work** <<<<----"
]
},
{
"metadata": {
"id": "yI8r0878KNFt",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"## Load image data\n",
"This data was originally released as part of a [Kaggle](https://www.kaggle.com/c/dogs-vs-cats) competition."
]
},
{
"metadata": {
"id": "qtOFdyOBKT1q",
"colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
}
},
"cell_type": "code",
"source": [
"!wget http://files.fast.ai/data/dogscats.zip"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "7e25zazMKV_i",
"colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
}
},
"cell_type": "code",
"source": [
"!unzip dogscats.zip"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "2voUx8U-KZzM",
"colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
}
},
"cell_type": "code",
"source": [
"!mkdir data"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "n9w_PoojKgV9",
"colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
}
},
"cell_type": "code",
"source": [
"!mv dogscats data/"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "BsrmkMwOKpII",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"## Import fast.ai libraries"
]
},
{
"metadata": {
"id": "oJJKPgikkeGh",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"#### Import ResNet34 architecture\n",
"ResNet, first decribed in [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385), is a CNN architecture that has demonstrated very strong image classification. "
]
},
{
"metadata": {
"id": "2yiy-FVflDRs",
"colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
}
},
"cell_type": "code",
"source": [
"from fastai.model import resnet34"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "JcS0-NWTlNHD",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"#### Import data loader"
]
},
{
"metadata": {
"id": "kGjhx7EGlUTd",
"colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
}
},
"cell_type": "code",
"source": [
"from fastai.dataset import ImageClassifierData"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "Xew6cMHKkvFA",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"#### Import data transformer"
]
},
{
"metadata": {
"id": "RxDhFPQEkqir",
"colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
}
},
"cell_type": "code",
"source": [
"from fastai.transforms import tfms_from_model"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "jLh-o_xglbsZ",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"#### Import network optimizer"
]
},
{
"metadata": {
"id": "G4uuNjjKK--5",
"colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
}
},
"cell_type": "code",
"source": [
"from fastai.conv_learner import ConvLearner"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "1j7Qt0cUmJMe",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"## Take a look at some of the data"
]
},
{
"metadata": {
"id": "6Ak2K4KtnrK1",
"colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
}
},
"cell_type": "code",
"source": [
"import matplotlib.pyplot as plt"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "ur6lJAbaKt6W",
"colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
}
},
"cell_type": "code",
"source": [
"%matplotlib inline"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "FrKets_1XqXv",
"colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
}
},
"cell_type": "code",
"source": [
"def img_plots(ims, class_, figsize=(12,6), rows=1, titles=None):\n",
" f = plt.figure(figsize=figsize)\n",
" for i in range(len(ims)):\n",
" sp = f.add_subplot(rows, len(ims)//rows, i+1)\n",
" sp.axis('Off')\n",
" if titles is not None: sp.set_title(titles[i], fontsize=16)\n",
" plt.imshow(plt.imread(f'{PATH}valid/{class_}/{ims[i]}'))"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "uwV50psEm_18",
"colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
}
},
"cell_type": "code",
"source": [
"PATH = \"data/dogscats/\"\n",
"catfiles = !ls {PATH}valid/cats | head\n",
"dogfiles = !ls {PATH}valid/dogs | head"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "JebGgVeCmcr8",
"colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
}
},
"cell_type": "code",
"source": [
"# Some cats\n",
"img_plots(catfiles, class_='cats', figsize=(18,12))"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "vrAcu4QvnyPM",
"colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
}
},
"cell_type": "code",
"source": [
"# Some dogs\n",
"\n",
"img_plots(dogfiles, class_='dogs', figsize=(18,12))"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "-fvH9kVOi4Jl",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"## Load pretrained network and re-train final layer for new task"
]
},
{
"metadata": {
"id": "hNl5lzGvOHEZ",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"#### Load ResNet network architecture"
]
},
{
"metadata": {
"id": "tMnBuXTyOQ3B",
"colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
}
},
"cell_type": "code",
"source": [
"arch = resnet34"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "aIaj7mErORTT",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"#### Transform data for use with chosen architecture"
]
},
{
"metadata": {
"id": "31xITKTnORxg",
"colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
}
},
"cell_type": "code",
"source": [
"img_size = 224\n",
"data = ImageClassifierData.from_paths('data/dogscats/', tfms=tfms_from_model(arch, img_size))"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "FqcEAQwIOdsA",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"#### Instantiate network optimizer with pretrained weights\n",
"Download pretrained network weights"
]
},
{
"metadata": {
"id": "Ml_70e-dOd-w",
"colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
}
},
"cell_type": "code",
"source": [
"learn = ConvLearner.pretrained(arch, data, precompute=True)"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "LlJIRLyQOePC",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"#### Train final network layer with new data\n",
"Set learning rate and number of epochs"
]
},
{
"metadata": {
"id": "SJueGfpFK7kc",
"colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
}
},
"cell_type": "code",
"source": [
"learn.fit(0.01, 3)"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "pEldjUAMju67",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"The last row of the results shows the accuracy of the classifier.\n",
"\n",
"This data set was originally released as part of a [Kaggle](https://www.kaggle.com/c/dogs-vs-cats) competition in 2013. At the time of the competition, a state of the art result was ~80% accuracy."
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment