Skip to content

Instantly share code, notes, and snippets.

@tylerneylon
Created May 30, 2019 20:33
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save tylerneylon/84950b1b436f25508d7b5f903539d276 to your computer and use it in GitHub Desktop.
Save tylerneylon/84950b1b436f25508d7b5f903539d276 to your computer and use it in GitHub Desktop.
Example of visualizing and extracting classifier cutoff values to meet certain precision-recall objectives.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Example: Finding a Cutoff for 95% Precision\n",
"\n",
"The purpose of this notebook is to demonstrate programmatically\n",
"choosing a cutoff value for a binary classifier such that:\n",
"(a) we achieve 95% or higher precision if possible, and (b)\n",
"within that constraint, we maximize the recall."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Imports"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"\n",
"from sklearn.datasets import load_iris\n",
"from sklearn.linear_model import LogisticRegression\n",
"from sklearn.metrics import average_precision_score\n",
"from sklearn.metrics import confusion_matrix\n",
"from sklearn.metrics import precision_recall_curve\n",
"from sklearn.model_selection import train_test_split"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Set the Random Seed\n",
"\n",
"This is so that anyone can re-run this notebook\n",
"and arrive at the same results."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"np.random.seed(2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load in Sample Data\n",
"\n",
"We'll work with the classic iris data set.\n",
"The problem here is to use some measurements of\n",
"different parts of a flower to identify which\n",
"species of iris it is.\n",
"\n",
"I'll filter the data down to focus on only two\n",
"classes, so that we can look at this in the context\n",
"of binary classifiers."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"((150, 4), (150,), array([0, 1, 2]))"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"iris_data = load_iris()\n",
"X = iris_data['data']\n",
"y = iris_data['target']\n",
"\n",
"# See some basic info about the raw dataset.\n",
"X.shape, y.shape, np.unique(y)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"# We have 150 data points with 4 features and 3 classes (0, 1, and 2)."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"((100, 4), (100,))"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Filter down to two classes.\n",
"X = X[y > 0]\n",
"y = y[y > 0] - 1 # Map [1, 2] to [0, 1].\n",
"X.shape, y.shape"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([0, 1])"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Sanity check that we have the 2 standard class labels 0 and 1.\n",
"np.unique(y)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Split into Train / Test Data"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.95)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Train a Basic Classifier\n",
"\n",
"I'll use logistic regression to keep this similar to our own case."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"model = LogisticRegression(solver='lbfgs').fit(X_train, y_train)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Check the Accuracy, Precision, and Recall\n",
"\n",
"I'm not yet choosing a cutoff, but showing how well the classifier is doing."
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"y_predict = np.array([int(proba[1] > 0.5) for proba in model.predict_proba(X_test)])"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[46, 0],\n",
" [44, 5]])"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Each row here is a true class, and each column is a predicted class.\n",
"# The best possible outcome is a matrix with zeros everywhere except\n",
"# along the diagonal, which corresponds to perfect classification.\n",
"confusion_matrix(y_test, y_predict)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.5368421052631579"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Check the accuracy.\n",
"sum(y_predict == y_test) / len(y_test)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# Plot the precision / recall curve.\n",
"\n",
"y_conf = np.array([proba[1] for proba in model.predict_proba(X_test)])\n",
"\n",
"prec, recall, _ = precision_recall_curve(y_test, y_conf)\n",
"avg_prec = average_precision_score(y_test, y_conf)\n",
"\n",
"plt.fill_between(recall, prec, alpha=0.2, color='b')\n",
"\n",
"plt.xlabel('Recall')\n",
"plt.ylabel('Precision')\n",
"plt.ylim([0.0, 1.05])\n",
"plt.xlim([0.0, 1.0])\n",
"plt.title('2-class Precision-Recall curve: AP={0:0.2f}'.format(avg_prec))\n",
"\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 864x144 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# This visualization is less useful for understanding the performance\n",
"# quality, but can help show how well the classifier tends to\n",
"# differentiate across classes.\n",
"\n",
"n = len(y_test)\n",
"\n",
"plt.figure(figsize=(12, 2))\n",
"plt.step(np.arange(n), y_conf, label='confidence scores')\n",
"plt.step(np.arange(n), y_test, label='true labels')\n",
"plt.legend()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Choose a Cutoff Value\n",
"\n",
"This is the example described at the top of the notebook."
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"# Here is the output of precision_recall_curve():\n",
"#\n",
"# * prec - An array (ndarray) of precision values.\n",
"# * recall - An array of recall values.\n",
"# * cutoffs - An array of cutoff values.\n",
"#\n",
"# They are related in that, at any given index i, using the\n",
"# cutoff value cutoffs[i] will result in the precision value\n",
"# prec[i] and the recall value recall[i].\n",
"#\n",
"# The cutoff value is interepreted like this:\n",
"#\n",
"# has_label = (confidence_score >= cutoff),\n",
"#\n",
"# where confidence_score is the output value of the model.\n",
"# (In the above code, this value comes from calling predict_proba().)\n",
"\n",
"prec, recall, cutoffs = precision_recall_curve(y_test, y_conf)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(array([0.05897802, 0.06008123, 0.06088365, 0.06103349, 0.06228476]),\n",
" array([0.53897053, 0.61083472, 0.61240123, 0.65266008, 0.68443149]))"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# cutoffs increase; so we typically start with low precision and move up\n",
"# to higher precision (but not always).\n",
"cutoffs[:5], cutoffs[-5:]"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([False, False, False, False, False, False, False, False, False,\n",
" False, False, False, False, False, False, False, False, False,\n",
" False, False, False, False, False, False, False, False, False,\n",
" False, False, False, False, False, False, False, False, False,\n",
" False, False, False, False, False, False, False, False, True,\n",
" True, True, True, True, True, True, True, True, True,\n",
" True, True, True, True, True, True, True, True, True,\n",
" True, True, True, True, True, True, True, True, True,\n",
" True, True, True, True, True, True, True, True, True])"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Generate a boolean mask where the precision is good enough for us.\n",
"prec > 0.95"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.19990406471510586"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Arrive at exactly the cutoff value we want.\n",
"# This works because flatnonzero() returns all the indexes\n",
"# where (prec > 0.95), in order. The first value returned\n",
"# by flatnonzero() will correspond with the highest recall\n",
"# because it is the most-inclusive cutoff (the lowest cutoff)\n",
"# within the prec > 0.95 constraint.\n",
"idx = np.flatnonzero(prec > 0.95)[0]\n",
"cutoff = cutoffs[idx]\n",
"cutoff"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(0.9722222222222222, 0.7142857142857143)"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Sanity check. This should match our intuition with the\n",
"# precision-recall plot above.\n",
"prec[idx], recall[idx]"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.9722222222222222"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Double sanity check.\n",
"# Let's hand evaluate the precision.\n",
"\n",
"y_predict = np.array([int(proba[1] >= cutoff) for proba in model.predict_proba(X_test)])\n",
"prec = sum(y_test[y_predict == 1]) / sum(y_predict)\n",
"prec"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "generic",
"language": "python",
"name": "generic"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment