Skip to content

Instantly share code, notes, and snippets.

@DonSheddow
Last active March 21, 2021 00:02
Show Gist options
  • Save DonSheddow/7c83f6fe6e59cc89149de753ed41c407 to your computer and use it in GitHub Desktop.
Save DonSheddow/7c83f6fe6e59cc89149de753ed41c407 to your computer and use it in GitHub Desktop.
digit_recognition.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "digit_recognition.ipynb",
"provenance": [],
"collapsed_sections": [],
"authorship_tag": "ABX9TyP+AZqh3sex688Z9DcADO+H",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/DonSheddow/7c83f6fe6e59cc89149de753ed41c407/digit_recognition.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"metadata": {
"id": "nVG8xWxxe4Rm"
},
"source": [
"#hide\n",
"!pip install -q fastai==2.1.10 torchcontrib"
],
"execution_count": 1,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "sYl81X89XJXl",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"outputId": "f8bfeff2-cb10-4d76-e148-b7b1cd295a4a"
},
"source": [
"import random\n",
"import os\n",
"from pathlib import Path\n",
"\n",
"from PIL import Image, ImageDraw, ImageFont, ImageFilter\n",
"\n",
"from fastai.vision.all import *\n",
"\n",
"\n",
"def create_digit_img(font_path, digit, size=28):\n",
" digit = str(digit)\n",
" W, H = (size, size)\n",
"\n",
" font = ImageFont.truetype(font_path, size, encoding='utf-8')\n",
"\n",
" image = Image.new(\"L\", (W, H), \"black\")\n",
" \n",
" draw = ImageDraw.Draw(image)\n",
"\n",
" offset_w, offset_h = font.getoffset(digit)\n",
" \n",
" w, h = draw.textsize(digit, font=font)\n",
" pos = ((W-w-offset_w)/2, (H-h-offset_h)/2)\n",
"\n",
" draw.text(pos, digit, \"white\", font=font)\n",
"\n",
" return image\n",
"\n",
"def augment(img):\n",
" return img.filter(ImageFilter.BoxBlur(1))\n",
"\n",
"font_dir = untar_data(\"https://sheddow.xyz/fonts.tgz\", force_download=True)\n",
"\n",
"fonts = [str(font_dir / s) for s in os.listdir(font_dir) ]\n",
"\n",
"print(f\"using {len(fonts)} fonts\")\n",
"\n",
"\n",
"data_sizes = {\"testing\": 1000, \"training\": 6000}\n",
"\n",
"# combine handwritten and computer generated digits\n",
"data_dir = untar_data(URLs.MNIST)\n",
"\n",
"for digit in range(10):\n",
" for (kind, size) in data_sizes.items():\n",
" for k in range(size):\n",
" dir_ = data_dir / kind / str(digit)\n",
" dir_.mkdir(parents=True, exist_ok=True)\n",
" font = random.choice(fonts)\n",
" im = create_digit_img(font, digit)\n",
" if random.random() < 0.3:\n",
" im = augment(im)\n",
" im.save(dir_ / f\"computergen-{k}.png\")\n",
"\n"
],
"execution_count": 2,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/html": [
""
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "stream",
"text": [
"using 111 fonts\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "u05uynJlfvum"
},
"source": [
"class RandomNoise(RandTransform):\n",
" order = 9\n",
" def __init__(self, p=0.5, std=20, mean=0):\n",
" super().__init__(p=p)\n",
" self.std = std\n",
" self.mean = mean\n",
"\n",
" def encodes(self, x:TensorImage):\n",
" std = random.randint(0, self.std)\n",
" return x + torch.randn(x.size()).to(x.device) * std + self.mean"
],
"execution_count": 3,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "tBUXwiW0flaQ",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 536
},
"outputId": "8556fbb4-f44b-4e85-ca8f-66ebdc5e4ea9"
},
"source": [
"mnist = DataBlock(blocks=(ImageBlock(cls=PILImageBW), CategoryBlock), \n",
" get_items=get_image_files, \n",
" splitter=GrandparentSplitter(train_name='training', valid_name='testing'),\n",
" get_y=parent_label,\n",
" batch_tfms=aug_transforms(max_warp=0.25, min_zoom=0.8, pad_mode='zeros', do_flip=False, xtra_tfms=[RandomNoise(p=0.3, std=15)])) # use xtra_tfms=[RandomErasing(sh=0.1, max_count=8)] to add noise\n",
"\n",
"dls = mnist.dataloaders(data_dir)\n",
"\n",
"dls.show_batch()\n",
"\n",
"learn = cnn_learner(dls, resnet34, metrics=error_rate)"
],
"execution_count": 4,
"outputs": [
{
"output_type": "display_data",
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 648x648 with 9 Axes>"
]
},
"metadata": {
"tags": [],
"needs_background": "light"
}
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "a3T72QBY1U9P",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 111
},
"outputId": "8da3b67c-fdc4-40e0-a1a7-f0cdf918961a"
},
"source": [
"learn.fit_one_cycle(2)"
],
"execution_count": 5,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>error_rate</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>0.138884</td>\n",
" <td>0.051361</td>\n",
" <td>0.016650</td>\n",
" <td>02:46</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>0.094171</td>\n",
" <td>0.029342</td>\n",
" <td>0.009300</td>\n",
" <td>02:48</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {
"tags": []
}
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "oXJjAlcB4VBg",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 173
},
"outputId": "2c4f8fae-11cf-403f-e40b-d5d6de3e0b20"
},
"source": [
"learn.fit_one_cycle(4)"
],
"execution_count": 6,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>error_rate</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>0.099324</td>\n",
" <td>0.032194</td>\n",
" <td>0.010100</td>\n",
" <td>02:47</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>0.068860</td>\n",
" <td>0.024399</td>\n",
" <td>0.007750</td>\n",
" <td>02:47</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>0.056584</td>\n",
" <td>0.015748</td>\n",
" <td>0.005150</td>\n",
" <td>02:48</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>0.052486</td>\n",
" <td>0.015382</td>\n",
" <td>0.004650</td>\n",
" <td>02:47</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {
"tags": []
}
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "fXWGGcM67S27",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 235
},
"outputId": "c0ac1255-cbba-4f46-f0c1-b76bc6ffdee8"
},
"source": [
"learn.fit_one_cycle(6)"
],
"execution_count": 7,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>error_rate</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>0.057007</td>\n",
" <td>0.016531</td>\n",
" <td>0.005550</td>\n",
" <td>02:47</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>0.058062</td>\n",
" <td>0.019173</td>\n",
" <td>0.005800</td>\n",
" <td>02:47</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>0.056476</td>\n",
" <td>0.012999</td>\n",
" <td>0.003900</td>\n",
" <td>02:47</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>0.046327</td>\n",
" <td>0.013247</td>\n",
" <td>0.003950</td>\n",
" <td>02:47</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>0.037934</td>\n",
" <td>0.012765</td>\n",
" <td>0.003750</td>\n",
" <td>02:47</td>\n",
" </tr>\n",
" <tr>\n",
" <td>5</td>\n",
" <td>0.032311</td>\n",
" <td>0.011637</td>\n",
" <td>0.003600</td>\n",
" <td>02:47</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {
"tags": []
}
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "W5zgype5_efa"
},
"source": [
"# TODO: train longer with fit_one_cycle before SWA\n",
"# tweak swa_lr and swa_freq\n",
"# should aim for an error_rate of 0.002 or less\n",
"import torch.optim\n",
"import torchcontrib.optim\n",
"from torchcontrib.optim import SWA\n",
"\n",
"swa_model = copy.deepcopy(learn.model)\n",
"\n",
"base_opt = torch.optim.SGD(swa_model.parameters(), lr=0.1)\n",
"\n",
"opt = torchcontrib.optim.SWA(base_opt, swa_start=8, swa_freq=2, swa_lr=0.02)\n",
"for _ in range(12):\n",
" for input, target in dls.train:\n",
" opt.zero_grad()\n",
" learn.loss_func(swa_model(input), target).backward()\n",
" opt.step()\n",
"opt.swap_swa_sgd()\n",
"\n",
"opt.bn_update(dls.train, swa_model)\n"
],
"execution_count": 19,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "RboQfSmvoEH2"
},
"source": [
"opt = torchcontrib.optim.SWA(base_opt, swa_start=0, swa_freq=1, swa_lr=0.01)\n",
"for _ in range(2):\n",
" for input, target in dls.train:\n",
" opt.zero_grad()\n",
" learn.loss_func(swa_model(input), target).backward()\n",
" opt.step()\n",
"opt.swap_swa_sgd()\n",
"\n",
"opt.bn_update(dls.train, swa_model)\n"
],
"execution_count": 21,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"id": "oFU4U8JsB1V3",
"outputId": "4565ab25-497a-49d9-f1e2-4c67be85758c"
},
"source": [
"swa_learner = Learner(dls, swa_model, loss_func=learn.loss_func, metrics=error_rate)\n",
"swa_learner.validate()"
],
"execution_count": 22,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/html": [
""
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(#2) [0.011040972545742989,0.0035500000230968]"
]
},
"metadata": {
"tags": []
},
"execution_count": 22
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "TRtQ7VUPaQwI"
},
"source": [
"dummy_input = torch.randn(81, 3, 28, 28).cuda()\n",
"\n",
"torch.onnx.export(learn.model, dummy_input, \"/tmp/digit_recognition.onnx\")"
],
"execution_count": 23,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "cPba1TxqsYbt",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 17
},
"outputId": "5ce8ab45-34d1-4803-bad6-f77554088b4f"
},
"source": [
"from google.colab import files\n",
"\n",
"files.download(\"/tmp/digit_recognition.onnx\")"
],
"execution_count": 24,
"outputs": [
{
"output_type": "display_data",
"data": {
"application/javascript": [
"\n",
" async function download(id, filename, size) {\n",
" if (!google.colab.kernel.accessAllowed) {\n",
" return;\n",
" }\n",
" const div = document.createElement('div');\n",
" const label = document.createElement('label');\n",
" label.textContent = `Downloading \"${filename}\": `;\n",
" div.appendChild(label);\n",
" const progress = document.createElement('progress');\n",
" progress.max = size;\n",
" div.appendChild(progress);\n",
" document.body.appendChild(div);\n",
"\n",
" const buffers = [];\n",
" let downloaded = 0;\n",
"\n",
" const channel = await google.colab.kernel.comms.open(id);\n",
" // Send a message to notify the kernel that we're ready.\n",
" channel.send({})\n",
"\n",
" for await (const message of channel.messages) {\n",
" // Send a message to notify the kernel that we're ready.\n",
" channel.send({})\n",
" if (message.buffers) {\n",
" for (const buffer of message.buffers) {\n",
" buffers.push(buffer);\n",
" downloaded += buffer.byteLength;\n",
" progress.value = downloaded;\n",
" }\n",
" }\n",
" }\n",
" const blob = new Blob(buffers, {type: 'application/binary'});\n",
" const a = document.createElement('a');\n",
" a.href = window.URL.createObjectURL(blob);\n",
" a.download = filename;\n",
" div.appendChild(a);\n",
" a.click();\n",
" div.remove();\n",
" }\n",
" "
],
"text/plain": [
"<IPython.core.display.Javascript object>"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "display_data",
"data": {
"application/javascript": [
"download(\"download_237773b3-8662-447e-baa8-9abd3c49f060\", \"digit_recognition.onnx\", 87255781)"
],
"text/plain": [
"<IPython.core.display.Javascript object>"
]
},
"metadata": {
"tags": []
}
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment