Skip to content

Instantly share code, notes, and snippets.

@jchaykow
Created February 2, 2019 16:56
Show Gist options
  • Save jchaykow/05b066935120ae71aa242c654b4a8367 to your computer and use it in GitHub Desktop.
Save jchaykow/05b066935120ae71aa242c654b4a8367 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Test in tabular"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import pytest"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from fastai.train import ClassificationInterpretation"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"from fastai.tabular import *\n",
"\n",
"pytestmark = pytest.mark.integration\n",
"path = untar_data(URLs.ADULT_SAMPLE)\n",
"\n",
"@pytest.fixture(scope=\"module\")\n",
"def learn():\n",
" df = pd.read_csv(path/'adult.csv')\n",
" procs = [FillMissing, Categorify, Normalize]\n",
" dep_var = 'salary'\n",
" cat_names = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race', 'sex', 'native-country']\n",
" cont_names = ['age', 'fnlwgt', 'education-num']\n",
" test = TabularList.from_df(df.iloc[800:1000].copy(), path=path, cat_names=cat_names, cont_names=cont_names)\n",
" data = (TabularList.from_df(df, path=path, cat_names=cat_names, cont_names=cont_names, procs=procs)\n",
" .split_by_idx(list(range(800,1000)))\n",
" .label_from_df(cols=dep_var)\n",
" .add_test(test)\n",
" .databunch(num_workers=1))\n",
" learn = tabular_learner(data, layers=[200,100], emb_szs={'native-country': 10}, metrics=accuracy)\n",
" learn.fit_one_cycle(2, 1e-2)\n",
" return learn"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"interp = ClassificationInterpretation.from_learner(learn())"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[150, 5],\n",
" [ 30, 15]])"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"interp.confusion_matrix()"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"interp.plot_confusion_matrix()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[('>=50k', '<50k', 30), ('<50k', '>=50k', 5)]"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"interp.most_confused()"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"losses, idxs = interp.top_losses()"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(tensor([4.1996, 3.3137, 3.0365, 2.7762, 2.5842, 2.1620, 1.7882, 1.7127, 1.4891,\n",
" 1.3865]), tensor([166, 24, 26, 195, 45, 42, 139, 146, 125, 9]))"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"losses[:10], idxs[:10]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Test in vision\n",
"- with additional plot_losses methods"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"from fastai.vision import *"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"@pytest.fixture(scope=\"module\")\n",
"def learn():\n",
" path = untar_data(URLs.MNIST_TINY)\n",
" data = ImageDataBunch.from_folder(path, ds_tfms=(rand_pad(2, 28), []), num_workers=2)\n",
" data.normalize()\n",
" learn = Learner(data, simple_cnn((3,16,16,16,2), bn=True), metrics=[accuracy, error_rate])\n",
" learn.fit_one_cycle(3)\n",
" return learn"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"Total time: 00:04 <p>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"interp = ClassificationInterpretation.from_learner(learn(), tta = True)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[329, 17],\n",
" [ 2, 351]])"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"interp.confusion_matrix()"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAARoAAAEmCAYAAAC9C19sAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAFEFJREFUeJzt3Xu4VXWd+PH353C4KQgoqKijlgqITnJTGqcaNXW0vOZTSuadHJvBMmv6qfmUpk2O1u/RmS6jTNPPS5mZQxcv43VIJTFBEU3BSwGOhAmTyEW5fn5/7HXsyMOBw+V79ubwfj3Pfjx7r3XW+mwPvN1r7bWPkZlIUklN9R5AUudnaCQVZ2gkFWdoJBVnaCQVZ2gkFWdotlIR0TMifhkRCyPi9k3YzqkRcd/mnK1eIuKDETGz3nN0RuF1NI0tIj4JXAgMARYB04CvZ+ajm7jd04DzgYMzc+UmD9rgIiKBfTLzpXrPsjXyFU0Di4gLgWuBfwJ2AnYHvgscvxk2vwfwwtYQmfaIiOZ6z9CpZaa3BrwBfYDFwMfXsU53aiGaW92uBbpXyw4B/gf4AvBH4A/AWdWyy4HlwIpqH+cAlwG3tNr2nkACzdX9M4HfUXtV9Xvg1FaPP9rq+w4GngAWVv88uNWyicAVwKRqO/cB/dt4bi3zf6nV/CcAHwFeAP4XuKTV+gcBjwFvVOt+G+hWLXu4ei5Lqud7cqvt/x9gHnBzy2PV9+xV7WNEdX8XYD5wSL3/bGyJt7oP4K2NHwwcBaxs+YvexjpfAyYDOwIDgF8DV1TLDqm+/2tA1+ov6FKgX7V8zbC0GRpgW+BNYHC1bCCwX/X1O6EBtgf+BJxWfd+Y6v4O1fKJwMvAIKBndf+qNp5by/xfqeb/NPA68COgN7Af8Dbw3mr9kcD7q/3uCTwPXNBqewnsvZbt/zO1YPdsHZpqnU9X29kGuBf4Zr3/XGypNw+dGtcOwPxc96HNqcDXMvOPmfk6tVcqp7VavqJaviIz76b2X/PBGznPamD/iOiZmX/IzN+uZZ2PAi9m5s2ZuTIzbwVmAMe2WucHmflCZr4F/AQYto59rqB2PmoF8GOgP3BdZi6q9v9b4H0AmTk1MydX+50FXA/8TTue01czc1k1z7tk5njgReBxanH98nq2pzYYmsa1AOi/nnMHuwCzW92fXT32zjbWCNVSoNeGDpKZS6gdbpwH/CEi7oqIIe2Yp2WmXVvdn7cB8yzIzFXV1y0heK3V8rdavj8iBkXEnRExLyLepHZeq/86tg3wema+vZ51xgP7A/+amcvWs67aYGga12PUDg1OWMc6c6md1G2xe/XYxlhC7RChxc6tF2bmvZl5BLX/ss+g9hdwffO0zPTqRs60Ib5Hba59MnM74BIg1vM963zLNSJ6UTvv9X3gsojYfnMMujUyNA0qMxdSOz/xnYg4ISK2iYiuEXF0RFxdrXYrcGlEDIiI/tX6t2zkLqcBH4qI3SOiD3Bxy4KI2CkijouIbYFl1A7BVq1lG3cDgyLikxHRHBEnA0OBOzdypg3Rm9p5pMXVq63PrLH8NeC9G7jN64CpmTkWuAv4t02ecitlaBpYZv5fatfQXErtROgrwDjgZ9UqVwJTgOnAM8CT1WMbs6/7gduqbU3l3XFoovbu1Vxq78T8DfD3a9nGAuCYat0F1N4xOiYz52/MTBvoi8Anqb2bNZ7ac2ntMuDGiHgjIj6xvo1FxPHUTsifVz10ITAiIk7dbBNvRbxgT1JxvqKRVJyhkVScoZFUnKGRVFxDfZAsum6b0aNvvcdQAQfsM7DeI6iAOXNmsWD+/PVdr9RgoenRl+6j1rz8QZ3BQ3ddvP6VtMU57AOj27Weh06SijM0koozNJKKMzSSijM0koozNJKKMzSSijM0koozNJKKMzSSijM0koozNJKKMzSSijM0koozNJKKMzSSijM0koozNJKKMzSSijM0koozNJKKMzSSijM0koozNJKKMzSSijM0koozNJKKMzSSijM0koozNJKKMzSSijM0koozNJKKMzSSijM0koozNJKKMzSSijM0koozNJKKMzSSijM0koozNJKKMzSSijM0koozNJKKMzSSijM0koozNJKKMzSSijM0koprrvcAnVH3rl144LrT6datmeYuTUz41fNc+f8e5gdfPoERgwayYtUqpsyYy7hv3c3KVavp26sH13/pGN6zSz+WLV/J3119J8/Ner3eT0PrMe68sdx3z130H7Ajv57yNABnnz6Gl154AYCFC9+gT5++PDx5aj3HbAi+oilg2YpVHHXhLYweO57RY8dz5EF7cdC+u/LjB57hgDO+x6izb6Bnt66c9dFhAHzp1L/m6Zde46Cx4znnG7/gm+cfWednoPb45KdO5/af3fWux/7jplt5ePJUHp48lWOPP5Fjjj+hTtM1FkNTyJK3VwDQtbmJ5i5NJMm9j7/8zvIpM15l1wHbATBkz/5MfHIWAC+8soA9durLjv227fCZtWEO/sCH6Lf99mtdlpn87D9/ykkfP6WDp2pMhqaQpqZg8vixzJlwIQ9N/T1PPD/3nWXNXZoYc8Rfcv9vauF55uU/cvyHBgMwasgu7L5zH3Yd0Lsuc2vzeGzSI+y4407stfc+9R6lIRQLTUT0iIjfRMTTEfHbiLi81L4a0erVyfs//e/s/fHrGDVkF4buOeCdZdddcDSTps9h0jOvAPDNH02ib6+eTB4/ls+ceCBPvziPlatW12t0bQZ33H4bH/v4yfUeo2GUPBm8DDgsMxdHRFfg0Yi4JzMnF9xnw1m4ZBkPT5vNkQftxXOzXueS0z/IgL7bcPJX/nxsv2jpcv7u6l++c3/GreOY9Yc36jGuNoOVK1dy588n8NCk39R7lIZR7BVN1iyu7natbllqf42kf59t6LNtdwB6dGvmsJHvYeac+Zz5kWEcceB7Of2KCWSrfxN9tu1O1+baj+Ksjw7n0elzWLR0eT1G12Yw8aEH2GfwYHbddbd6j9Iwir69HRFdgKnA3sB3MvPxkvtrFDvv0IvxFx1Hl6agqSm4Y+Lz3DP5JRY9cAlz5i1k4nfOBODnj8zkGzc9wpA9+vPvFx/PqtWrmTFrPuddc2d9n4DaZewZpzLpkV+xYMF89ttnDy669KucdsbZTPjpTzwJvIbILP8iIyL6AhOA8zPz2TWWnQucC0D3PiN7/NUXi8+jjjf3rovrPYIKOOwDo3nqySmxvvU65F2nzHwDmAgctZZlN2TmqMwcFV19S1fqjEq+6zSgeiVDRPQEDgdmlNqfpMZV8hzNQODG6jxNE/CTzPTkg7QVKhaazJwODC+1fUlbDq8MllScoZFUnKGRVJyhkVScoZFUnKGRVJyhkVScoZFUnKGRVJyhkVScoZFUnKGRVJyhkVScoZFUnKGRVJyhkVScoZFUnKGRVJyhkVScoZFUnKGRVJyhkVScoZFUnKGRVJyhkVScoZFUnKGRVJyhkVScoZFUnKGRVJyhkVScoZFUnKGRVJyhkVScoZFUnKGRVJyhkVScoZFUnKGRVJyhkVScoZFUnKGRVFxzWwsi4pdAtrU8M48rMpGkTqfN0ADf7LApJHVqbYYmM3/VkYNI6rzW9YoGgIjYB/gGMBTo0fJ4Zr634FySOpH2nAz+AfA9YCVwKHATcHPJoSR1Lu0JTc/MfBCIzJydmZcBh5UdS1Jnst5DJ+DtiGgCXoyIccCrwI5lx5LUmbTnFc0FwDbAZ4GRwGnAGSWHktS5rPcVTWY+UX25GDir7DiSOqP2vOv036zlwr3M9DyNpHZpzzmaL7b6ugdwErV3oCSpXdpz6DR1jYcmRUSRi/mGDxrIpPsvLbFp1Vm/A8fVewQVsGzmnHat155Dp+1b3W2idkJ4540bS9LWqD2HTlOpnaMJaodMvwfOKTmUpM6lPaHZNzPfbv1ARHQvNI+kTqg919H8ei2PPba5B5HUea3r99HsDOwK9IyI4dQOnQC2o3YBnyS1y7oOnf4WOBPYDfgWfw7Nm8AlZceS1Jms6/fR3AjcGBEnZeYdHTiTpE6mPedoRkZE35Y7EdEvIq4sOJOkTqY9oTk6M99ouZOZfwI+Um4kSZ1Ne0LTpfXb2RHRE/DtbUnt1p7raG4BHoyIH1T3zwJuLDeSpM6mPZ91ujoipgOHU3vn6b+APUoPJqnzaO//QG4esJraJ7c/DDxfbCJJnc66LtgbBJwCjAEWALdR+73Bh3bQbJI6iXUdOs0AHgGOzcyXACLi8x0ylaROZV2HTidRO2T674gYHxEf5s9XB0tSu7UZmsyckJknA0OAicDngZ0i4nsRcWQHzSepE1jvyeDMXJKZP8zMY6h97mkacFHxySR1Gu191wmAzPzfzLzeX0wuaUNsUGgkaWMYGknFGRpJxRkaScUZGknFGRpJxRkaScUZGknFGRpJxRkaScUZGknFGRpJxRkaScUZGknFGRpJxRkaScUZGknFGRpJxRkaScUZGknFGRpJxRkaScUZGknFGRpJxRkaScUZGknFGRpJxRkaScUZGknFGRpJxRkaScUZGknFGRpJxTXXe4CtySuvvMLYs07ntdfm0dTUxNnnnMu4z36u3mNpA3Tv1swD37+Abt2aae7ShQkPPMWV/3Y3N1z+KT44cm8WLn4bgHO/cjPTX3iVQXvuxA2Xf4phQ3bjsm/fybU3P1jnZ1AfhqYDNTc3c9XV32L4iBEsWrSIg0eP5MOHH8G+Q4fWezS107LlKznq3H9hyVvLaW5u4qH/uJD7Jj0HwCXX/owJD0x71/p/WriEL/zz7Rx76AH1GLdheOjUgQYOHMjwESMA6N27N0OG7Mvcua/WeSptqCVvLQega3MXmpu7kJltrvv6nxYz9bk5rFi5qqPGa0iGpk5mz5rFtGlPceBBo+s9ijZQU1Mw+ccXMefBq3ho8gyeeHY2AJf9w7H85raLufoLH6NbVw8WWisWmogYHBHTWt3ejIgLSu1vS7J48WLGfOIkrvnWtWy33Xb1HkcbaPXq5P2nXMXef3spo/bfg6F7DeQr//oLDjjxCj7wqWvo12dbvnDW4fUes6EUC01mzszMYZk5DBgJLAUmlNrflmLFihWM+cRJnDzmVE448WP1HkebYOHit3h4yoscefBQ5s1/E4DlK1Zy088nM2q/Pes7XIPpqEOnDwMvZ+bsDtpfQ8pMzvv0OQwesi+f+/yF9R5HG6F/v1706dUTgB7du3LY6MHMnPUaO/f/8yvT4w59H8+9PLdeIzakjjqQPAW4dW0LIuJc4FyAv9h99w4apz5+PWkSP/rhzey//18yeuQwAC6/8p846uiP1HkytdfO/bdj/NdOo0tTE01NwR33P8k9jzzLPdefT/9+vYmA6TP/h/O//mMAdtqhN5N++CV6b9uD1ZmMO/UQhp/0dRYtebvOz6RjxbrOmG+WHUR0A+YC+2Xma+tad+TIUTnp8SlF51F99DtwXL1HUAHLZv6E1Uv/GOtbryMOnY4GnlxfZCR1Xh0RmjG0cdgkaetQNDQRsQ1wBPCfJfcjqbEVPRmcmUuBHUruQ1Lj88pgScUZGknFGRpJxRkaScUZGknFGRpJxRkaScUZGknFGRpJxRkaScUZGknFGRpJxRkaScUZGknFGRpJxRkaScUZGknFGRpJxRkaScUZGknFGRpJxRkaScUZGknFGRpJxRkaScUZGknFGRpJxRkaScUZGknFGRpJxRkaScUZGknFGRpJxRkaScUZGknFGRpJxRkaScUZGknFGRpJxRkaScUZGknFGRpJxRkaScUZGknFGRpJxRkaScUZGknFGRpJxRkaScUZGknFGRpJxRkaScUZGknFRWbWe4Z3RMTrwOx6z9FB+gPz6z2ENrut7ee6R2YOWN9KDRWarUlETMnMUfWeQ5uXP9e189BJUnGGRlJxhqZ+bqj3ACrCn+taeI5GUnG+opFUnKGRVJyhkVRcc70H2FpExEFAZuYTETEUOAqYkZl313k0qThPBneAiPgqcDS1sN8PjAYmAocD92bm1+s3nTZWRHwWmJCZr9R7lkZnaDpARDwDDAO6A/OA3TLzzYjoCTyeme+r64DaKBGxEFgCvAzcCtyema/Xd6rG5DmajrEyM1dl5lLg5cx8EyAz3wJW13c0bYLfAbsBVwAjgeci4r8i4oyI6F3f0RqLoekYyyNim+rrkS0PRkQfDM2WLDNzdWbel5nnALsA36V2/u139R2tsXjo1AEiontmLlvL4/2BgZn5TB3G0iaKiKcyc3gby3pWr1iFoZE2WkQMyswX6j3HlsDQSCrOczSSijM0koozNAIgIlZFxLSIeDYibm/1LtnGbOuQiLiz+vq4iLhoHev2jYi/34h9XBYRX9zYGdWxDI1avJWZwzJzf2A5cF7rhVGzwX9eMvMXmXnVOlbpC2xwaLRlMTRam0eAvSNiz4h4PiK+CzwJ/EVEHBkRj0XEk9Urn14AEXFURMyIiEeBj7VsKCLOjIhvV1/vFBETIuLp6nYwcBWwV/Vq6ppqvX+MiCciYnpEXN5qW1+OiJkR8QAwuMP+bWiTGRq9S0Q0U/tcVsu1PYOBm6rrRZYAlwKHZ+YIYApwYUT0AMYDxwIfBHZuY/P/AvwqMw8ARgC/BS6idrX0sMz8x4g4EtgHOIjaxzZGRsSHImIkcAownFrIDtzMT10F+elttegZEdOqrx8Bvk/tStfZmTm5evz9wFBgUkQAdAMeA4YAv8/MFwEi4hbg3LXs4zDgdIDMXAUsjIh+a6xzZHV7qrrfi1p4elP7AOPSah+/2KRnqw5laNTircwc1vqBKiZLWj8E3J+ZY9ZYbxiwuS7ICuAbmXn9Gvu4YDPuQx3MQydtiMnAX0fE3gARsU1EDAJmAO+JiL2q9ca08f0PAp+pvrdLRGwHLKL2aqXFvcDZrc797BoROwIPAydGRM/qA4vHbubnpoIMjdqt+hUIZwK3RsR0auEZkplvUztUuqs6GdzW/230c8Ch1a/NmArsl5kLqB2KPRsR12TmfcCPgMeq9X4K9M7MJ4HbgGnAHdQO77SF8CMIkorzFY2k4gyNpOIMjaTiDI2k4gyNpOIMjaTiDI2k4v4/t0sEs6YYCI8AAAAASUVORK5CYII=\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"interp.plot_confusion_matrix()"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[('3', '7', 17), ('7', '3', 2)]"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"interp.most_confused()"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"losses, idxs = interp.top_losses()"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(tensor([0.8811, 0.7913, 0.7798, 0.7450, 0.7419, 0.7401, 0.7392, 0.7298, 0.7278,\n",
" 0.7274]), tensor([335, 13, 341, 11, 292, 17, 1, 575, 94, 206]))"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"losses[:10], idxs[:10]"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 864x864 with 4 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"interp.plot_top_losses(4)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.7"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment