Created
May 30, 2019 20:33
-
-
Save tylerneylon/84950b1b436f25508d7b5f903539d276 to your computer and use it in GitHub Desktop.
Example of visualizing and extracting classifier cutoff values to meet certain precision-recall objectives.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Example: Finding a Cutoff for 95% Precision\n", | |
"\n", | |
"The purpose of this notebook is to demonstrate programmatically\n", | |
"choosing a cutoff value for a binary classifier such that:\n", | |
"(a) we achieve 95% or higher precision if possible, and (b)\n", | |
"within that constraint, we maximize the recall." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Imports" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"%matplotlib inline\n", | |
"\n", | |
"import matplotlib.pyplot as plt\n", | |
"import numpy as np\n", | |
"\n", | |
"from sklearn.datasets import load_iris\n", | |
"from sklearn.linear_model import LogisticRegression\n", | |
"from sklearn.metrics import average_precision_score\n", | |
"from sklearn.metrics import confusion_matrix\n", | |
"from sklearn.metrics import precision_recall_curve\n", | |
"from sklearn.model_selection import train_test_split" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Set the Random Seed\n", | |
"\n", | |
"This is so that anyone can re-run this notebook\n", | |
"and arrive at the same results." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"np.random.seed(2)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Load in Sample Data\n", | |
"\n", | |
"We'll work with the classic iris data set.\n", | |
"The problem here is to use some measurements of\n", | |
"different parts of a flower to identify which\n", | |
"species of iris it is.\n", | |
"\n", | |
"I'll filter the data down to focus on only two\n", | |
"classes, so that we can look at this in the context\n", | |
"of binary classifiers." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"((150, 4), (150,), array([0, 1, 2]))" | |
] | |
}, | |
"execution_count": 3, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"iris_data = load_iris()\n", | |
"X = iris_data['data']\n", | |
"y = iris_data['target']\n", | |
"\n", | |
"# See some basic info about the raw dataset.\n", | |
"X.shape, y.shape, np.unique(y)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# We have 150 data points with 4 features and 3 classes (0, 1, and 2)." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"((100, 4), (100,))" | |
] | |
}, | |
"execution_count": 5, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# Filter down to two classes.\n", | |
"X = X[y > 0]\n", | |
"y = y[y > 0] - 1 # Map [1, 2] to [0, 1].\n", | |
"X.shape, y.shape" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([0, 1])" | |
] | |
}, | |
"execution_count": 6, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# Sanity check that we have the 2 standard class labels 0 and 1.\n", | |
"np.unique(y)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Split into Train / Test Data" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.95)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Train a Basic Classifier\n", | |
"\n", | |
"I'll use logistic regression to keep this similar to our own case." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"model = LogisticRegression(solver='lbfgs').fit(X_train, y_train)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Check the Accuracy, Precision, and Recall\n", | |
"\n", | |
"I'm not yet choosing a cutoff, but showing how well the classifier is doing." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"y_predict = np.array([int(proba[1] > 0.5) for proba in model.predict_proba(X_test)])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([[46, 0],\n", | |
" [44, 5]])" | |
] | |
}, | |
"execution_count": 10, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# Each row here is a true class, and each column is a predicted class.\n", | |
"# The best possible outcome is a matrix with zeros everywhere except\n", | |
"# along the diagonal, which corresponds to perfect classification.\n", | |
"confusion_matrix(y_test, y_predict)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"0.5368421052631579" | |
] | |
}, | |
"execution_count": 11, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# Check the accuracy.\n", | |
"sum(y_predict == y_test) / len(y_test)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYoAAAEWCAYAAAB42tAoAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAHP5JREFUeJzt3XucHGWd7/HPN5lcgIQgBKKGkAByMQKiZLkcV2FXROBI4ssrUVRclnhjV4+34571aMR1XXXVZQ+wkhUWRS4CL9cdNYg3BFHRhOUiCcYN1wRwSSAEciHX3/njeYZpOtNP9wxT0z2T7/v16tdUV1VX//qZ7vpWPdVVrYjAzMyskVHtLsDMzDqbg8LMzIocFGZmVuSgMDOzIgeFmZkVOSjMzKzIQTGMSTpT0s3trmOwSVoi6YQm8+wnaZ2k0UNUVuUk3S/pxDw8X9K32l2TGTgohpykcZIulvSApKck3S7plHbX1Yq8ItuYV9D/LelSSRMG+3ki4iUR8fMm8zwYERMiYttgP39eSW/Jr/MJSb+SdNxgP8/OIr9Ptkp6Qd34QWlnSW/Ln6f1kr4rac/CvKdJuis/568kzaybfoCk7+fP5mpJX+xvPSORg2LodQErgOOBScAngaslzWhjTf1xWkRMAF4OzCLV/yxKhvt769v5dU4GbgCuaXM9g05S1xA8x27AG4G1wBl9zNLTznsDNwPfkaR+LP8lwEXAO4ApwAbgwgbzHgRcDrwX2AP4HtDd0w6SxgI/Bn4GPB/YF/BeHQ6KIRcR6yNifkTcHxHbI+L7wH3AUY0eI2mapO9IWiXpMUnnN5jvPEkrJD0p6VZJr6yZdrSkxXnaf0v6Sh4/XtK38nKfkLRI0pQWXsdDwHXAYXk5P5f0OUm/JH1YD5A0Ke89PSLpIUl/V9tVJOlsSXfnrbelkl6ex9d2wTSqe4akqPmQv1BSt6THJS2XdHbN88yXdLWkb+bnWiJpVrPXmF/nVtLKZaqkvWuW+bq8N9izJXxEzbQ+/1+SDpT0szxutaTLJe3RSh31JM3Jz/+kpHsknVzfdjWv/Vt1bXaWpAeBn0m6TtI5dcu+Q9Ib8vChkn6c23WZpLf0s9Q3Ak8A5wLvajRTRGwBvkFaQe/Vj+W/HfheRNwUEeuA/wu8QdLEPuZ9LfCLiLg5/1+/AEwlbbQBnAk8HBFfyZ/TpyPizn7UMmI5KNosr5QPBpY0mD4a+D7wADCD9Ma+qsHiFgFHAnsCVwDXSBqfp50HnBcRuwMHAlfn8e8i7dlMI31A3wtsbKHuacCpwG01o98BzAMm5novBbYCLwJeBpwE/GV+/JuB+cA7gd2B2cBjfTxVo7rrXQWsBF4IvAn4e0l/XjN9dp5nD6Ab6DNs+3idY3ONjwFr8riXAZcA7yG12UWkLdNxTf5fAj6fa3wxqc3nt1JHXU1HA98EPpZfz6uA+/uxiOPz878WuBKYW7PsmcB04Ad5b+DHpPfSPsDpwIV5np4un2Yr0nfl57gKOFRSnxtEksaRVtQrImK1pD/NIdzo9qf5oS8B7uhZTkTcA2wmfab6fKq6YZE3doBjgftzeK7OGz+HN3l9O4eI8K1NN2AM8BPgosI8xwGrgK4+pp0J3Fx47BrgpXn4JuAzwOS6ef4C+BVwRAv13g+sI20hPkDaxd8lT/s5cG7NvFOATT3T87i5wA15+Hrgg4XnObFJ3TOAIHXlTQO2ARNrpn8euDQPzwd+UjNtJrCx8Drnk1Y2T+TlPgacUDP9X4DP1j1mGWkF3PD/1cfzvB64rcHrng98q8HjLgK+2qzt6pdT02YH1EyfCKwHpuf7nwMuycNvJW2B1z/3p1t8f+8HbAeOrPmfn9egnR8ldfkc1c/P0E+B99aNe6j2/1Uz/tD8Wk8AxpL2PrYDf5On/wjYApySp38MuBcY25+aRuLNexRtotSHfxnpg3JOzfjrlA60rZP0dtJK8IFIu8rNlvnR3JWzVtITpD2FyXnyWaStrN/n7qXX5fGXkT7AV0l6WNIXJY0pPM3rI2KPiJgeEe+PiNq9jxU1w9NJQfhIz1YgaSWzT54+Dbin2Wsq1F3rhcDjEfFUzbgHSFvzPf5YM7wBGC+pS9Lba9r7upp5ro6IPUiBdxfP7hqcDnykdgs3v54XUvh/SZoi6arcDfckqf97cv18LWi17Rp55v+U2+wHpL0FSGF+eR6eDhxT9zrfTuoeasU7gLsj4vZ8/3LgbXXvr6vz+2mfiPjziLi1n69lHWmPtNbuwFP1M0bE70l7OOcDj5DafilpTxTSnvTNEXFdRGwG/pG0x/jiftY04lR+MMt2JEnAxaSV0KmR+mcBiIhT6uY9DthPUlcpLJSOR3wceDWwJCK2S1pD3tWOiP8C5uaAegNwraS9ImI9aYv9M0oH1BeSto4vHsBLq70U8QrSHsXkBnWvIHUllRfYoO662R4G9pQ0sSYs9iNtWTZb/uX0rhj7mr5a0jxgsaQrIuKRXPvnIuJz9fM3+X/9PamNDo+IxyW9nha7wOqU2m49sGvN/b5W6vWXjL4S+LSkm4DxpIP3Pc9zY0S8ZgA1Quqy209ST0h3kVa8pwL/UXpgfj9fV5jllIj4BanL9qU1jzsAGAf8oa8HRcS1wLV53j1IGyKL8uQ7gVeUX9LOyXsU7fEvpK2U0+q2yPvyW9LWzz9I2k3p4HNfb+aJpOMBq4AuSZ+iZktL0hmS9o6I7aRdfYDtkv5M0uG5b/1J0q739uf06oC8Qv0R8GVJu0sapXQwt+fA4deBj0o6SsmLJE2vX06juuueawWp++zzuX2OIK0ABuUbKxGxjLTX9fE86l+B90o6Jte+m6T/mQ+glv5fE0lbwGslTSV1bQzExcC7Jb06t+tUSYfmabcDp0sao3TA/k0tLG8hae/hXNK3kHra9/vAwZLekZc3RtKfSGq6hZ0D80DgaNJxsyNJxwKuIAVIUUT8ItLXnxvdfpFnvRw4TdIr8zGVc4Hv1O1d1tZ1lKTRSl9MWAB05z0NSO+XYyWdmD8PHwJWA3c3q3ekc1AMsbwyfA/pg/PHum6mHUQ6T+A00gHhB0m7yW/tY9brgR+StqQeAJ7m2V1BJwNLJK0jHSA+PYfU80lbWE+SPhA3krqjBsM7SX29S0nHS64FXpBf1zWk/vArSN0E3yUdhK/XqO56c0l98A8D/07qR//JIL0OgC8B8yTtExGLgbNJewNrgOWk40XN/l+fIX2teC2pu+c7AykkIn4LvBv4al7WjaQVPaR+9wNzXZ8htW+z5W3KtZxYO39e2Z5E6pZ6mNR99wXSFju5267PL2GQunj+IyJ+FxF/7LmR/oevU+Fch/6IiCWkL2BcTjrOMRF4f8/03JX7f2oech5pg2MZqY3OrlnWMtJXeL+Wp80BZuduqJ2aIvzDRWZm1pj3KMzMrMhBYWZmRQ4KMzMrclCYmVnRsDuPYvLkyTFjxox2l2FmNqzceuutqyNi7+Zz7mjYBcWMGTNYvHhxu8swMxtWJD0w0Me668nMzIocFGZmVuSgMDOzIgeFmZkVOSjMzKzIQWFmZkWVBYWkSyQ9KumuBtMl6Z+Vft/4TuXfSzYzs85S5R7FpaRLRDdyCnBQvs0j/UaDmZl1mMpOuIuIm/IvpjUyB/hmpOuc3yJpD0kvyD9409D27bB27SAWamZDQoKJE9NfG17aeWb2VJ79wzor87gdgiL/FOU8gClT9uNXv/KbzWw4euUrYbfd2l2F9dewuIRHRCwg/WwhhxwyK/bcE7qGReVm1mPVqnZXYAPVzm89PQRMq7m/bx5nZmYdpJ1B0Q28M3/76VhgbbPjE2ZmNvQq68CRdCVwAjBZ0krg08AYgIj4GrAQOJX0w/QbSD8Wb2ZmHabKbz3NbTI9gA9U9fxmZjY4fGa2mZkVOSjMzKzIQWFmZkUOCjMzK3JQmJlZkYPCzMyKHBRmZlbkoDAzsyIHhZmZFTkozMysyEFhZmZFDgozMytyUJiZWZF/J87MOlYEbN0K27alv/XDu+wCe+7Z7ipHPgeFmQ2ZLVtgw4beFX3Piv/pp2HTph1v27alsOjL5s0wfbqDYig4KMxsSETALbf0PW306Gffurpg3Lg03Mi6ddXUaTtyUJjZkNhnn/Y9d08XVs9tt91glI/QtsxBYWbD1rZt8MQTvQGwaVPf3VhbtqT5pTTfscfCXnu1t/bhxEFhZsPSmDHwxz/Co4/2jhs1qrfrqqcba+LEZ3dhrVrV+LjHc9HXgffa29NPp+MqPeG1eTMccghMmTL4tQw2B4WZDUvjxlW7ku1rZd9or2Xz5rTXUhtAUu99KYVXT5CNHg3r16fHDgcOCjPb6axf37sHsHHjs1f+PX9rV/K1AdDXXsuECeUD741qGC4cFGa2Uxk7Fu6+OwUA7Lji32WXtOLvmW4OCjPbyUya1O4Khh9/QczMzIocFGZmVuSgMDOzIgeFmZkV+WC2mVmb9XxVd8uW3r9btqTzMzZsSF/X7fkab0Q6s3z8+KGrz0FhZtYm990H99yTAqGvs8Vrv7rb1ZXCYc2adPb3UHJQmJm1wfOelwKiqytdZqTV8zbacX6Hg8LMrA1GjRra7qPnwgezzcysqNKgkHSypGWSlkv6RB/T95N0g6TbJN0p6dQq6zEzs/6rLCgkjQYuAE4BZgJzJc2sm+2TwNUR8TLgdODCquoxM7OBqXKP4mhgeUTcGxGbgauAOXXzBLB7Hp4EPFxhPWZmNgBVHsyeCqyoub8SOKZunvnAjyT9FbAbcGJfC5I0D5gHMGXKfoNeqJnZcNJzbsWWLen8ip6bBEceOfjfjGr3t57mApdGxJclHQdcJumwiNheO1NELAAWABxyyKwKfpvKzGz4uPXW3uFRo9Kv/XV1wbp16XyM4RQUDwHTau7vm8fVOgs4GSAifi1pPDAZeBQzM9vB5MmNp1X1Y0hVHqNYBBwkaX9JY0kHq7vr5nkQeDWApBcD44FVFdZkZmb9VFlQRMRW4BzgeuBu0reblkg6V9LsPNtHgLMl3QFcCZwZUcXPnpuZ2UBVeowiIhYCC+vGfapmeCnwiiprMDOz58ZnZpuZWZGDwszMihwUZmZW5KAwM7MiB4WZmRU5KMzMrMhBYWZmRQ4KMzMrclCYmVmRg8LMzIocFGZmVuSgMDOzIgeFmZkVOSjMzKzIQWFmZkUOCjMzK3JQmJlZkYPCzMyKHBRmZlbkoDAzsyIHhZmZFTkozMysyEFhZmZFDgozMytyUJiZWZGDwszMihwUZmZW5KAwM7MiB4WZmRU5KMzMrMhBYWZmRV2tzihpKjC99jERcVMVRZmZWedoKSgkfQF4K7AU2JZHB1AMCkknA+cBo4GvR8Q/9DHPW4D5eXl3RMTbWi3ezMyq1+oexeuBQyJiU6sLljQauAB4DbASWCSpOyKW1sxzEPA3wCsiYo2kfVov3czMhkKrxyjuBcb0c9lHA8sj4t6I2AxcBcypm+ds4IKIWAMQEY/28znMzKyBbdtgwwZ4/HGArtEDXU6rexQbgNsl/RR4Zq8iIv668JipwIqa+yuBY+rmORhA0i9J3VPzI+KHLdZkZmZ17rkHnnoq3TZuTOO2bwfYdfxAl9lqUHTn22DrAg4CTgD2BW6SdHhEPFE7k6R5wDyAKVP2q6AMM7Phb/RoeOABGDsWxo2DCRPS+C1bANBAl9tSUETENySNJe8BAMsiYkuThz0ETKu5v28eV2sl8Ju8rPsk/YEUHIvqnn8BsADgkENmRSs1m5ntbPbcs5rltnSMQtIJwH+RDk5fCPxB0quaPGwRcJCk/XPInM6OeyXfJe1NIGkyKYjubbV4MzOrXqtdT18GToqIZQCSDgauBI5q9ICI2CrpHOB60vGHSyJiiaRzgcUR0Z2nnSSp52u3H4uIxwb+cszMbLC1GhRjekICICL+IKnpt6AiYiGwsG7cp2qGA/hwvpmZWQdqNSgWS/o68K18/+3A4mpKMjOzTtJqULwP+ADQ83XYX5COVZiZ2QjX6reeNgFfyTczM9uJFINC0tUR8RZJvyNdi+lZIuKIyiozM7OO0GyP4oP57+uqLsTMzDpT8TyKiHgkD64GVkTEA8A44KXAwxXXZmZmHaDViwLeBIzPv0nxI+AdwKVVFWVmZp2j1aBQRGwA3gBcGBFvBl5SXVlmZtYpWg4KSceRzp/4QR434EvWmpnZ8NFqUHyI9AND/54vw3EAcEN1ZZmZWado9TyKG4Eba+7fS+/Jd2ZmNoI1O4/inyLiQ5K+R9/nUcyurDIzM+sIzfYoLst//7HqQszMrDMVgyIibs2Di4GNEbEdQNJo0vkUZmY2wrV6MPunwK4193cBfjL45ZiZWadpNSjGR8S6njt5eNfC/GZmNkK0GhTrJb28546ko4CN1ZRkZmadpNXfo/gQcI2khwEBzwfeWllVZmbWMVo9j2KRpEOBQ/KoZRGxpbqyzMysU7TU9SRpV+B/Ax+MiLuAGZJ86XEzs51Aq8co/g3YDByX7z8E/F0lFZmZWUdpNSgOjIgvAlsA8pVkVVlVZmbWMVoNis2SdiFfxkPSgcCmyqoyM7OO0eq3nj4N/BCYJuly4BXAmVUVZWZmnaNpUEgS8HvSjxYdS+py+mBErK64NjMz6wBNgyIiQtLCiDic3h8tMjOznUSrxyj+U9KfVFqJmZl1pFaPURwDnCHpfmA9qfspIuKIqgozM7PO0GpQvLbSKszMrGM1+4W78cB7gRcBvwMujoitQ1GYmZl1hmbHKL4BzCKFxCnAlyuvyMzMOkqzrqeZ+dtOSLoY+G31JZmZWSdptkfxzBVi3eVkZrZzahYUL5X0ZL49BRzRMyzpyWYLl3SypGWSlkv6RGG+N0oKSbP6+wLMzKxaxa6niBg90AVLGg1cALwGWAksktQdEUvr5psIfBD4zUCfy8zMqtPqCXcDcTSwPCLujYjNwFXAnD7m+yzwBeDpCmsxM7MBqjIopgIrau6vzOOekX+He1pEFC8NImmepMWSFq9du2rwKzUzs4aqDIoiSaOArwAfaTZvRCyIiFkRMWvSpL2rL87MzJ5RZVA8BEyrub9vHtdjInAY8PN8aZBjgW4f0DYz6yxVBsUi4CBJ+0saC5wOdPdMjIi1ETE5ImZExAzgFmB2RCyusCYzM+unyoIin3dxDnA9cDdwdUQskXSupNlVPa+ZmQ2uVi8KOCARsRBYWDfuUw3mPaHKWszMbGDadjDbzMyGBweFmZkVOSjMzKzIQWFmZkUOCjMzK3JQmJlZkYPCzMyKHBRmZlbkoDAzsyIHhZmZFTkozMysyEFhZmZFDgozMytyUJiZWZGDwszMihwUZmZW5KAwM7MiB4WZmRU5KMzMrMhBYWZmRQ4KMzMrclCYmVmRg8LMzIocFGZmVuSgMDOzIgeFmZkVOSjMzKzIQWFmZkUOCjMzK3JQmJlZkYPCzMyKHBRmZlZUaVBIOlnSMknLJX2ij+kflrRU0p2SfippepX1mJlZ/1UWFJJGAxcApwAzgbmSZtbNdhswKyKOAK4FvlhVPWZmNjBV7lEcDSyPiHsjYjNwFTCndoaIuCEiNuS7twD7VliPmZkNQJVBMRVYUXN/ZR7XyFnAdX1NkDRP0mJJi9euXTWIJZqZWTMdcTBb0hnALOBLfU2PiAURMSsiZk2atPfQFmdmtpPrqnDZDwHTau7vm8c9i6QTgb8Fjo+ITRXWY2ZmA1DlHsUi4CBJ+0saC5wOdNfOIOllwEXA7Ih4tMJazMxsgCoLiojYCpwDXA/cDVwdEUsknStpdp7tS8AE4BpJt0vqbrA4MzNrkyq7noiIhcDCunGfqhk+scrnNzOz564jDmabmVnnclCYmVmRg8LMzIocFGZmVuSgMDOzIgeFmZkVOSjMzKzIQWFmZkUOCjMzK3JQmJlZkYPCzMyKHBRmZlbkoDAzsyIHhZmZFTkozMysyEFhZmZFDgozMytyUJiZWZGDwszMihwUZmZW5KAwM7MiB4WZmRU5KMzMrMhBYWZmRQ4KMzMrclCYmVmRg8LMzIocFGZmVuSgMDOzIgeFmZkVOSjMzKzIQWFmZkUOCjMzK6o0KCSdLGmZpOWSPtHH9HGSvp2n/0bSjCrrMTOz/qssKCSNBi4ATgFmAnMlzayb7SxgTUS8CPgq8IWq6jEzs4HpqnDZRwPLI+JeAElXAXOApTXzzAHm5+FrgfMlKSKitOBNm2Dr1sEv2MxsJHqu68sqg2IqsKLm/krgmEbzRMRWSWuBvYDVtTNJmgfMy/c2H3/8xHuqKXm42fI8GLOm3VV0BrdFL7dFL7dFr/XTB/rIKoNi0ETEAmABgKTFEU/NanNJHSG1xdNuC9wWtdwWvdwWvSQtHuhjqzyY/RAwreb+vnlcn/NI6gImAY9VWJOZmfVTlUGxCDhI0v6SxgKnA91183QD78rDbwJ+1uz4hJmZDa3Kup7yMYdzgOuB0cAlEbFE0rnA4ojoBi4GLpO0HHicFCbNLKiq5mHIbdHLbdHLbdHLbdFrwG0hb8CbmVmJz8w2M7MiB4WZmRV1bFD48h+9WmiLD0taKulOST+VNODvS3e6Zm1RM98bJYWkEfvVyFbaQtJb8ntjiaQrhrrGodLCZ2Q/STdIui1/Tk5tR51Vk3SJpEcl3dVguiT9c26nOyW9vKUFR0TH3UgHv+8BDgDGAncAM+vmeT/wtTx8OvDtdtfdxrb4M2DXPPy+nbkt8nwTgZuAW4BZ7a67je+Lg4DbgOfl+/u0u+42tsUC4H15eCZwf7vrrqgtXgW8HLirwfRTgesAAccCv2lluZ26R/HM5T8iYjPQc/mPWnOAb+Tha4FXS9IQ1jhUmrZFRNwQERvy3VtI56yMRK28LwA+S7pu2NNDWdwQa6UtzgYuiIg1ABHx6BDXOFRaaYsAds/Dk4CHh7C+IRMRN5G+QdrIHOCbkdwC7CHpBc2W26lB0dflP6Y2micitgI9l/8YaVppi1pnkbYYRqKmbZF3padFxA+GsrA2aOV9cTBwsKRfSrpF0slDVt3QaqUt5gNnSFoJLAT+amhK6zj9XZ8Aw+QSHtYaSWcAs4Dj211LO0gaBXwFOLPNpXSKLlL30wmkvcybJB0eEU+0tar2mAtcGhFflnQc6fytwyJie7sLGw46dY/Cl//o1UpbIOlE4G+B2RGxaYhqG2rN2mIicBjwc0n3k/pgu0foAe1W3hcrge6I2BIR9wF/IAXHSNNKW5wFXA0QEb8GxgOTh6S6ztLS+qRepwaFL//Rq2lbSHoZcBEpJEZqPzQ0aYuIWBsRkyNiRkTMIB2vmR0RA74YWgdr5TPyXdLeBJImk7qi7h3KIodIK23xIPBqAEkvJgXFqiGtsjN0A+/M3346FlgbEY80e1BHdj1FdZf/GHZabIsvAROAa/Lx/AcjYnbbiq5Ii22xU2ixLa4HTpK0FNgGfCwiRtxed4tt8RHgXyX9L9KB7TNH4oalpCtJGweT8/GYTwNjACLia6TjM6cCy4ENwLtbWu4IbCszMxtEndr1ZGZmHcJBYWZmRQ4KMzMrclCYmVmRg8LMzIocFGZ1JG2TdLukuyR9T9Ieg7z8MyWdn4fnS/roYC7fbLA5KMx2tDEijoyIw0jn6Hyg3QWZtZODwqzs19RcNE3SxyQtytfy/0zN+HfmcXdIuiyPOy3/Vsptkn4iaUob6jd7zjryzGyzTiBpNOmyDxfn+yeRrpV0NOl6/t2SXkW6xtgngf8REasl7ZkXcTNwbESEpL8EPk46Q9hsWHFQmO1oF0m3k/Yk7gZ+nMeflG+35fsTSMHxUuCaiFgNEBE9vwewL/DtfL3/scB9Q1O+2eBy15PZjjZGxJHAdNKeQ88xCgGfz8cvjoyIF0XExYXl/D/g/Ig4HHgP6UJ0ZsOOg8KsgfyrgX8NfCRfyv564C8kTQCQNFXSPsDPgDdL2iuP7+l6mkTvJZzfhdkw5a4ns4KIuE3SncDciLgsX6L61/kqveuAM/KVSj8H3ChpG6lr6kzSr6pdI2kNKUz2b8drMHuufPVYMzMrcteTmZkVOSjMzKzIQWFmZkUOCjMzK3JQmJlZkYPCzMyKHBRmZlb0/wEDTF4j+aa3xgAAAABJRU5ErkJggg==\n", | |
"text/plain": [ | |
"<Figure size 432x288 with 1 Axes>" | |
] | |
}, | |
"metadata": { | |
"needs_background": "light" | |
}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"# Plot the precision / recall curve.\n", | |
"\n", | |
"y_conf = np.array([proba[1] for proba in model.predict_proba(X_test)])\n", | |
"\n", | |
"prec, recall, _ = precision_recall_curve(y_test, y_conf)\n", | |
"avg_prec = average_precision_score(y_test, y_conf)\n", | |
"\n", | |
"plt.fill_between(recall, prec, alpha=0.2, color='b')\n", | |
"\n", | |
"plt.xlabel('Recall')\n", | |
"plt.ylabel('Precision')\n", | |
"plt.ylim([0.0, 1.05])\n", | |
"plt.xlim([0.0, 1.0])\n", | |
"plt.title('2-class Precision-Recall curve: AP={0:0.2f}'.format(avg_prec))\n", | |
"\n", | |
"plt.show()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAswAAACPCAYAAADugo3xAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAIABJREFUeJzt3XuUVNWZ9/HvQwMCIiJgkItImyEalJtyiRKRoIKGTGNmNBBjBNTgilGZOCHRmChjMoYRM3mNGkejgLxqJPHNUoxEMaO9lKjIRZSrgIrapBVo5NIgze15/6iirabrnK4uTtf191mrV9fZ57br1Kl9ntpn733M3RERERERkeSaZTsDIiIiIiK5TAGziIiIiEgIBcwiIiIiIiEUMIuIiIiIhFDALCIiIiISQgGziIiIiEgIBcwiIiIiIiEUMIuIiIiIhFDALCIiIiISonm2M3C4Tp06ec+ePbOdDREREREpcEuWLNni7sc3tFzOBcw9e/Zk8eLF2c6GiIiIiBQ4M/sgleUabJJhZjPMbJOZrQiYb2b2WzNbb2Zvm9kZCfPGm9m6+N/41LMvIiIiIpIbUmnDPAu4MGT+RUCv+N8k4H4AM+sA3AYMAQYDt5nZcUeSWRERERGRTGuwSYa7v2xmPUMWGQPMdncHXjez9mbWBRgOvODuWwHM7AVigfcfjjTTTeKvN8X+XzStbvrimbD8yeD1+lwCAyc2Xb6i1tD7SUc2j0G676exeQ7bT9C20lknnf1HLROfZzrfq1z4rKMW1bGO+thEfb43Ng9Rnx9R57lY5MJxy0RZnYltpSOXy8kw6ebhhD71Y7AcEkUb5m7ARwnTFfG0oPR6zGwSsdppevToEUGW0vDx8uTpy5+MzTuhT/A6+RQwh72fdGT7GKTzftLJc9B+wraVzjqN3X/UMvV5pvO9yvZnHbUoj3XUxybK8z2dPER5fkSd52KRC8etCcvqffv2UVFRwZ49e4K35SfDl/8NSlrWTT+wF7wlrF6der7SEbT/sDyErRMk6veTTh4ASlo06TFt1aoV3bt3p0WLFmmtnxOd/tz9QeBBgIEDB3qWs1PfCX1g4rP102eOznxeohD0ftKRC8egse8n3Twn209D20pnncZsK2qZ/DzT+V5l87OOWtTHOupjE9X5nm4eojo/os5zsciF49aEZXVFRQXHHHMMPXv2xMySb2dLPETq1Ouw9HXJ06MWtP+wPIStE7ifiN9POnloYu5OVVUVFRUVlJaWprWNKMZh3gicmDDdPZ4WlC4iIiKSNXv27KFjx47BwbIUFDOjY8eO4XcUGhBFwDwXuCI+WsZXgO3uXgk8D4w0s+Pinf1GxtNEREREskrBcnE50s+7wSYZZvYHYh34OplZBbGRL1oAuPv/APOArwPrgd3AxPi8rWb2C2BRfFO3H+oAKCIiIiKSL1IZJePbDcx34AcB82YAM9LLmoiIiIiEmTJlCvPmzePrX/86X+x8DG1at+KKH/y4zjIbNmzgG9/4BitWJH2khqQgJzr9iYiIiEjjPfjgg2zdupWSkpLPO/AVqP3799O8eXZC1yjaMIuIiIhII8yePZu+ffvSr18/vvvd7wKxmuARI0bQt29fzvuXK/iw4h8ATJgwgRtuuIGzzz6bk08+mSfnPgdAWVkZ1dXVnHnmmcyZM4epd/6Wu+57GIAlS5bQr18/+vXrx3333Ve73wMHDjBlyhQGDRpE3759eeCRJwAoLy9n+PDhXHLJJZx66ql85zvfIdaIABYtWsTZZ59Nv379GDx4MDt37qy/nQceqPced+3axejRo+nXrx+nn346c+bMCdzenj17mDhxIn369GHAgAG89NJLAMyaNYuysjJGjBjBeeedB8D06dNr93vbbbeF7isqqmEWERGRovUfz6xk1T921J+x77PY/xZbUktP0LtrO27759MC569cuZJf/vKXvPrqq3Tq1ImtW2NdvK6//nrGjx/P+PHjmXH3r7jhp7/gqXnnAlBZWcmCBQtYs2YNZaMv4pKyC5k7dy5t27Zl2bJlAExdsqB2HxMnTuTee+9l2LBhTJkypTb94Ycf5thjj2XRokXU1NQwdMhARg4fCjTjzTffZOXKlXTt2pWhQ4fy97//ncGDBzN27FjmzJnDoEGD2LFjB61bt66/naFDGTlyJKXHfP4+n3vuObp27cqzz8aG+du+fTt79+5Nur27774bM2P58uWsWbOGkSNHsnbtWgCWLl3K22+/TYcOHZg/fz7r1q3jjTfewN0pKyvj5ZdfZvPmzfX2FSXVMIuIiIhk0Isvvsill15Kp06dAOjQoQMAr732GpdddhkA3/3WGBYsXFK7zsUXX0yzZs3o3bs3n2yuCt3+tm3b2LZtG8OGDYttK16DDTB//nxmz55N//79GTJkCFWfbmPdex8AMHjwYLp3706zZs3o378/GzZs4J133qFLly4MGjQIgHbt2tG8efP626mqYt26uk1C+vTpwwsvvMBPfvITXnnlFY499tjA7S1YsIDLL78cgFNPPZWTTjqpNmC+4IILao/R/PnzmT9/PgMGDOCMM85gzZo1rFu3Lum+oqQaZhERESlagTXBgQ8HydCDSw5z1FFH1b4+1FQiHe7OPffcw6hRo2IJ8fdTvmJjnX2UlJSwf//+1LdzSEI76i996UssXbqUefPm8bOf/YzzzjuPb37zm43O89FHH11nvzfffDPXXHNNveUO39ett97a6H0FUQ2ziIiISAaNGDGCP/3pT1RVxWqKDzXJOPvss3niiVib4seefIZzvjIwre23b9+e9u3bs2BBrInGY489Vjtv1KhR3H///ezbtw+Ate++z65duwO3dcopp1BZWcmiRbFRgnfu3Mn+/fvrb2ftWnbt2lVn3X/84x+0adOGyy+/nClTprB06dLA7Z1zzjm1+Vy7di0ffvghp5xySr38jBo1ihkzZlBdXQ3Axo0b2bRpU9J9RUk1zCIiIiIZdNppp3HLLbdw7rnnUlJSwoABA5g1axb33HMPEydOZPr06Rzf/mhm/nZa2vuYOXMmV155JWbGyJEja9OvvvpqNmzYwBlnnIG7c3z7tjw1+3eB22nZsiVz5szh+uuv57PPPqN169b87W9/q7+d44/nqaeeqrPu8uXLmTJlCs2aNaNFixbcf//9gdu79tpr+f73v0+fPn1o3rw5s2bNqlPjfcjIkSNZvXo1Z511FgBt27bl0UcfZf369fX2FSUFzCIiIiIZdqhzX6KTTjqJF198MTaR0LRh1qxZdZar/mDZ56/jNa0AU398Q+3rM888k7feeqt2+s477wSgWbNm3HHHHdxxxx119jN8+BkMHz68dvl777239vWgQYN4/fXX672HOts5ZMum2pejRo2q32QjZHszZ86slzZhwgQmTJhQJ23y5MlMnjy5TtoXv/jFpPuKippkiIiIiIiEUMAsIiIiIhJCAbOIiIiISAgFzCIiIiIiIRQwi4iIiIiEUMAsIiIiIhJCAbOIiIhIBm3bto3f/S547OMjNWvWLK677rrQZaZOncpdd93VqO22bdv2SLKV11IKmM3sQjN7x8zWm9lNSeb/xsyWxf/Wmtm2hHkHEubNjTLzIiIiIvkmLGAOexy1ZE+DAbOZlQD3ARcBvYFvm1nvxGXc/Yfu3t/d+wP3AH9OmP3ZoXnuXhZh3kVERETyzk033cS7775L//79mTJlCuXl5ZxzzjmUlZXRu3dvNmzYwOnnjK5d/q677mLq1KkAvPvuu1z4ras487xvcs4557BmzZrQfT3zzDMMGTKEAQMGcP755/PJJ5/Uznvrrbc466Jv0WvwBfz+97+vTZ8+fTqDBg2ib9++3HbbbfW2WVlZybBhw+jfvz+nn346r7zyyhEekdyXypP+BgPr3f09ADN7AhgDrApY/ttA/aMrIiIikmv+ehN8vLx++r7dsf8t2qSWnuiEPnBR8GOtp02bxooVK1i2LPbEvvLycpYuXcqKFSsoLS1lw4YNgetOmjSJ//nVz+n1xZ4sfHcr11577edPB0ziq1/9Kq+//jpmxkMPPcSdd97Jr3/9awDefvttXv/Lo+zavZsB51/C6NGjWbFiBevWreONN97A3SkrK+Pll19m2LBhtdt8/PHHGTVqFLfccgsHDhxg9+7dwceiQKQSMHcDPkqYrgCGJFvQzE4CSoHET66VmS0G9gPT3P2pZOuKiIiIFKvBgwdTWloaukx1dTWvvvoql14VfwR286OoqakJXaeiooKxY8dSWVnJ3r176+xjzJgxtG7ditatW/G1r32NN954gwULFjB//nwGDBhQu89169bVCZgHDRrElVdeyb59+7j44ovp379/mu86f6QSMDfGOOBJdz+QkHaSu280s5OBF81subu/m7iSmU0CJgH06NEj4iyJiIiIBAiqCd6yLva/U6/U0o/Q0UcfXfu6efPmHDx4sHZ6z549ABw8eJD27duzrHxuynm4/vrrufHGGykrK6O8vLy2aQeAmdVZ1sxwd26++WauueaawG0OGzaMl19+mWeffZYJEyZw4403csUVV6TyNvNWKp3+NgInJkx3j6clMw74Q2KCu2+M/38PKAcGHL6Suz/o7gPdfeDxxx+fQpZERERE8tMxxxzDzp07A+d37tyZTVuqqNr6KTU1NfzlL38BoF27dpSWlvKnp/8KgLvz1ltvhe5r+/btdOvWDYBHHnmkzrynn36aPXtqqNr6KeXl5QwaNIhRo0YxY8YMqqurAdi4cSObNm2qs94HH3xA586d+d73vsfVV1/N0qVLG3cA8lAqNcyLgF5mVkosUB4HXHb4QmZ2KnAc8FpC2nHAbnevMbNOwFDgzigyLiIiIpKPOnbsyNChQzn99NO56KKLGD16dJ35LVq04NZ/v47Boy6h24k9OfXUU2vnPfbYY3z/qvH88je/Y9/BZowbN45+/foF7mvq1KlceumlHHfccYwYMYL333+/dl7fvn352je/y5aqT/n5z39O165d6dq1K6tXr+ass84CYkPJPfroo3zhC1+oXa+8vJzp06fTokUL2rZty+zZs6M6NDmrwYDZ3feb2XXA80AJMMPdV5rZ7cBidz80VNw44Al394TVvww8YGYHidVmT3P3oM6CIiIiIkXh8ccfrzM9fPjwOtM3TLqCGyZdUa/ZRWlpKc/98eHYRECTjAkTJjBhwgQg1k55zJgx9ZapbZqRpInJ5MmTmTx5cr11DtU6jx8/nvHjxyfdd6FKqQ2zu88D5h2Wduth01OTrPcq0OcI8iciIiIiklVRd/oTERGRAvbJzj1sqa7h9gdeqzdvTP9uXDZEnfel8OjR2CIiIpKyLdU17N57oF76qsodPL0saEyA3FO3BakUuiP9vFXDLCIiIo3SpmUJc645q07a2CQ1zrmqVatWVFVV0bFjx3pDq0nhcXeqqqpo1apV2ttQwCwiIiJFpXv37lRUVLB58+bAZQ7s+ISD7mx7b1ed9PYHP6WZGSWb9zdtJqvjQ7kl20/QvLB10tlPOqLeXkRatWpF9+7d015fAbOIiIgUlRYtWjT4VL2Vd3yPmr0HuK/Lf9dJ/1HlT2nTsoQv/3RBU2YRZv4o9n/is6nPC1snnf2kI+rt5QgFzCIiIiJJJGt6svKOkizlRrJJnf5EREREREKohllEpIA9vvDDpCMX3Fq1nU5tj6JzFvIkxSfoPAQNRSf5QTXMIiIF7OllG1lVuaNe+u69B9hSXZOFHEkxCjoP820oOileqmEWESlwvbu0UztMybpk52E+DUUnxU01zCIiIiIiIRQwi4iIiIiEUMAsIiIiIhJCAbOIiIiISAgFzCIiIiIiITRKhoiIiEgB+GTnHrZU13B7wOgjGvM6fSnVMJvZhWb2jpmtN7ObksyfYGabzWxZ/O/qhHnjzWxd/G98lJkXERERkZgt1TXs3nsg6TyNeX1kGqxhNrMS4D7gAqACWGRmc9191WGLznH36w5btwNwGzAQcGBJfN1PI8m9iIiIiNRq07Kk3njXoDGvj1QqNcyDgfXu/p677wWeAMakuP1RwAvuvjUeJL8AXJheVkVEREREMi+VNszdgI8SpiuAIUmW+1czGwasBX7o7h8FrNvt8BXNbBIwCaBHD7WtEREpFI8v/DDwNvCtVdvp1PYoOmc4TyIijRVVp79ngD+4e42ZXQM8AoxIdWV3fxB4EGDgwIEeUZ5ERCRDgjobLXx/KwBDSjvUW2f33gNsqa5RwCwiOS+VgHkjcGLCdPd4Wi13r0qYfAi4M2Hd4YetW97YTIqISG4L6mw0pLRDYM/8lXeUZCJrOSus9h00ooFILkklYF4E9DKzUmIB8DjgssQFzKyLu1fGJ8uA1fHXzwN3mNlx8emRwM1HnGsREck5QZ2NJLmnl21kVeUOendpV2/eqsodAAqYRXJEgwGzu+83s+uIBb8lwAx3X2lmtwOL3X0ucIOZlQH7ga3AhPi6W83sF8SCboDb3X1rE7wPERGRvNO7SzuNaCCSB1Jqw+zu84B5h6XdmvD6ZgJqjt19BjDjCPIoIiIiIpI1ejS2iIiIiEgIPRpbRERERCIR9njufO7IqhpmEREREYlE0Ig5+f5obtUwi4iI5AkNRSf5INmIOfnekVU1zCIiInni0FB0yeR7DZ5ILlMNs4iISB7RUHSSrlWVO+qdJ3pEfWoUMEtBCrttqcJBRESKzZj+3ZKm6xH1qVHALAUp7AlaKhxERKTYXDakhx5RfwQUMEvBCrptqcJBRPJBstvnQRUB6Qq6G6c7cSJ1KWAWERFJkAsjUQTdPu/dpV3gvHQE3Y3TnTiRuhQwi4iIJAhr0nVohIqmDpiDbp83hWR343QnTqQuBcwieS6sNkxjsmaGPoPCEzYSRbKmEqDPOl1Bx1PNQiSXaBxmkTwXNC6rxmTNHH0GxWNM/26BNc/6rBsv6HjC581CRHKBaphFCkCy2rCwmjBQbVjUgj4DKSxBTSX0WacnrOmJmoVILlHALFKgwjoGZaodZpQ+2bmHLdU13K5b4SIikmEKmEUKVFjNTT7Whm2prmH33gP10vMx+JfiEtTGPeoh4kSk6aQUMJvZhcDdQAnwkLtPO2z+jcDVwH5gM3Clu38Qn3cAWB5f9EN3L4so7yKSp8Jqi8M6+rRpWaJmD5J3gkbdiHqIOJFcF9ZMsHfXdtz2z6dlOEepazBgNrMS4D7gAqACWGRmc919VcJibwID3X23mX0fuBMYG5/3mbv3jzjfIpLHgmqLIb3xX9XLXnJd0KgbIsUi338cplLDPBhY7+7vAZjZE8AYoDZgdveXEpZ/Hbg8ykyKSOFJVlsMje/oE1YIpxN8R91WWsG8ZJKGvJNclcmxxZtCKgFzN+CjhOkKYEjI8lcBf02YbmVmi4k115jm7k8dvoKZTQImAfTokb8HUySfJLuw5mMQF3Uv+6Da74Xvb2Xh+1sb1RY16mBeJEzQ+aZ2/iJHLtJOf2Z2OTAQODch+SR332hmJwMvmtlyd383cT13fxB4EGDgwIEeZZ5EpL6gC6uCuJhktd9hDycJaouqIbOKS1DtbqY692nIO5Gmk0rAvBE4MWG6ezytDjM7H7gFONfda0cad/eN8f/vmVk5MAB49/D1RdIRdPtcvc/DBV1YFcQFy/fbiYUuqCzI1F2TsLsJ6twnkv9SCZgXAb3MrJRYoDwOuCxxATMbADwAXOjumxLSjwN2u3uNmXUChhLrECgSiaDb51FfoLJ9MRYJE9buulh+PAaVBZm6a6IfVCKFrcGA2d33m9l1wPPEhpWb4e4rzex2YLG7zwWmA22BP5kZfD583JeBB8zsILHHcE87bHQNiQu73RtGHTmCO49FKdsXY5EwYaOOFFPtZrKyQHdNRCQKKbVhdvd5wLzD0m5NeH1+wHqvAn2OJIPFImiczjDqyJFZuhhLLsvED0cRkWKlJ/3lkMaO05nLHTkaqjFXzbiIiIjkCwXMcRuqdgHQM7vZKBhhNeaqGRdJT9AP0R/tPUCbltHd7VCbfRGRuhQwx+3auz/bWSg4QTXmuVwzHiUFHZkL8IpF0A/RNi1L6NT2qMj2k+02++k+Ol1EpKkoYC4yYReiMGpC0XhRBh35OgpCpgK8YpL0h+jMYyPfTzbb7Ef96HTJX7v3HmjUkwtVUSFNRQFzkQm7EAUJe8KZCqFwUQUd+TwKQjoBXrKLpGqli0tUj06H4KBL5Vf0onx4S6e2R7GluqZeelizvmzfHZHCpYC5CDW2N31YBz4VQplTLKMgBF0kVSst6Qg6n6C4yq9MPIUw6oe3dD6mFZ2PacWciXXLvYaa9TWmoiLs+pbtH+lqmpRbFDAXqCjbjurxvpmTC21+g2rjIDNNc4Iukk3R7EAKX+D5RKz8Sna+53Izp3Rk6imE+fjwlrAO6lH+SE9n5Cg1TcotCpgzLOhLE3UBrbaj+Snbn1tYbVxY0xy1cZd0ZbP5TdD5nuvNnBorHwPZTAoc0jXCH+npjhwVZdMkOTIKmI9A2O2SoAAi6EuTbgHd0G22THQOykeN7UiSSdn83MJq48J+7IGGCZTGy3bzm7DzXSQdDVWKFfPIUflOAfMRCLpd0lAA0dgHlATJ1G22QpNORxIJrqVSYS/pUvMbKTRRV4qFUefozFLAfISS3S4Z+8BrGelgodts6QnrSBL0uUVdCAXVcBda20kRkWITVaVYmGzfnSlGCpibgGp+GxYUmGazSUTY5xJlIRTWTljnh0jDVLOWGUE/7HWssy9Td2fCOivmQhPGTFLA3ICwdspBhYZqfsMFBYTZbhIR+rmFFEKNvaio3WT25fJQUrku20GUatYyI+yHfS4c63TuBhbDiCjpCvpeL3x/KwBDSjvUSc/29TobFDA3IGxYl1woNPJRWFvYXKx5DpPrFxVJLlNDSTUkbAi/ZLJ9cc+F813tnjMj9Id9lo91OncDox4RpZCa1YV9r4eUdkh6/S3GvisKmBOE3eZr6iFn8llUv9qDCq2w4cyyXRuYyxcVCZeJoaTChF2kgkTdZCfSuyM63yVD0rkbGOWdvaZoVpessihT1zfd9UxNSgGzmV0I3A2UAA+5+7TD5h8FzAbOBKqAse6+IT7vZuAq4ABwg7s/H1nuI6TbfOmJ8ld7UCEYdvtcn480JFdvw2b7IpULtcUi+Sjq727QtVLfw9zSYMBsZiXAfcAFQAWwyMzmuvuqhMWuAj51938ys3HAfwFjzaw3MA44DegK/M3MvuTuyds4ZJFu86UnExf9dNsWS+ZkYlSYdBTLgynSodpikdwQeI3L8e9hpkaVyhWp1DAPBta7+3sAZvYEMAZIDJjHAFPjr58E7jUzi6c/4e41wPtmtj6+veJr/CJSoHJ5VJhs1+KKiBSiTI0qlUtSCZi7AR8lTFcAQ4KWcff9ZrYd6BhPf/2wdXO3Wufj5TBzdP20E/o0bp1c1tD7SXeb2ToG6b6fxuY5bD9B20pnnXT2H7VG5u0y4LKWIQusou7P60P7aOz3Khc+66hF9d2J+thEfb43Ng9Rnx9R57lY5MJxy0RZnYltpSOHy8nQct8+hGMK73uVE53+zGwSMAmgR48sjYTQ55Lk6Sf0CZ4XlJ7Lwt5POrJ9DNJ5P+nkOWg/YdtKZ53Gbitqmfo80/leZfuzjlqU+4j62ER5vqeThyjPj6jzXCxy4bhlqqzOxLbSkcvlZJgC/V6Zu4cvYHYWMNXdR8WnbwZw918lLPN8fJnXzKw58DFwPHBT4rKJywXtb+DAgb548eIjelMiIiIiIg0xsyXuPrCh5ZqlsK1FQC8zKzWzlsQ68c09bJm5wPj460uAFz0Wic8FxpnZUWZWCvQC3kj1TYiIiIiIZFuDTTLibZKvA54nNqzcDHdfaWa3A4vdfS7wMPB/4536thILqokv90diLRj3Az/IxREyRERERESCNNgkI9PUJENEREREMiHVJhk5FzCb2WbggyztvhOwJUv7ltygc0BA54HoHJAYnQeF7yR3P76hhXIuYM4mM1ucyq8MKVw6BwR0HojOAYnReSCHpNLpT0RERESkaClgFhEREREJoYC5rgeznQHJOp0DAjoPROeAxOg8EEBtmEVEREREQqmGWUREREQkhAJmwMwuNLN3zGy9md2U7fxIZpjZiWb2kpmtMrOVZjY5nt7BzF4ws3Xx/8dlO6/StMysxMzeNLO/xKdLzWxhvEyYE3/KqRQwM2tvZk+a2RozW21mZ6ksKC5m9sP4tWCFmf3BzFqpLJBDij5gNrMS4D7gIqA38G0z653dXEmG7Af+3d17A18BfhD/7G8C/tfdewH/G5+WwjYZWJ0w/V/Ab9z9n4BPgauykivJpLuB59z9VKAfsfNBZUGRMLNuwA3AQHc/ndiTjcehskDiij5gBgYD6939PXffCzwBjMlyniQD3L3S3ZfGX+8kdoHsRuzzfyS+2CPAxdnJoWSCmXUHRgMPxacNGAE8GV9E50CBM7NjgWHAwwDuvtfdt6GyoNg0B1qbWXOgDVCJygKJU8AcC5A+SpiuiKdJETGznsAAYCHQ2d0r47M+BjpnKVuSGf8H+DFwMD7dEdjm7vvj0yoTCl8psBmYGW+a85CZHY3KgqLh7huBu4APiQXK24ElqCyQOAXMUvTMrC3w/4B/c/cdifM8NoyMhpIpUGb2DWCTuy/Jdl4kq5oDZwD3u/sAYBeHNb9QWVDY4u3TxxD78dQVOBq4MKuZkpyigBk2AicmTHePp0kRMLMWxILlx9z9z/HkT8ysS3x+F2BTtvInTW4oUGZmG4g1xxpBrC1r+/htWVCZUAwqgAp3XxiffpJYAK2yoHicD7zv7pvdfR/wZ2Llg8oCARQwAywCesV7wrYk1sh/bpbzJBkQb6v6MLDa3f87YdZcYHz89Xjg6UznTTLD3W929+7u3pPYd/9Fd/8O8BJwSXwxnQMFzt0/Bj4ys1PiSecBq1BZUEw+BL5iZm3i14ZD54DKAgH04BIAzOzrxNoxlgAz3P0/s5wlyQAz+yrwCrCcz9uv/pRYO+Y/Aj2AD4BvufvWrGRSMsbMhgM/cvdvmNnJxGqcOwBvApe7e0028ydNy8z6E+v42RJ4D5hIrFJJZUGRMLP/AMYSG0HpTeBqYm2WVRYpmKccAAAAUUlEQVSIAmYRERERkTBqkiEiIiIiEkIBs4iIiIhICAXMIiIiIiIhFDCLiIiIiIRQwCwiIiIiEkIBs4iIiIhICAXMIiIiIiIhFDCLiIiIiIT4/yztKXF7+AkqAAAAAElFTkSuQmCC\n", | |
"text/plain": [ | |
"<Figure size 864x144 with 1 Axes>" | |
] | |
}, | |
"metadata": { | |
"needs_background": "light" | |
}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"# This visualization is less useful for understanding the performance\n", | |
"# quality, but can help show how well the classifier tends to\n", | |
"# differentiate across classes.\n", | |
"\n", | |
"n = len(y_test)\n", | |
"\n", | |
"plt.figure(figsize=(12, 2))\n", | |
"plt.step(np.arange(n), y_conf, label='confidence scores')\n", | |
"plt.step(np.arange(n), y_test, label='true labels')\n", | |
"plt.legend()\n", | |
"plt.show()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Choose a Cutoff Value\n", | |
"\n", | |
"This is the example described at the top of the notebook." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Here is the output of precision_recall_curve():\n", | |
"#\n", | |
"# * prec - An array (ndarray) of precision values.\n", | |
"# * recall - An array of recall values.\n", | |
"# * cutoffs - An array of cutoff values.\n", | |
"#\n", | |
"# They are related in that, at any given index i, using the\n", | |
"# cutoff value cutoffs[i] will result in the precision value\n", | |
"# prec[i] and the recall value recall[i].\n", | |
"#\n", | |
"# The cutoff value is interepreted like this:\n", | |
"#\n", | |
"# has_label = (confidence_score >= cutoff),\n", | |
"#\n", | |
"# where confidence_score is the output value of the model.\n", | |
"# (In the above code, this value comes from calling predict_proba().)\n", | |
"\n", | |
"prec, recall, cutoffs = precision_recall_curve(y_test, y_conf)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(array([0.05897802, 0.06008123, 0.06088365, 0.06103349, 0.06228476]),\n", | |
" array([0.53897053, 0.61083472, 0.61240123, 0.65266008, 0.68443149]))" | |
] | |
}, | |
"execution_count": 15, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# cutoffs increase; so we typically start with low precision and move up\n", | |
"# to higher precision (but not always).\n", | |
"cutoffs[:5], cutoffs[-5:]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([False, False, False, False, False, False, False, False, False,\n", | |
" False, False, False, False, False, False, False, False, False,\n", | |
" False, False, False, False, False, False, False, False, False,\n", | |
" False, False, False, False, False, False, False, False, False,\n", | |
" False, False, False, False, False, False, False, False, True,\n", | |
" True, True, True, True, True, True, True, True, True,\n", | |
" True, True, True, True, True, True, True, True, True,\n", | |
" True, True, True, True, True, True, True, True, True,\n", | |
" True, True, True, True, True, True, True, True, True])" | |
] | |
}, | |
"execution_count": 16, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# Generate a boolean mask where the precision is good enough for us.\n", | |
"prec > 0.95" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"0.19990406471510586" | |
] | |
}, | |
"execution_count": 17, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# Arrive at exactly the cutoff value we want.\n", | |
"# This works because flatnonzero() returns all the indexes\n", | |
"# where (prec > 0.95), in order. The first value returned\n", | |
"# by flatnonzero() will correspond with the highest recall\n", | |
"# because it is the most-inclusive cutoff (the lowest cutoff)\n", | |
"# within the prec > 0.95 constraint.\n", | |
"idx = np.flatnonzero(prec > 0.95)[0]\n", | |
"cutoff = cutoffs[idx]\n", | |
"cutoff" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(0.9722222222222222, 0.7142857142857143)" | |
] | |
}, | |
"execution_count": 18, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# Sanity check. This should match our intuition with the\n", | |
"# precision-recall plot above.\n", | |
"prec[idx], recall[idx]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"0.9722222222222222" | |
] | |
}, | |
"execution_count": 19, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# Double sanity check.\n", | |
"# Let's hand evaluate the precision.\n", | |
"\n", | |
"y_predict = np.array([int(proba[1] >= cutoff) for proba in model.predict_proba(X_test)])\n", | |
"prec = sum(y_test[y_predict == 1]) / sum(y_predict)\n", | |
"prec" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "generic", | |
"language": "python", | |
"name": "generic" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.7.2" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment