Last active
November 13, 2020 08:28
-
-
Save ababino/2a2c67ac264e2ed8c95144377b9be2b4 to your computer and use it in GitHub Desktop.
tabular-export-learner-testing.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"accelerator": "GPU", | |
"colab": { | |
"name": "tabular-export-learner-testing.ipynb", | |
"provenance": [], | |
"collapsed_sections": [], | |
"include_colab_link": true | |
}, | |
"jupytext": { | |
"split_at_heading": true | |
}, | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/ababino/2a2c67ac264e2ed8c95144377b9be2b4/tabular-export-learner-testing.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "XrJzkfJ-1bKp" | |
}, | |
"source": [ | |
"# Tabular models" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "Lu4U4bzjv0oH", | |
"outputId": "5e4e9767-2525-447b-803e-7877cce896eb", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
} | |
}, | |
"source": [ | |
"!pip install fastai --upgrade -qqq\n", | |
"!pip install pympler" | |
], | |
"execution_count": 1, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Requirement already satisfied: pympler in /usr/local/lib/python3.6/dist-packages (0.9)\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "_kw8-Dju_6J5", | |
"outputId": "3b2ec4af-7fa7-4a72-f542-366f39ea0a33", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 35 | |
} | |
}, | |
"source": [ | |
"import fastai\n", | |
"fastai.__version__" | |
], | |
"execution_count": 2, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"application/vnd.google.colaboratory.intrinsic+json": { | |
"type": "string" | |
}, | |
"text/plain": [ | |
"'2.1.5'" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 2 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "K-NLPMGM1bKr" | |
}, | |
"source": [ | |
"from fastai.tabular.all import *\n", | |
"from fastai.callback.all import *\n", | |
"from pympler import muppy\n", | |
"import warnings\n", | |
"warnings.filterwarnings('ignore')" | |
], | |
"execution_count": 3, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "7fsTZYn5a18b" | |
}, | |
"source": [ | |
"def print_dfs_lens():\n", | |
" all_objects = muppy.get_objects()\n", | |
" my_types = muppy.filter(all_objects, Type=pd.DataFrame)\n", | |
" my_types_lens = [len(x) for x in my_types]\n", | |
" print(my_types_lens[:-3], my_types_lens[-3:])" | |
], | |
"execution_count": 4, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "dzHkM_t41bKx" | |
}, | |
"source": [ | |
"Tabular data should be in a Pandas `DataFrame`." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "xLLnzxw_1bKy" | |
}, | |
"source": [ | |
"path = untar_data(URLs.ADULT_SAMPLE)\n", | |
"df = pd.read_csv(path/'adult.csv')\n", | |
"for i in range(4):\n", | |
" df = df.append(df.copy())" | |
], | |
"execution_count": 5, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "TBoo6c761bK2" | |
}, | |
"source": [ | |
"dep_var = 'salary'\n", | |
"cat_names = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race']\n", | |
"cont_names = ['age', 'fnlwgt', 'education-num']\n", | |
"procs = [Categorify, FillMissing, Normalize]" | |
], | |
"execution_count": 6, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "ZNlr9V4H1bK_" | |
}, | |
"source": [ | |
"splits = RandomSplitter()(range_of(df))" | |
], | |
"execution_count": 7, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "2ERqyHkxE7sq", | |
"outputId": "7e5f0ce5-b21c-4990-ff15-6cee2a8b0f94", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
} | |
}, | |
"source": [ | |
"splits" | |
], | |
"execution_count": 8, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"((#416781) [500824,451294,253829,307438,386053,451314,171109,189741,68905,477899...],\n", | |
" (#104195) [7873,188230,219329,201724,405331,315616,162262,14314,369423,412459...])" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 8 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "cTd_Ckt61bLH" | |
}, | |
"source": [ | |
"to = TabularPandas(df, procs, cat_names, cont_names, y_names=\"salary\", splits=splits)" | |
], | |
"execution_count": 9, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "7UvARddj1bLN" | |
}, | |
"source": [ | |
"dls = to.dataloaders(bs=128)" | |
], | |
"execution_count": 10, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "Ucg9ao3dEpbF" | |
}, | |
"source": [ | |
"First we'll check before fit:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "bRYiL1gZ1bLZ", | |
"outputId": "429d9059-504e-4e7b-d870-eb80f0ca046b", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
} | |
}, | |
"source": [ | |
"print_dfs_lens()\n", | |
"learn = tabular_learner(dls, layers=[200,100], metrics=accuracy)\n", | |
"learn.export('t1')\n", | |
"l1 = load_learner('t1')\n", | |
"print_dfs_lens()" | |
], | |
"execution_count": 11, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"[520976, 0] [520976, 104195, 416781]\n", | |
"[520976, 0, 520976, 104195, 416781] [0, 0, 0]\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "dexvA6_4q8i6" | |
}, | |
"source": [ | |
"Now exporting after fit" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "n3VQpLnjjV2D", | |
"outputId": "6b79d14c-0972-4c44-c9e4-ec24bd305b31", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 114 | |
} | |
}, | |
"source": [ | |
"print_dfs_lens()\n", | |
"learn = tabular_learner(dls, layers=[200,100], metrics=accuracy)\n", | |
"learn.fit(1, cbs=ShortEpochCallback())\n", | |
"learn.export('t2')\n", | |
"l2 = load_learner('t2')\n", | |
"print_dfs_lens()" | |
], | |
"execution_count": 12, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"[520976, 0, 520976, 104195, 416781] [0, 0, 0]\n" | |
], | |
"name": "stdout" | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/html": [ | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: left;\">\n", | |
" <th>epoch</th>\n", | |
" <th>train_loss</th>\n", | |
" <th>valid_loss</th>\n", | |
" <th>accuracy</th>\n", | |
" <th>time</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <td>0</td>\n", | |
" <td>00:00</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>" | |
], | |
"text/plain": [ | |
"<IPython.core.display.HTML object>" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
} | |
}, | |
{ | |
"output_type": "stream", | |
"text": [ | |
"[520976, 0, 520976, 104195, 416781, 0, 0, 0, 0] [0, 104195, 0]\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "bSpqmQc5rF21" | |
}, | |
"source": [ | |
"We see that the loaded learner experted after fitting has holds the valid ds somewhere" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "XtvccFvdpwk1" | |
}, | |
"source": [ | |
"## Fix" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "Fz0-WjWVozli" | |
}, | |
"source": [ | |
"The problem is that `ProgressCallback` holds the last dl (valid) in the `pbar` attribute as you can see [here](https://github.com/fastai/fastai/blob/ef5c3609b8da03fb25953a7f4f49aaa1e69197d2/fastai/callback/progress.py#L34).\n", | |
"If we patch ProgressCallback to delete the `pbar` attribute in the last line in `after_fit` we fix the problem." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "0Klq5ig7oqWh" | |
}, | |
"source": [ | |
"@patch\n", | |
"def after_fit(self:ProgressCallback):\n", | |
" if getattr(self, 'mbar', False):\n", | |
" self.mbar.on_iter_end()\n", | |
" delattr(self, 'mbar')\n", | |
" if hasattr(self, 'old_logger'): self.learn.logger = self.old_logger\n", | |
" ###### this is the new line #############\n", | |
" if getattr(self, 'pbar', False): delattr(self, 'pbar')" | |
], | |
"execution_count": 13, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "b5RQn0X7p78J", | |
"outputId": "6c986bce-ee4b-4100-f56d-593a9388581b", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 114 | |
} | |
}, | |
"source": [ | |
"print_dfs_lens()\n", | |
"learn = tabular_learner(dls, layers=[200,100], metrics=accuracy)\n", | |
"learn.fit(1, cbs=ShortEpochCallback())\n", | |
"learn.export('t3')\n", | |
"l3 = load_learner('t3')\n", | |
"print_dfs_lens()" | |
], | |
"execution_count": 14, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"[520976, 0, 520976, 104195, 416781, 0, 0, 0, 0] [0, 104195, 0]\n" | |
], | |
"name": "stdout" | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/html": [ | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: left;\">\n", | |
" <th>epoch</th>\n", | |
" <th>train_loss</th>\n", | |
" <th>valid_loss</th>\n", | |
" <th>accuracy</th>\n", | |
" <th>time</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <td>0</td>\n", | |
" <td>00:00</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>" | |
], | |
"text/plain": [ | |
"<IPython.core.display.HTML object>" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
} | |
}, | |
{ | |
"output_type": "stream", | |
"text": [ | |
"[520976, 0, 520976, 104195, 416781, 0, 0, 0, 0, 0, 104195, 0] [0, 0, 0]\n" | |
], | |
"name": "stdout" | |
} | |
] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment