Created
November 12, 2020 15:30
-
-
Save muellerzr/df3fc4a12b021be85639afddab3c5d32 to your computer and use it in GitHub Desktop.
tabular-export-learner-testing.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"accelerator": "GPU", | |
"colab": { | |
"name": "tabular-export-learner-testing.ipynb", | |
"provenance": [], | |
"collapsed_sections": [], | |
"include_colab_link": true | |
}, | |
"jupytext": { | |
"split_at_heading": true | |
}, | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/muellerzr/df3fc4a12b021be85639afddab3c5d32/tabular-export-learner-testing.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "XrJzkfJ-1bKp" | |
}, | |
"source": [ | |
"# Tabular models" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "Lu4U4bzjv0oH", | |
"outputId": "4ad0e0bf-7db1-4f47-b8fc-90bcec46dea7", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
} | |
}, | |
"source": [ | |
"!pip install fastai --upgrade -qqq\n", | |
"!pip install pympler" | |
], | |
"execution_count": 1, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"\u001b[K |████████████████████████████████| 194kB 5.7MB/s \n", | |
"\u001b[K |████████████████████████████████| 51kB 5.5MB/s \n", | |
"\u001b[?25h" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "_kw8-Dju_6J5", | |
"outputId": "52b5073f-77be-4778-aed2-6eb1cfeb04fd", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 35 | |
} | |
}, | |
"source": [ | |
"import fastai\n", | |
"fastai.__version__" | |
], | |
"execution_count": 1, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"application/vnd.google.colaboratory.intrinsic+json": { | |
"type": "string" | |
}, | |
"text/plain": [ | |
"'2.1.5'" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 1 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "K-NLPMGM1bKr" | |
}, | |
"source": [ | |
"from fastai.tabular.all import *" | |
], | |
"execution_count": 2, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "dzHkM_t41bKx" | |
}, | |
"source": [ | |
"Tabular data should be in a Pandas `DataFrame`." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "xLLnzxw_1bKy" | |
}, | |
"source": [ | |
"path = untar_data(URLs.ADULT_SAMPLE)\n", | |
"df = pd.read_csv(path/'adult.csv')\n", | |
"for i in range(4):\n", | |
" df = df.append(df.copy())" | |
], | |
"execution_count": 3, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "TBoo6c761bK2" | |
}, | |
"source": [ | |
"dep_var = 'salary'\n", | |
"cat_names = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race']\n", | |
"cont_names = ['age', 'fnlwgt', 'education-num']\n", | |
"procs = [Categorify, FillMissing, Normalize]" | |
], | |
"execution_count": 4, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "ZNlr9V4H1bK_" | |
}, | |
"source": [ | |
"splits = RandomSplitter()(range_of(df))" | |
], | |
"execution_count": 5, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "2ERqyHkxE7sq", | |
"outputId": "77c557b4-b2d5-4d5b-a0b2-1e12f71b60cc", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
} | |
}, | |
"source": [ | |
"splits" | |
], | |
"execution_count": 6, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"((#416781) [307989,75517,414193,363282,349727,241733,40196,389451,329687,91195...],\n", | |
" (#104195) [406723,268985,333518,493941,351485,213666,430108,31256,68008,294457...])" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 6 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "cTd_Ckt61bLH" | |
}, | |
"source": [ | |
"to = TabularPandas(df, procs, cat_names, cont_names, y_names=\"salary\", splits=splits)" | |
], | |
"execution_count": 6, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "7UvARddj1bLN" | |
}, | |
"source": [ | |
"dls = to.dataloaders(bs=128)" | |
], | |
"execution_count": 7, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "Ucg9ao3dEpbF" | |
}, | |
"source": [ | |
"First we'll check before exporting:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "KiNZuAIrC4m6", | |
"outputId": "85446930-4508-4935-d43c-3e23491b3081", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
} | |
}, | |
"source": [ | |
"from pympler import muppy\n", | |
"all_objects = muppy.get_objects()\n", | |
"my_types = muppy.filter(all_objects, Type=pd.DataFrame)\n", | |
"len(my_types)" | |
], | |
"execution_count": 8, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"/usr/local/lib/python3.6/dist-packages/torch/distributed/distributed_c10d.py:126: UserWarning: torch.distributed.reduce_op is deprecated, please use torch.distributed.ReduceOp instead\n", | |
" warnings.warn(\"torch.distributed.reduce_op is deprecated, please use \"\n" | |
], | |
"name": "stderr" | |
}, | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"5" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 8 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "N8vdsac4DPdM", | |
"outputId": "647eb163-596f-4531-ce89-d7970cae91c4", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
} | |
}, | |
"source": [ | |
"for t in my_types:\n", | |
" print(len(t))" | |
], | |
"execution_count": 9, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"520976\n", | |
"0\n", | |
"104195\n", | |
"520976\n", | |
"416781\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "bRYiL1gZ1bLZ" | |
}, | |
"source": [ | |
"learn = tabular_learner(dls, layers=[200,100], metrics=accuracy)" | |
], | |
"execution_count": 10, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "O_x8daKB1DqB" | |
}, | |
"source": [ | |
"learn.export('t1')" | |
], | |
"execution_count": 11, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "i7En8TFR1FPN" | |
}, | |
"source": [ | |
"l1 = load_learner('t1')" | |
], | |
"execution_count": 12, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "06Uyq54OwdOj", | |
"outputId": "1fd7e30f-0748-4c5e-9520-227d8de369be", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 77 | |
} | |
}, | |
"source": [ | |
"learn.fit(1)" | |
], | |
"execution_count": 13, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/html": [ | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: left;\">\n", | |
" <th>epoch</th>\n", | |
" <th>train_loss</th>\n", | |
" <th>valid_loss</th>\n", | |
" <th>accuracy</th>\n", | |
" <th>time</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <td>0</td>\n", | |
" <td>0.331344</td>\n", | |
" <td>0.325066</td>\n", | |
" <td>0.847334</td>\n", | |
" <td>01:00</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>" | |
], | |
"text/plain": [ | |
"<IPython.core.display.HTML object>" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
} | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "DUPAf9D1z-Pu" | |
}, | |
"source": [ | |
"learn.export('t2')" | |
], | |
"execution_count": 14, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "64IiOmQuEuxV" | |
}, | |
"source": [ | |
"And then we'll check t2 (what we just exported):" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "UCEOGECvz_jA", | |
"outputId": "8636eb57-3d9e-43fe-f14a-8af748449b42", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
} | |
}, | |
"source": [ | |
"from fastai.tabular.all import *\n", | |
"from pympler import muppy\n", | |
"l2 = load_learner('t2')\n", | |
"\n", | |
"all_objects = muppy.get_objects()\n", | |
"my_types = muppy.filter(all_objects, Type=pd.DataFrame)\n", | |
"for t in my_types:\n", | |
" print(len(t))" | |
], | |
"execution_count": 1, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"/usr/local/lib/python3.6/dist-packages/torch/distributed/distributed_c10d.py:126: UserWarning: torch.distributed.reduce_op is deprecated, please use torch.distributed.ReduceOp instead\n", | |
" warnings.warn(\"torch.distributed.reduce_op is deprecated, please use \"\n" | |
], | |
"name": "stderr" | |
}, | |
{ | |
"output_type": "stream", | |
"text": [ | |
"0\n", | |
"0\n", | |
"104195\n", | |
"0\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "zdX_t_XqExMD" | |
}, | |
"source": [ | |
"And t1 (exported prior to training):" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "kBwPMhK8DeTl", | |
"outputId": "d77b575f-c9e3-4798-fe3c-8492bb106a0c", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
} | |
}, | |
"source": [ | |
"from fastai.tabular.all import *\n", | |
"from pympler import muppy\n", | |
"l2 = load_learner('t1')\n", | |
"\n", | |
"all_objects = muppy.get_objects()\n", | |
"my_types = muppy.filter(all_objects, Type=pd.DataFrame)\n", | |
"for t in my_types:\n", | |
" print(len(t))" | |
], | |
"execution_count": 1, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"/usr/local/lib/python3.6/dist-packages/torch/distributed/distributed_c10d.py:126: UserWarning: torch.distributed.reduce_op is deprecated, please use torch.distributed.ReduceOp instead\n", | |
" warnings.warn(\"torch.distributed.reduce_op is deprecated, please use \"\n" | |
], | |
"name": "stderr" | |
}, | |
{ | |
"output_type": "stream", | |
"text": [ | |
"0\n", | |
"0\n", | |
"0\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "U5qaQvxXE1ax" | |
}, | |
"source": [ | |
"We can see that the exported version had a copy of the validation dataset somewhere" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "ptg9TfCBE4U7" | |
}, | |
"source": [ | |
"" | |
], | |
"execution_count": null, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment