Skip to content

Instantly share code, notes, and snippets.

@DonSheddow
Last active March 21, 2021 00:02
Show Gist options
  • Save DonSheddow/7c83f6fe6e59cc89149de753ed41c407 to your computer and use it in GitHub Desktop.
Save DonSheddow/7c83f6fe6e59cc89149de753ed41c407 to your computer and use it in GitHub Desktop.
digit_recognition.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "digit_recognition.ipynb",
"provenance": [],
"collapsed_sections": [],
"authorship_tag": "ABX9TyP+AZqh3sex688Z9DcADO+H",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/DonSheddow/7c83f6fe6e59cc89149de753ed41c407/digit_recognition.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"metadata": {
"id": "nVG8xWxxe4Rm"
},
"source": [
"#hide\n",
"!pip install -q fastai==2.1.10 torchcontrib"
],
"execution_count": 1,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "sYl81X89XJXl",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"outputId": "f8bfeff2-cb10-4d76-e148-b7b1cd295a4a"
},
"source": [
"import random\n",
"import os\n",
"from pathlib import Path\n",
"\n",
"from PIL import Image, ImageDraw, ImageFont, ImageFilter\n",
"\n",
"from fastai.vision.all import *\n",
"\n",
"\n",
"def create_digit_img(font_path, digit, size=28):\n",
" digit = str(digit)\n",
" W, H = (size, size)\n",
"\n",
" font = ImageFont.truetype(font_path, size, encoding='utf-8')\n",
"\n",
" image = Image.new(\"L\", (W, H), \"black\")\n",
" \n",
" draw = ImageDraw.Draw(image)\n",
"\n",
" offset_w, offset_h = font.getoffset(digit)\n",
" \n",
" w, h = draw.textsize(digit, font=font)\n",
" pos = ((W-w-offset_w)/2, (H-h-offset_h)/2)\n",
"\n",
" draw.text(pos, digit, \"white\", font=font)\n",
"\n",
" return image\n",
"\n",
"def augment(img):\n",
" return img.filter(ImageFilter.BoxBlur(1))\n",
"\n",
"font_dir = untar_data(\"https://sheddow.xyz/fonts.tgz\", force_download=True)\n",
"\n",
"fonts = [str(font_dir / s) for s in os.listdir(font_dir) ]\n",
"\n",
"print(f\"using {len(fonts)} fonts\")\n",
"\n",
"\n",
"data_sizes = {\"testing\": 1000, \"training\": 6000}\n",
"\n",
"# combine handwritten and computer generated digits\n",
"data_dir = untar_data(URLs.MNIST)\n",
"\n",
"for digit in range(10):\n",
" for (kind, size) in data_sizes.items():\n",
" for k in range(size):\n",
" dir_ = data_dir / kind / str(digit)\n",
" dir_.mkdir(parents=True, exist_ok=True)\n",
" font = random.choice(fonts)\n",
" im = create_digit_img(font, digit)\n",
" if random.random() < 0.3:\n",
" im = augment(im)\n",
" im.save(dir_ / f\"computergen-{k}.png\")\n",
"\n"
],
"execution_count": 2,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/html": [
""
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "stream",
"text": [
"using 111 fonts\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "u05uynJlfvum"
},
"source": [
"class RandomNoise(RandTransform):\n",
" order = 9\n",
" def __init__(self, p=0.5, std=20, mean=0):\n",
" super().__init__(p=p)\n",
" self.std = std\n",
" self.mean = mean\n",
"\n",
" def encodes(self, x:TensorImage):\n",
" std = random.randint(0, self.std)\n",
" return x + torch.randn(x.size()).to(x.device) * std + self.mean"
],
"execution_count": 3,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "tBUXwiW0flaQ",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 536
},
"outputId": "8556fbb4-f44b-4e85-ca8f-66ebdc5e4ea9"
},
"source": [
"mnist = DataBlock(blocks=(ImageBlock(cls=PILImageBW), CategoryBlock), \n",
" get_items=get_image_files, \n",
" splitter=GrandparentSplitter(train_name='training', valid_name='testing'),\n",
" get_y=parent_label,\n",
" batch_tfms=aug_transforms(max_warp=0.25, min_zoom=0.8, pad_mode='zeros', do_flip=False, xtra_tfms=[RandomNoise(p=0.3, std=15)])) # use xtra_tfms=[RandomErasing(sh=0.1, max_count=8)] to add noise\n",
"\n",
"dls = mnist.dataloaders(data_dir)\n",
"\n",
"dls.show_batch()\n",
"\n",
"learn = cnn_learner(dls, resnet34, metrics=error_rate)"
],
"execution_count": 4,
"outputs": [
{
"output_type": "display_data",
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAgAAAAIHCAYAAADpfeRCAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO3deZCV1bX+8bVpxgaaeR4UlPEa1CgiQVBJEJQYh7oONyblWKaMWJJEjUGNpVfrVowGNEQtkdxEgxXEAaJiIsaARgTBIaggCCJz0yAtNPRIs+8fjfnll7VaTnfTvOc96/upslI+nmET3u5+eFln7xBjFAAA4EuTpBcAAACOPAoAAAAOUQAAAHCIAgAAgEMUAAAAHKIAAADgEAUAAACHKAD1EEI4OoQwP4RQHEIoDCFMDyE0TXpdQKZCCENCCK+FEHaHENaGEC5Iek1AXYUQ/hBC2BZC2BNCWBNCuCbpNaUJBaB+HhaRIhHpISIniMjpIvLDRFcEZOhgWZ0nIi+KSEcRuVZE/hBCGJjowoC6+x8ROTrGWCAi3xGRe0IIJyW8ptSgANRPPxF5OsZYHmMsFJE/i8h/JLwmIFODRaSniEyNMVbHGF8TkTdF5PvJLguomxjjRzHGii//9eA/xyS4pFShANTPNBG5NISQH0LoJSJnS00JANIqiMhxSS8CqKsQwsMhhFIR+VhEtonI/ISXlBoUgPp5XWr+xL9HRDaLyHIRmZvoioDMrZaav8K6OYTQLIRwltT8NVZ+sssC6i7G+EMRaSsio0XkORGp+Opn4EsUgDoKITSRmj/tPycirUWks4h0EJFfJLkuIFMxxioROV9EJopIoYj8RESelpoyC6TOwb/K+ruI9BaR65JeT1pQAOquo4j0FZHpMcaKGOPnIvK/InJOsssCMhdjXBFjPD3G2CnGOF5E+ovI20mvC2igpsIMQMYoAHUUY9wpIutF5LoQQtMQQnsRuVxEViS7MiBzIYRhIYSWB+dYbpKaT7T8LuFlARkLIXQNIVwaQmgTQsgLIYwXkf8Skb8mvba0oADUz4UiMkFEdojIWhGpEpEfJboioG6+LzUDU0Ui8k0RGfcv09RAGkSpud2/WUSKReR+EZkcY/xToqtKkRBjTHoNAADgCOMOAAAADlEAAABwiAIAAIBDFAAAAByiAAAA4NChjrDlIwJoiJD0AoRrGA2TDdewCNcxGsa8jrkDAACAQxQAAAAcogAAAOAQBQAAAIcoAAAAOEQBAADAIQoAAAAOUQAAAHDoUBsBuVRZWamyjRs3qmzfvn0q279/v8o2bNigsurqavO9R48erbLu3bubjwUAoL64AwAAgEMUAAAAHKIAAADgEAUAAACHcnIIcM+ePSpbsGCB+dg//vGPKtu7d6/K/vGPf6gsRn1Al5UVFxerzBoWFBGZNGmSyh544AGVNW2ak791AHKI9f2wqqrKfKz1Pa1JE/6M2pj4fxcAAIcoAAAAOEQBAADAIQoAAAAOUQAAAHAoVaPk5eXlKnvwwQdVNnPmTJWtX7/efM1OnTqpbNSoUSobM2aMyioqKlRmTbiWlZWpzPpUgYhIYWGhyqytifkUAIBsUlRUpLKpU6eqrLbJ/ltuuUVlBQUFKgsh1GN1sHAHAAAAhygAAAA4RAEAAMAhCgAAAA6lapLMGjK57777VLZ7926VtWjRwnzN73znOyq79957VdaxY0eV5eXlma+ZiRkzZpj5kiVLVGYNP1qDNBs3blRZr169VNa6detMlggAGWvVqpXKrO+7v/71r83nb968WWXTp09XWdu2beuxOli4AwAAgEMUAAAAHKIAAADgEAUAAACHgnVe87/4yv/YmKx1lZSUqGzu3Lkqe/HFF1W2aNEi831atmypsksvvVRlF154ocoGDRqkMmtAxRoWtAYVRUReffVVlVnDNS+99JLKli9frrJ77rlHZWeeeabKGmlnwWzYsiuxaxg5IRuuYZEUXMcHDhxQ2QcffKCyCy64wHz+jh07VHbeeeep7PHHH1eZ9X0c/x/zOuYOAAAADlEAAABwiAIAAIBDFAAAABzK2iFAi7VW60jebdu2qeyNN94wX3PevHkq+/vf/64y69jgs88+W2VXXHGFyqzd+FasWGGu57HHHlPZm2++qbKuXbuq7Oc//7nKvvWtb6nMGipsJNkwQJVV13CmqqurVVZaWmo+tri4WGV79+5VmXWstLWjZH5+vso6dOigstp2ZGvevLmZp1Q2XMMiKb2OraPQH330UfOx999/v8qsYWnr2OAbb7xRZe3atctkiV4wBAgAAGpQAAAAcIgCAACAQxQAAAAcStUQYKasHamsI3VFRAoLC1VmDQHOnDlTZcuWLVPZGWecobL+/ftn9B4i9pG+Y8aMUdlVV12lsrFjx6qsTZs25vscIdkwQJX117A13Ldu3TqVvf322+bzV61apTJrV7VMhwDbt2+vsqOPPlplI0aMMNdz3HHHZfSa1ntnoWy4hkVScB1brJ8v+/btMx/79NNPq8za8XTNmjUqswYLTz31VJU15Aj3lGMIEAAA1KAAAADgEAUAAACHKAAAADiUk0OAdWH9+q2BwcWLF6vM2nlvyZIl9X5fEfuo3jvuuENl1gDWEdzhL1PZMECVVdfwnj17VLZ06VKVWTtUrl271nzNZs2aqczaBc3aoS/TY7et3Qa7detmrsfaffKss85SWY8ePVSWhUNa2XANi2TZddwYrGvx008/Vdl3v/tdlQ0cOFBld999t8r69etXz9WlHkOAAACgBgUAAACHKAAAADhEAQAAwKGmSS8gaSHo2QhrmK53794qs4aYGvK+IvbOfd27d1dZFg784d/s379fZR9//LHKZs+erbJ33nlHZSeffLL5PtaA3VFHHaUy6/jeqqoqlW3fvl1l1nHar732mrmeWbNmqax169YqmzBhgso4whX/yrpurGvE2jHwlFNOUdkNN9xweBaWI7gDAACAQxQAAAAcogAAAOAQBQAAAIcoAAAAOOT+UwCW6upqlS1fvjyjzNrK1Nqq9cCBA+Z7f/TRRypbtmyZyqzz2Vu2bGm+JpJRWVmpsnfffVdl77//vsp69eqlsksuucR8n9GjR6vM2vY3U0OGDFGZdb1VVFSYz3/iiSdU9vrrr6ts2LBhKrM+BZOF2wOjERQWFqrspptuUtnChQszej3r68r6mmzI10racQcAAACHKAAAADhEAQAAwCEKAAAADrkfArQG/tasWaOyOXPmqGzTpk0qGzp0qMpGjhypsm3btpnr+etf/6qyp556SmX9+/dX2fDhw1XmecAladb55tb1ZmXWkGh5ebn5PtZgU5Mmuts3ZJiuffv2KqvtbPX8/HyVffLJJyrbsmVLRq/JEKAPu3fvVtlnn32mMmv7ausaKSoqUpk1uOr5eyR3AAAAcIgCAACAQxQAAAAcogAAAOCQ+yFAa6Dk0UcfVZl19nmXLl1Udtlll6ns4osvVtm6devM9XzwwQcqs3ZRmzJlisqsHdisc+FxZFg7QA4cODCjbOXKlSqbN2+e+T7FxcUqs4ZErV32LPv27VPZhg0bVGbthCki8vnnn6vMOsN9586dKtu/f38mS0QOsr6fWkPVixcvVlmHDh1UNnHiRJW1bt26nqvLTdwBAADAIQoAAAAOUQAAAHCIAgAAgEPuhwCtQaSlS5eqzNpB6sILL1TZRRddpDJrIMvaWU1EZNKkSSqbMWOGyqxjg62sd+/eKmNntSPDGgIcPHiwysaPH68y6/fIGhAVsQf0rKEo67hoa8dAa2dCaxdCa3c/EXtHN2tnQyuzdk+ED9YOkscff7zKWrVqpTJr8PSuu+5S2bnnnquynj17musJIZh5LuEOAAAADlEAAABwiAIAAIBDFAAAABwKhxi6yZmJnNp2GPvFL36hsoceekhl1tCeNZx3yimnqMwavqpNWVmZyqwjgi+44AKVHXvssSpbsmSJyqxd2RpJNkzRZNU1bF2HO3bsUNmHH36oMms4VcQeAvziiy9UZg3yWUOAbdu2VZl1/W/fvt1cz6uvvqqyQYMGqeymm25S2dlnn62yhHdvy4ZrWCTLruPGYP0ssr4fWsewv/nmmyq79tprVTZq1CiVzZ8/31xPju0aaF7H3AEAAMAhCgAAAA5RAAAAcIgCAACAQzm5E6A1TGINQInYu5mVlpaq7NJLL1WZNXTXokWLTJZYK2sHuF69eqnMGiwsKSlR2Z49e1R2BIcA8W+aNtVfct27d1eZ9Xs0ZMgQ8zWt3Syta8E6+tra7czakc3aCfPll18212MNZHXr1k1lHTt2VJn1/w98yPRaHDBggMr69u2rssmTJ6vsjTfeUFlRUZG5HusodWtoNs1y61cDAAAyQgEAAMAhCgAAAA5RAAAAcCgnJ24qKytVVtvRpdauZdbAoLU7mjXwl+kRkrXtTGjt4LZixQqVWUOAEyZMUFmuDa2knTWgah2/ax0l3KNHD/M1rSOfG3LUrjUEu3z5cpVt3LjRfL71PtZAVdeuXVVm/bqBf2V9T1u0aJHKMv0asAZmveCnAwAADlEAAABwiAIAAIBDFAAAAByiAAAA4FBOfgrAmrC3zpAWEdm2bVtGr3nmmWeqrFWrViqzPoFgbaNaXFxsvs8rr7yisj//+c8qs7Z1HThwoMoaujUx6i/T882ta9DawrmgoMB8ny5duqjM2lLX+rSB9T5r165V2bx581S2bNkycz39+/dX2fHHH68ya3tgPrWCQ7G+hm6//XaVWZ9msbZat17PC77aAABwiAIAAIBDFAAAAByiAAAA4FBODgFaW0Du3r074+dbA1S1ncX+77Zu3aqy9evXq2zJkiXm86dMmaIya9tfa7jPGkC0hgVxZFj/31vDqH/5y19U9vHHH6vMGvYTEenXr5/KWrdurTJri+vCwkKVWVtPW1sBW2e1i4iMHTtWZSNGjFBZ+/btzecDX7K+p1nbt2/evFll1hDu0KFDVda9e3fzvT0MpOb+rxAAACgUAAAAHKIAAADgEAUAAACHcnIIMNMdz0REQggqs84knzVrlsr27dunslWrVqls9erVKtu1a5e5HmsAcfTo0Sq76KKLVGbtVtihQwfzfdD4rCEkaxBv586dKnv33XdVZu1wKSLStm1blVkDTNbzra8La/DKGjQcOXKkuZ7x48erzNod0Po6g1/W8La1A+VPfvITlRUVFWX0HhMnTlRZ586dM3puLuIOAAAADlEAAABwiAIAAIBDFAAAABzKySFAa5e83r17m4896qijVGbt5vfss8+qrKSkRGXWcJ81fNWxY0dzPd/+9rdVdu2116rsa1/7mso6deqksubNm5vvg8ZnDZhav0djxoxRWbt27VRW29HV1i6X1iCfNRho7eZnHdNr7YRpXYMiIn369FGZdXQ2cos1xGcdhV7bwN5jjz2WUWYNzVqs75s/+9nPVGbtmukFdwAAAHCIAgAAgEMUAAAAHKIAAADgUE4OAVo7jPXq1ct87EknnaQy68jWzz77TGXWMZIDBgxQWc+ePVU2aNAgcz3Dhg1TmbUToDW4Yg2dITnWAGaPHj1UZu3WaA3dFRcXm+9TWlqqMmvw1BoCtAZmrQFEa7e02o4DtnazRG6xjji/4oorVLZx40aV1XYdW0PV1mChxdqp8sorr1RZQUFBRq/nBXcAAABwiAIAAIBDFAAAAByiAAAA4FBOTutYQ0jWIJ6IPXS3YMEClVmDeGeddZbK+vbtqzJrt0HrcSL2oJY11MjAXzrl5eWpzBrotLLaBlmBI62wsFBl69atU9n27dtVVtux1tb3NGsHSWsHyilTpqhs+PDh5vvg/+EOAAAADlEAAABwiAIAAIBDFAAAABwKMcav+u9f+R/TpKyszMytoymt44CtHaSsI32tndWszMkRlNkwqZgz1zASkQ3XsEgKrmNruG/16tUqu/jii83nW0N71o6pl19+ucpqG/LGP5nXMXcAAABwiAIAAIBDFAAAAByiAAAA4BAFAAAAh9x8CgCJyIYJaq5hNEQ2XMMiXMdoGD4FAAAAalAAAABwiAIAAIBDFAAAAByiAAAA4BAFAAAAhygAAAA4RAEAAMAhCgAAAA5RAAAAcIgCAACAQxQAAAAcogAAAOAQBQAAAIcOdRwwAADIQdwBAADAIQoAAAAOUQAAAHCIAgAAgEMUAAAAHKIAAADgEAUAAACHKAAAADhEAQAAwCEKAAAADlEAAABwiAIAAIBDFAAAAByiAAAA4BAFoB5CCH8IIWwLIewJIawJIVyT9JqAugohLAwhlIcQ9h78Z3XSawLqKoRwaQhhVQhhXwhhXQhhdNJrSosQY0x6DakTQvgPEVkbY6wIIQwWkYUiMjHG+E6yKwMyF0JYKCJ/iDE+nvRagPoIIYwTkcdF5BIReVtEeoiIxBi3JLmutOAOQD3EGD+KMVZ8+a8H/zkmwSUBgEd3icjdMcYlMcYDMcYt/PDPHAWgnkIID4cQSkXkYxHZJiLzE14SUB//E0LYGUJ4M4RwRtKLATIVQsgTkZNFpEsIYW0IYXMIYXoIoVXSa0sLCkA9xRh/KCJtRWS0iDwnIhVf/Qwg6/xURPqLSC8ReUxEXgghcCcLadFNRJqJyH9KzffhE0TkRBG5PclFpQkFoAFijNUxxr+LSG8RuS7p9QB1EWNcGmMsiTFWxBh/LyJvisg5Sa8LyFDZwf/9dYxxW4xxp4j8SriGM0YBODyaCjMASL8oIiHpRQCZiDEWi8hmqblu/xkntJxUogDUUQih68GPnbQJIeSFEMaLyH+JyF+TXhuQqRBC+xDC+BBCyxBC0xDCZSIyRkT+nPTagDr4XxG54eD35Q4i8iMReTHhNaVG06QXkEJRam73Pyo1BWqDiEyOMf4p0VUBddNMRO4RkcEiUi01w6znxxjXJLoqoG7+W0Q6i8gaESkXkadF5N5EV5Qi7AMAAIBD/BUAAAAOUQAAAHCIAgAAgEMUAAAAHDrUpwCYEERDZMNnyrmG0RDZcA2LcB2jYczrmDsAAAA4RAEAAMAhCgAAAA5RAAAAcIgCAACAQxQAAAAcogAAAOAQBQAAAIcoAAAAOEQBAADAIQoAAAAOUQAAAHCIAgAAgEMUAAAAHDrUccAAsly/fv1UtnHjRpXl5+errHnz5ip78sknVXbOOefUc3UAshV3AAAAcIgCAACAQxQAAAAcogAAAOAQBQAAAIf4FECGSktLVXbgwAGVtWnT5kgsBw7FGM189+7dKpszZ47KJk6cqLImTfSfAZo25dsC4AF3AAAAcIgCAACAQxQAAAAcogAAAOAQ0z6GkpISlc2fP19lixYtUll1dbXKRo0apbKLL77YfO+WLVtmskQ4ZF1bIiIVFRUq69+/v8patGhx2NcEIL24AwAAgEMUAAAAHKIAAADgEAUAAACHGAI0lJWVqczaHa1r164qKy4uVtm0adNUVlBQYL73+eefn8kS4dD+/fvNvKqqSmVbt25V2aZNm1R20kknqaxnz571WB2AtOEOAAAADlEAAABwiAIAAIBDFAAAABwKtR0xetBX/sdcZQ1VVVZWZvTcZs2aqeyaa65RWatWrcznP/LIIyqzjmxNiZD0AiSHrmFrh0oRkfbt26vMGlC1dgdcsmSJymbOnKmyyy+/XGUhZMNvb6PLll9kzlzHHlhHxSf8fdy8jlP7kwUAANQfBQAAAIcoAAAAOEQBAADAIXYCNFiDfFZmsYY/rGNYrd0GRUTKy8tVlp+fn9F7I7fl5eWZ+fe//32VTZ06VWXW7pO///3vVXbzzTer7LzzzlNZhw4dzPUAnlgD4r/61a9UZu3O+eCDD6rsSA7XcgcAAACHKAAAADhEAQAAwCEKAAAADjEE2ADWjoErV65U2bZt21TWo0cP8zVrO/IVqG33yMcff1xl1vHVFmuA8Oqrr1bZ8uXLVTZu3LiM3gPIZS+88ILK7rnnHpVZX2vW0Hhtw76NgTsAAAA4RAEAAMAhCgAAAA5RAAAAcIghwAxZg3zTp09X2UMPPaSybt26qez8888336d58+b1WB08qG2HsEwH/izWDpfWENI777yjMoYA4c2mTZtUdv3116usTZs2Kps8ebLKjuTAn4U7AAAAOEQBAADAIQoAAAAOUQAAAHCIAgAAgEN8CiBD1paN1dXVKtu7d6/KrEnr3r17m+/TokWLeqwOOHzatWunsqKiogRWAmSXOXPmqOzzzz9X2S9/+UuVDRgwoFHW1BDcAQAAwCEKAAAADlEAAABwiAIAAIBDDAFmqGvXriqztoDs27evyp5//nmVvf/+++b7jBgxQmUdOnTIZIlwqrS0VGX5+fkZPbeiokJlLVu2VFn79u3rvjAgxdasWaOy2267TWWnnXaayq666iqVNWmSfX/ezr4VAQCARkcBAADAIQoAAAAOUQAAAHCIIcAMWbv59enTR2VXX321yvbs2aOy9evXm++zdu1alQ0fPjyTJSLH1XbN3H///SqbNm2ayqxr2HpNazBwwoQJmSwRSJ0Yo5k//vjjKgshqOyWW25RWUFBQcMXdgRwBwAAAIcoAAAAOEQBAADAIQoAAAAOMQR4mFnH+Q4dOlRldRkC/PrXv66yvLy8eqwOaWYNk4qIzJgxQ2XNmzdXmTVMOnv2bJV16tRJZQMHDsxkiUDqvPjii2b+wAMPqOyGG25Q2bhx4w77mo4U7gAAAOAQBQAAAIcoAAAAOEQBAADAIfdDgFVVVSo7cOCAyqyjHK2d1SzW61mZiEhJSYnKqqurVcYQoD9Dhgwx8+nTp6vswQcfVJk18Ne5c2eV/eY3v1FZ27ZtM1kikNWs7/fz5s0zH2vt9Grt+te0aXp/jHIHAAAAhygAAAA4RAEAAMAhCgAAAA6ld3rhMNm7d6/KtmzZojJrh79evXpl9LhMj1wVsYcNrSMo4Y+1u5+IyOWXX66yb37zmyorKytTWbt27VTWrVs3lTF0ilywaNEilc2aNct87NNPP62yHj16HPY1JYk7AAAAOEQBAADAIQoAAAAOUQAAAHCIIUBjCHDx4sUqKyoqUtmAAQNUZu2s9t5772W8nt69e6sszTtNofFZg6fHHHNMAisBssfu3btVduedd6rspJNOMp8/YcIEleXaQDZ3AAAAcIgCAACAQxQAAAAcogAAAOAQBQAAAIfcjJfHGM28WbNmGT3/008/VdmaNWtUZm3la02j1jZ5ap35nmuTpwBwOFVXV6tszpw5Klu6dKnKNmzYYL5mpj8brPdOy9bZ3AEAAMAhCgAAAA5RAAAAcIgCAACAQ26GAGsbpLPOQx8xYoTKysvLVbZ+/XqV7d+/X2WDBw9WmXVeu4h9FjsAoHZlZWUqe/bZZ1V26623qqxnz57ma1rD23/7299UNnfuXJU98sgjKmvVqpX5PkniDgAAAA5RAAAAcIgCAACAQxQAAAAcCrXtkHfQV/7HXGD9+q2Bkp07d6ps165dKquqqlJZp06dVNa9e3dzPfn5+WaeUtmwhWHOX8NoVNlwDYtwHf/TgQMHVDZ16lSV3XzzzSrbsmWLyu68807zfZ555hmVFRcXq8waMLceZw2cH0HmdcwdAAAAHKIAAADgEAUAAACHKAAAADjkfggQjSobBqi4htEQ2XANi3Ad/1NlZaXKTjzxRJWtWrXqSCxHLrnkEpU99dRTKkv4WHeGAAEAQA0KAAAADlEAAABwiAIAAIBDbo4DBgCk37Rp01SW6cCfNfRe23BeXl6eyn70ox+p7L777svovbMRdwAAAHCIAgAAgEMUAAAAHKIAAADgEEOAAIDUaNq0/j+2OnfurLLrr7/efOzkyZNV1r59+3q/dzbiDgAAAA5RAAAAcIgCAACAQxQAAAAc4jhgNKZsOEqVaxgNkQ3XsAjX8T+VlJSobP369Spr166dyvr27auyhI/pPVI4DhgAANSgAAAA4BAFAAAAhygAAAA4RAEAAMAhPgWAxpQN47Vcw2iIbLiGRbiO0TB8CgAAANSgAAAA4BAFAAAAhygAAAA4RAEAAMAhCgAAAA5RAAAAcIgCAACAQxQAAAAcogAAAOAQBQAAAIcoAAAAOEQBAADAIQoAAAAOHeo4YAAAkIO4AwAAgEMUAAAAHKIAAADgEAUAAACHKAAAADhEAQAAwCEKAAAADlEAAABwiAIAAIBDFAAAAByiAAAA4BAFAAAAhygAAAA4RAEAAMAhCkA9hRAuDSGsCiHsCyGsCyGMTnpNQF2FEAaEEMpDCH9Iei1AXYQQWoQQZoYQNoQQSkII74cQzk56XWnSNOkFpFEIYZyI/EJELhGRt0WkR7IrAurtNyKyLOlFAPXQVEQ2icjpIrJRRM4RkadDCF+LMX6W5MLSIsQYk15D6oQQFovIzBjjzKTXAtRXCOFSEblQRFaKyLExxu8lvCSgQUIIK0Tkrhjjs0mvJQ34K4A6CiHkicjJItIlhLA2hLA5hDA9hNAq6bUBmQohFIjI3SLy46TXAhwOIYRuIjJQRD5Kei1pQQGou24i0kxE/lNERovICSJyoojcnuSigDr6b6m5i7U56YUADRVCaCYis0Tk9zHGj5NeT1pQAOqu7OD//jrGuC3GuFNEfiU1f/8EZL0Qwgki8i0RmZr0WoCGCiE0EZEnRaRSRCYlvJxUYQiwjmKMxSGEzSLyr8MTDFIgTc4QkaNFZGMIQUSkjYjkhRCGxhi/nuC6gDoJNRfwTKm5M3tOjLEq4SWlCkOA9RBCuFtEzhaRiSJSJSJ/EpGFMcY7El0YkIEQQr6IFPxLdJPUFILrYow7ElkUUA8hhEel5q9hvxVj3Jv0etKGOwD1898i0llE1ohIuYg8LSL3JroiIEMxxlIRKf3y30MIe0WknB/+SJMQwlEi8gMRqRCRwoN3s0REfhBjnJXYwlKEOwAAADjEECAAAA5RAAAAcIgCAACAQxQAAAAcOtSnAJgQREOEQz+k0XENoyGy4RoW4TpGw5jXMXcAAABwiAIAAIBDFAAAAByiAAAA4BAFAAAAhygAAAA4RAEAAMAhCgAAAA5RAAAAcIgCAACAQxQAAAAcogAAAOAQBQAAAIcoAAAAOHSo44ABADhsDhw4YOZ79+5V2e7du1VWXl6usry8PC3NHSEAAArNSURBVJW1adNGZe3atVNZixYtzPV4wB0AAAAcogAAAOAQBQAAAIcoAAAAOEQBAADAIT4FAKBeYoxmHkI4witBttq1a5fKHnroIfOxL774osrWrVunspKSEpVZnwLo1q2byk4//XSV3XrrreZ6hgwZorImTXLrz8y59asBAAAZoQAAAOAQBQAAAIcoAAAAOBRqG+Q56Cv/Yy6wtqUsLS1VWVlZmcqaNtUzlPn5+SpzvNVkNkyD5fw1nBTra0JEZPbs2SqztnmdNGnSYV9TI8iGa1gkBdfxhg0bVGYN3W3atMl8fm1bBB9O1oBqly5dzMe+/PLLKjvxxBMzes0sZC6SOwAAADhEAQAAwCEKAAAADlEAAABwKCd3ArSGSQoLC83HLly4UGUrV67M6H169uypspNPPlllgwcPVlnbtm3N10zJQAmcsYaF9+zZYz72iSeeUFm/fv0O+5qQnMrKSpU9+eSTKtuxY4fKCgoKzNe86aabVHbdddeprEOHDiqrqKhQ2W9/+1uV3XjjjSorKioy13PvvfeqbObMmSpr3769+fw04A4AAAAOUQAAAHCIAgAAgEMUAAAAHEr9EKA18GftNPXCCy+Yz3/nnXdUZg11WEMm1vtY67GOpWzTpo25HoYAkY2qqqpUtnr1avOx1mDtcccdd7iXhAQ1b95cZT/+8Y9V1rlzZ5UNHDjQfM0xY8aozNpt1dKyZUuVXXHFFSqbNWuWyhYvXmy+5rvvvquyLVu2qIwhQAAAkCoUAAAAHKIAAADgEAUAAACHUj8EaO1G9tZbb6lswYIF5vP79OmjshEjRqisd+/eKtu8ebPKOnXqpDJrMLC2oy+bNKGTIftYX2fPP/+8+Vhr10DrmGzkFuv3+Oqrr1ZZbd/j8vLyDut6rIHq6urqjJ9vPba23S/Tip82AAA4RAEAAMAhCgAAAA5RAAAAcChVQ4D79+9X2datW1X2+uuvq6xXr17ma95www0qO+aYY1Rm7UhVVlamMmu4z9o163APvACHS6a7a77yyivm81u1aqWy1q1bN3xhSJ1mzZol9t7WMb8ffvhhxs+3dvizhrzTjDsAAAA4RAEAAMAhCgAAAA5RAAAAcChVQ4DWDmMrVqxQmbVb0znnnGO+5lFHHaWyTI+gtIadgLTbt2+fypYuXaqy2nazPPXUU1XGToBoTJWVlSq75ZZbVGZd27UNZA8ePFhlffv2rcfqshd3AAAAcIgCAACAQxQAAAAcogAAAOAQBQAAAIdS9SkAa+rY2tqxoKBAZbVNb7Zo0aLhCwNyyJYtW1T2wQcfqOwb3/iG+Xzr67Rt27YNXxggItXV1SqbMmWKyubOnauyEILK2rVrZ76P9SmCli1bZrLE1OAOAAAADlEAAABwiAIAAIBDFAAAABxK1RCgtY2jtQWkpWPHjmZeUVGhsh07dqhs27ZtKrOGDXv27KmyNm3aqKxJE7oXkrd//36VvffeeyorKytT2ciRI83XXLRokcqsrwHgUKqqqlR2++23q+yRRx5RmfWzwdrm/c477zTf+4QTTshkianGTyEAAByiAAAA4BAFAAAAhygAAAA4lPohwNatW6vsiy++UJk17Ccisnz5cpW9/fbbKtu1a5fKrCETa8fBsWPHquzYY48119O8eXMzBxqDdV0vW7ZMZb169VLZsGHDzNd89dVXVcYQIA6lvLxcZTfffLPKfve736nMGlK1dv278cYbVXbNNdeY67EGBnMNdwAAAHCIAgAAgEMUAAAAHKIAAADgUKqmHKyje63hj08++URlr7zyivma1tGl1mtaO6aVlpaqzBqg+vzzz1V2ySWXmOsZOHCgytg1EI3FOuZ3+/btKvve976nsrrsrskQIL60Z88eM588ebLKZs+erTLr+3OMUWWTJk1SmbWLYH5+vrkeD/jJAgCAQxQAAAAcogAAAOAQBQAAAIdSNQTYtm1blVkDJdZOfrX5wQ9+oLIzzjhDZd26dVPZ3r17VbZgwQKVzZ8/X2WvvfaauZ4+ffqozNrtEKgra1Bq4cKFKtu6davKNm3apDJrWFBEZPPmzSr77LPPVGbtpMlOmOllXV/WMeo//elPzec/88wzKqttB9d/d91116nsjjvuUFm7du0yej0vuAMAAIBDFAAAAByiAAAA4BAFAAAAh1I1BNiqVSuVWbvkWTv0VVdXm6952mmnqcwa+LOOlrSGEseNG6eyJUuWqMwatBIRKSkpURlDgDgcrF0vrUE8a7hv6tSpKqttR7cdO3aobNq0aSo799xzVdalSxfzNZFdrJ1R161bp7K77rpLZXPmzDFf0/oebX1/v/LKK1Vm7fBnXUvW9/G6qKqqUpl1bHBD3+dI4Q4AAAAOUQAAAHCIAgAAgEMUAAAAHErVEKClf//+Khs8eLDKrGN2RUQ6deqksoYMcFiDgXXZ3Wz37t0q6969e73XA3zJuq4vu+wylY0fPz6j51rDrSIiM2bMUJl1NCu7sqWDNZxn7ex42223qez5559XmTWMKmJfY9YQYMuWLVW2fPnyjJ7bUNbw49ixY1VWUFBw2N+7MXAHAAAAhygAAAA4RAEAAMAhCgAAAA5RAAAAcCj1nwIYNWqUyqyzy1u0aGE+v7YtguvLOhO7rKxMZc2aNTsi6wG+ZE1FDxs2rN6vV9sW1c8++6zKrC2y6/LpGCSnvLxcZc8884zKnnvuOZVZ3w9rYz3Wmrp/+OGHM8oaylqP9amxuXPnqsz6uZSNuAMAAIBDFAAAAByiAAAA4BAFAAAAh1I/BDho0CCVDRkyRGVr1641n79y5UqVDR8+XGWZbg/8xRdfqMwaZKltCLBjx44ZvQ+QNOu6FhGpqKhQWW1DuMh+1u/n5s2bVVaXgb9MWa/ZGO+TKWtIe/v27Qms5PDgDgAAAA5RAAAAcIgCAACAQxQAAAAcSv0QoHXu8siRI1X26aefms+fNWuWyrp06aKyPn36qMzaIevll19W2Z49e1R2+umnm+uxdpoCslFdhgCtM9yRDlVVVSorLi5OYCXJO3DggMoKCwsTWMnhwR0AAAAcogAAAOAQBQAAAIcoAAAAOBQOsatSclsuNYA1hPTWW2+Zj33qqadUtnXrVpUdc8wxKmvaVM9QlpSUqOyEE05Q2UUXXWSuxxpATLHMtk9sXKm8htNg9erVZv7SSy+p7Prrr1dZSnYHzIZrWCTB67iyslJl1u531i6o2SbTHV1rY33P79q1q8qycEdX8xfOHQAAAByiAAAA4BAFAAAAhygAAAA4lJNDgBZr1z4RkVWrVqnsjTfeUFlRUZHKmjTR/Wno0KEqO+uss1TWoUMHcz0NHVLJMtnwi8mZaxiJyIZrWITrGA3DECAAAKhBAQAAwCEKAAAADlEAAABwyM0QYG2s4x3Lysoyeq61k5m1U5Rj2TBAlfPXMBpVNlzDIlzHaBiGAAEAQA0KAAAADlEAAABwiAIAAIBD7ocA0aiyYYCKaxgNkQ3XsAjXMRqGIUAAAFCDAgAAgEMUAAAAHKIAAADgEAUAAACHKAAAADhEAQAAwCEKAAAADlEAAABwiAIAAIBDFAAAAByiAAAA4BAFAAAAhygAAAA4RAEAAMAhCgAAAA5RAAAAcIgCAACAQxQAAAAcogAAAOBQiDEmvQYAAHCEcQcAAACHKAAAADhEAQAAwCEKAAAADlEAAABwiAIAAIBD/wd7sWMkgmq3dwAAAABJRU5ErkJggg==\n",
"text/plain": [
"<Figure size 648x648 with 9 Axes>"
]
},
"metadata": {
"tags": [],
"needs_background": "light"
}
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "a3T72QBY1U9P",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 111
},
"outputId": "8da3b67c-fdc4-40e0-a1a7-f0cdf918961a"
},
"source": [
"learn.fit_one_cycle(2)"
],
"execution_count": 5,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>error_rate</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>0.138884</td>\n",
" <td>0.051361</td>\n",
" <td>0.016650</td>\n",
" <td>02:46</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>0.094171</td>\n",
" <td>0.029342</td>\n",
" <td>0.009300</td>\n",
" <td>02:48</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {
"tags": []
}
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "oXJjAlcB4VBg",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 173
},
"outputId": "2c4f8fae-11cf-403f-e40b-d5d6de3e0b20"
},
"source": [
"learn.fit_one_cycle(4)"
],
"execution_count": 6,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>error_rate</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>0.099324</td>\n",
" <td>0.032194</td>\n",
" <td>0.010100</td>\n",
" <td>02:47</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>0.068860</td>\n",
" <td>0.024399</td>\n",
" <td>0.007750</td>\n",
" <td>02:47</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>0.056584</td>\n",
" <td>0.015748</td>\n",
" <td>0.005150</td>\n",
" <td>02:48</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>0.052486</td>\n",
" <td>0.015382</td>\n",
" <td>0.004650</td>\n",
" <td>02:47</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {
"tags": []
}
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "fXWGGcM67S27",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 235
},
"outputId": "c0ac1255-cbba-4f46-f0c1-b76bc6ffdee8"
},
"source": [
"learn.fit_one_cycle(6)"
],
"execution_count": 7,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>error_rate</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>0.057007</td>\n",
" <td>0.016531</td>\n",
" <td>0.005550</td>\n",
" <td>02:47</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>0.058062</td>\n",
" <td>0.019173</td>\n",
" <td>0.005800</td>\n",
" <td>02:47</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>0.056476</td>\n",
" <td>0.012999</td>\n",
" <td>0.003900</td>\n",
" <td>02:47</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>0.046327</td>\n",
" <td>0.013247</td>\n",
" <td>0.003950</td>\n",
" <td>02:47</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>0.037934</td>\n",
" <td>0.012765</td>\n",
" <td>0.003750</td>\n",
" <td>02:47</td>\n",
" </tr>\n",
" <tr>\n",
" <td>5</td>\n",
" <td>0.032311</td>\n",
" <td>0.011637</td>\n",
" <td>0.003600</td>\n",
" <td>02:47</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {
"tags": []
}
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "W5zgype5_efa"
},
"source": [
"# TODO: train longer with fit_one_cycle before SWA\n",
"# tweak swa_lr and swa_freq\n",
"# should aim for an error_rate of 0.002 or less\n",
"import torch.optim\n",
"import torchcontrib.optim\n",
"from torchcontrib.optim import SWA\n",
"\n",
"swa_model = copy.deepcopy(learn.model)\n",
"\n",
"base_opt = torch.optim.SGD(swa_model.parameters(), lr=0.1)\n",
"\n",
"opt = torchcontrib.optim.SWA(base_opt, swa_start=8, swa_freq=2, swa_lr=0.02)\n",
"for _ in range(12):\n",
" for input, target in dls.train:\n",
" opt.zero_grad()\n",
" learn.loss_func(swa_model(input), target).backward()\n",
" opt.step()\n",
"opt.swap_swa_sgd()\n",
"\n",
"opt.bn_update(dls.train, swa_model)\n"
],
"execution_count": 19,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "RboQfSmvoEH2"
},
"source": [
"opt = torchcontrib.optim.SWA(base_opt, swa_start=0, swa_freq=1, swa_lr=0.01)\n",
"for _ in range(2):\n",
" for input, target in dls.train:\n",
" opt.zero_grad()\n",
" learn.loss_func(swa_model(input), target).backward()\n",
" opt.step()\n",
"opt.swap_swa_sgd()\n",
"\n",
"opt.bn_update(dls.train, swa_model)\n"
],
"execution_count": 21,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"id": "oFU4U8JsB1V3",
"outputId": "4565ab25-497a-49d9-f1e2-4c67be85758c"
},
"source": [
"swa_learner = Learner(dls, swa_model, loss_func=learn.loss_func, metrics=error_rate)\n",
"swa_learner.validate()"
],
"execution_count": 22,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/html": [
""
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(#2) [0.011040972545742989,0.0035500000230968]"
]
},
"metadata": {
"tags": []
},
"execution_count": 22
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "TRtQ7VUPaQwI"
},
"source": [
"dummy_input = torch.randn(81, 3, 28, 28).cuda()\n",
"\n",
"torch.onnx.export(learn.model, dummy_input, \"/tmp/digit_recognition.onnx\")"
],
"execution_count": 23,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "cPba1TxqsYbt",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 17
},
"outputId": "5ce8ab45-34d1-4803-bad6-f77554088b4f"
},
"source": [
"from google.colab import files\n",
"\n",
"files.download(\"/tmp/digit_recognition.onnx\")"
],
"execution_count": 24,
"outputs": [
{
"output_type": "display_data",
"data": {
"application/javascript": [
"\n",
" async function download(id, filename, size) {\n",
" if (!google.colab.kernel.accessAllowed) {\n",
" return;\n",
" }\n",
" const div = document.createElement('div');\n",
" const label = document.createElement('label');\n",
" label.textContent = `Downloading \"${filename}\": `;\n",
" div.appendChild(label);\n",
" const progress = document.createElement('progress');\n",
" progress.max = size;\n",
" div.appendChild(progress);\n",
" document.body.appendChild(div);\n",
"\n",
" const buffers = [];\n",
" let downloaded = 0;\n",
"\n",
" const channel = await google.colab.kernel.comms.open(id);\n",
" // Send a message to notify the kernel that we're ready.\n",
" channel.send({})\n",
"\n",
" for await (const message of channel.messages) {\n",
" // Send a message to notify the kernel that we're ready.\n",
" channel.send({})\n",
" if (message.buffers) {\n",
" for (const buffer of message.buffers) {\n",
" buffers.push(buffer);\n",
" downloaded += buffer.byteLength;\n",
" progress.value = downloaded;\n",
" }\n",
" }\n",
" }\n",
" const blob = new Blob(buffers, {type: 'application/binary'});\n",
" const a = document.createElement('a');\n",
" a.href = window.URL.createObjectURL(blob);\n",
" a.download = filename;\n",
" div.appendChild(a);\n",
" a.click();\n",
" div.remove();\n",
" }\n",
" "
],
"text/plain": [
"<IPython.core.display.Javascript object>"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "display_data",
"data": {
"application/javascript": [
"download(\"download_237773b3-8662-447e-baa8-9abd3c49f060\", \"digit_recognition.onnx\", 87255781)"
],
"text/plain": [
"<IPython.core.display.Javascript object>"
]
},
"metadata": {
"tags": []
}
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment