Skip to content

Instantly share code, notes, and snippets.

@xiaohk
Last active November 6, 2018 23:41
Show Gist options
  • Save xiaohk/698ba7c174768a519d147aaea67dc1a0 to your computer and use it in GitHub Desktop.
Save xiaohk/698ba7c174768a519d147aaea67dc1a0 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Precision-recall curve and Average Precision"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"scikit-learn version: 0.20.0\n",
"matplotlib version: 2.1.2\n"
]
}
],
"source": [
"import numpy as np\n",
"import sklearn\n",
"import matplotlib\n",
"from sklearn import metrics\n",
"import matplotlib.pyplot as plt\n",
"from sklearn.utils.fixes import signature\n",
"\n",
"print('scikit-learn version: {}'.format(sklearn.__version__))\n",
"print('matplotlib version: {}'.format(matplotlib.__version__))"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"20 20 20\n"
]
}
],
"source": [
"model_1_score = [0.4325, 0.3498, 0.2368, 0.2601, 0.1698, 0.211 , 0.1913, 0.8441,\n",
" 0.098 , 0.1682, 0.1844, 0.3937, 0.1746, 0.295 , 0.3164, 0.1856,\n",
" 0.1353, 0.2147, 0.1898, 0.2257]\n",
"model_2_score = [0.502, 0.502, 0.502, 0.502, 0.502, 0.502, 0.502, 0.502,\n",
" 0.502, 0.502, 0.502, 0.502, 0.502, 0.502, 0.502, 0.502,\n",
" 0.502, 0.502, 0.502, 0.502]\n",
"y_true = [1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]\n",
"\n",
"print(len(model_1_score), len(model_2_score), len(y_true))"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model_1 AP = 0.30556\n",
"Model_2 AP = 0.10000\n"
]
}
],
"source": [
"model_1_ap = metrics.average_precision_score(y_true, model_1_score)\n",
"model_2_ap = metrics.average_precision_score(y_true, model_2_score)\n",
"print(\"Model_1 AP = {:.5f}\".format(model_1_ap))\n",
"print(\"Model_2 AP = {:.5f}\".format(model_2_ap))"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[1. 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.\n",
" 0. ]\n",
"[0.11111111 0.05882353 0.0625 0.06666667 0.07142857 0.07692308\n",
" 0.08333333 0.09090909 0.1 0.11111111 0.125 0.14285714\n",
" 0.16666667 0.2 0.25 0.33333333 0.5 0.\n",
" 1. ]\n"
]
}
],
"source": [
"precision_1, recall_1, thresholds_1 = metrics.precision_recall_curve(y_true, model_1_score)\n",
"print(recall_1)\n",
"print(precision_1)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[1. 0.]\n",
"[0.1 1. ]\n"
]
}
],
"source": [
"precision_2, recall_2, thresholds_2 = metrics.precision_recall_curve(y_true, model_2_score)\n",
"print(recall_2)\n",
"print(precision_2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Use `line plot` to plot PR-curves"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAEKCAYAAAD9xUlFAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAIABJREFUeJzt3Xl0nOV96PHvb/ZNo12yLFmWbXmRN2wjMGYxi21CCIUkkIADYW24bW/anpum96bn3tOkuT3p7e1pTpte0hYChEBYkkBTJ4EQbOzYGAw2GIx3y6tGi7VY6yxan/vHOxaysa2xpdHMSL/POTrWO/POzPPK0vt7n+f5vb9HjDEopZRSALZUN0AppVT60KCglFJqiAYFpZRSQzQoKKWUGqJBQSml1BANCkoppYZoUFBKKTVEg4JSSqkhGhSUUkoNcaS6AReroKDAVFRUpLoZSimVUd5///0WY0zhSPtlXFCoqKhgx44dqW6GUkplFBE5nsh+OnyklFJqiAYFpZRSQzQoKKWUGpJxcwpKqcmlr6+PUChELBZLdVMygsfjoaysDKfTeUmv16CglEproVCIrKwsKioqEJFUNyetGWNobW0lFAoxY8aMS3qPpA0fichTItIkIrvP87yIyA9EpEZEdonIsmS1RSmVuWKxGPn5+RoQEiAi5Ofnj6pXlcw5hR8Dt1zg+c8Cs+NfjwL/msS2KKUymAaExI32Z5W0oGCM2QycusAudwA/MZZtQI6IlCSrPS3NHWzZ20hf/0CyPkIppTJeKrOPSoHaYduh+GOfIiKPisgOEdnR3Nx8SR9WcyrG1h4HT+xvZ19bD7o2tVJKfVpGpKQaYx43xlQbY6oLC0e8S/ucrvIZ7g034LHb+M9jXTxf00FTtH+MW6qUmohEhPvuu29ou7+/n8LCQm677baLep+KigpaWlpGtc/DDz9MUVERCxcuvKjPTlQqg0IdMG3Ydln8saSZNtDDg7Oz+Mw0Py3RAZ7e387varuJ9g8m82OVUhnO7/eze/duotEoAG+88Qalpecc2Ei6Bx98kN/+9rdJe/9UpqSuA74uIi8Cy4EOY0xDsj/UJsLSAi/zcty81Rjhg+YYe9t6WFniY0mBB5tOaCmVvmpOQHdkbN8z4IPK8hF3u/XWW/nNb37DXXfdxQsvvMDatWvZsmULAKdOneLhhx/myJEj+Hw+Hn/8cRYvXkxraytr166lrq6OFStWnDFs/dxzz/GDH/yA3t5eli9fzg9/+EPsdvuI7Vi5ciXHjh275MMdSTJTUl8A3gHmikhIRB4RkT8SkT+K7/IqcASoAZ4A/iRZbTkXr8PGmrIAD83Locjr4HehME/vb+dEV994NkMplSHuueceXnzxRWKxGLt27WL58uVDz337299m6dKl7Nq1i+9973vcf//9APzN3/wN1157LXv27OELX/gCJ06cAGDfvn289NJLbN26lQ8//BC73c5Pf/rTlBzX2ZLWUzDGrB3heQP812R9fqKKvA7WVgY50NHLm6Ewz9d0MC/HxY2lfrJdI0dtpdQ4SuCKPlkWL17MsWPHeOGFF7j11lvPeO6tt97i5ZdfBuCmm26itbWVzs5ONm/ezCuvvALA5z73OXJzcwHYsGED77//PldccQUA0WiUoqKicTya89M7mrEmkebluJkVdPHuySjbTkao6ejlqmIfy4u9OG06pKSUgttvv51vfvObbNq0idbW1kt+H2MMDzzwAH/3d383hq0bGxmRfTRenDbh2hIfX5ufS2W2i7caIzyxr40D7ZrCqpSyMn++/e1vs2jRojMev+6664aGfzZt2kRBQQHBYJCVK1fy/PPPA/Daa6/R1tYGwKpVq/jFL35BU1MTYM1JHD+e0HIHSadB4RyyXXY+PyPI2sogbpvwH0e7eLGmk2ZNYVVqUisrK+PP/uzPPvX4d77zHd5//30WL17Mt771LZ555hnAmmvYvHkzCxYs4JVXXqG83Br+mj9/Pn/7t3/LzTffzOLFi1mzZg0NDYnl2axdu5YVK1Zw4MABysrKePLJJ8fuAAHJtCvg6upqc0krr9U2wpEQXLsUEpjhP23QGHa2xNjSEKFnwLCs0MN1U3x4HBpPlRoP+/bto6qqKtXNyCjn+pmJyPvGmOqRXqtzCiOwiXB5oZf5uW42N3ySwnp9iZ/F+W5NYVVKTSgaFBLkddj4zLQAS/I9rK/r5re13exsibKmLEBZ4NLqliul1Lm0trayatWqTz2+YcMG8vPzk/rZGhQuUrHPwVcqs9nf3subdWGeO9TB/Fw3N071kaUprEqpMZCfn8+HH36Yks/WoHAJRISqXCuFdVtThHdPRjnU0cOKYh9XFnlxaAqrUipDaVAYBZddWFniZ3GehzfrwmxuiLCrNcaqMj+VQZfWgFdKZRxNoRkDOW47X5wZ5J5ZQRw24eUjXfzscCctMU1hVUplFg0KY6gi6OKheTmsLvVTH+nnqX3tbAh1ExvQKqxKqcygQWGM2UWoLvLyX6pyWZzvYXtzjMf3tvFRa0zvilYqQ6XLegq1tbXceOONzJ8/nwULFvDP//zPF/X5idCgkCQ+p41bygM8ODeHXLed105088zBDurCWoVVqUyTLuspOBwO/vEf/5G9e/eybds2HnvsMfbu3Tu2nzGm76Y+ZYrPwX2zs9nb1sPG+gjPHuxgYZ6bG6b6CTg1Jit1MdaHujk5xuVmir0OVpcFRtwvHdZTKCkpoaTEWso+KyuLqqoq6urqmD9//ih+AmfSs9I4EBEW5Hl4tCqXFcVe9rX18PjeNradjNA/qENKSmWCdFtP4dixY+zcufOMdowF7SmMI5dduH6qn8X5HjbUhdlUH+Gj1hirSgNUZrtS3Tyl0l4iV/TJkk7rKXR3d3PnnXfyT//0TwSDwbE4vCEaFFIg123nrplBjnT2sj4U5hdHOpkVdLKqNECeR++KVipdpcN6Cn19fdx5553ce++9fPGLX7zkNpyPDh+l0Mygi0fm5XBTqZ9Qdz8/2t/GxrowPZrCqlRaSvV6CsYYHnnkEaqqqvjGN74xloc2RHsKKWa3CVcWeVmQ6+b39WHebYqy+1SMG6b6WZjn1ruilUojF1pP4eGHH2bx4sX4fL4z1lNYu3YtCxYs4Oqrrz7negqDg4M4nU4ee+wxpk+ffsHP37p1K88++yyLFi1iyZIlAHzve9/71HDWaOh6CmmmIdzHG6Ew9ZF+pvocrCnzU+LXKqxq8tL1FC7eaNZT0OGjNFPid/LVOdl8rjxAR+8Azxzs4DfHuwj36ZCSUir5dPgoDYkIi/I9zMlx8XZjlO3NUQ6293L1FC/VhV7sWoVVqQlN11NQ5+S227ix1M9l+R421HWzsT7CR609rC7zMzOoKaxq8jDGTKr5tdGspzDaKQEdPsoAeR47X5qVzV0zgxgMPzvcyS8Od9LWM5DqpimVdB6Ph9bWVq0dlgBjDK2trXg8nkt+D+0pZJDKbBcVWbnsaI7ydmOUH+1r44oiL1cX+3DZJ89VlJpcysrKCIVCNDc3p7opGcHj8VBWVnbJr9egkGEcNuGqYh8L8zxsqg+z7WSU3ad6uHGqj/m5msKqJh6n08mMGTNS3YxJQ4ePMlTAaeO26Vl8dU42AaeNXx3v5rlDHTRGdGEfpdSl06CQ4Ur9Th6Yk81nywO09Qzw4wPtvHaii4imsCqlLoEOH00AIsJl+R7m5rjY2hDh/eYY+9t7uXaKj2WFHuw6pKSUSpD2FCYQj93GqrIAD1flMNXnYENdmKf3t3OsszfVTVNKZYikBgURuUVEDohIjYh86xzPl4vIRhHZKSK7RGTsCnhMYgUeB1+eFeTOmVn0DxpePNzJK0c6adcUVqXUCJIWFETEDjwGfBaYD6wVkbOXB/pfwM+MMUuBe4AfJqs9k42IMDvbzR9W5XJ9iY+jXb08sa+NzfVhegc031spdW7JnFO4EqgxxhwBEJEXgTuA4QuKGuD0ChHZQH0S2zMpOWzCiik+FuS52VQf4e2TUT4+1cONpX6qclyawqqUOkMyh49Kgdph26H4Y8N9B7hPRELAq8CfJrE9k1rQZef2iizunZ2NzyGsO9bFTw91cFJTWJVSw6R6onkt8GNjTBlwK/CsiHyqTSLyqIjsEJEdelfj6EwLOHlgbg63TAvQGk9hfb22m0i/prAqpZIbFOqAacO2y+KPDfcI8DMAY8w7gAcoOPuNjDGPG2OqjTHVhYWFSWru5GETYUmBh/9SlcuyQg8ftsR4fG8b7zdHGdT6MkpNaskMCtuB2SIyQ0RcWBPJ687a5wSwCkBEqrCCgnYFxonHYWNNWYCH5+VQ7HXwRshKYT3epSmsSk1WSQsKxph+4OvA68A+rCyjPSLyXRG5Pb7bXwBfE5GPgBeAB42WQhx3hV4H91QG+cKMLHoHDS/UdPIfRzvp6NUUVqUmm6Te0WyMeRVrAnn4Y3897Pu9wDXJbINKjIgwN8fNzKCL95qivNMY4XBHL8uLvVxV7MOpC/soNSlomQt1BqdNuGaKj4V5bjbVhdnaGOXj1h5uKvUzV1NYlZrwUp19pNJUtsvOHTOCfKUyG7dd+OWxLl6o6aQpqimsSk1kGhTUBZVnOXloXg43l/lpivbz9P52flfbTVRTWJWakHT4SI3IJsKyQi9VuW62NETY2RJjX1sPK6f6uCzfg02HlJSaMLSnoBLmddi4eVqAh+blUOh18HptmB8faOdEd1+qm6aUGiMaFNRFK/I6WFsZ5PMVWcT6Dc8f6uA/j3bSqSmsSmU8HT5Sl0REmJfrZla2i20nI7x7MkpNZy8rin1cWeTFoSmsSmUkDQpqVJw24boSP4vyPGysD7O5IcJHrTFWlfqZna0prEplGh0+UmMix23nCzOC3FMZxGkTXjnaxUuHO2nRFFalMooGBTWmKrJcPDwvhzVlfhoi/Ty5v531oW5imsKqVEbQ4SM15mwiXB5PYd1cH2FHc4y9bT1cX+JnUb5bU1iVSmPaU1BJ43PYuKU8wINzc8hz23mttpufHOggpCmsSqUtDQoq6ab4HNw7O5vbp2cR7h/kuUMd/OpYF119msKqVLrR4SM1LkSE+XluKk+nsDZFOdjRw9XFPq7QFFal0oYGBTWuXHZh5VQ/i/I9vFkX5venU1jL/FQGNYVVqVTT4SOVErluO3fODHL3rCB2EV4+0sXPD3fSGtMUVqVSSYOCSqkZQRcPV+WwqtRPXbifJ/e182ZdmJ4BTWFVKhV0+EilnF2EK4q8zM918/uGMO81RdlzKsb1U/0synPrkJJS40h7Cipt+J02bi3P4oG52eS47bx6opufHOygPqwprEqNFw0KKu2U+JzcNzub26YH6Ood5CcHO/j18S66+3RISalk0+EjlZZEhIV5HmZnu3inMcp7zVEOtvdyzRQv1YVe7JrCqlRSaE9BpTW33cYNpX7+cF4u0wIONtZHeHJ/O4c7elPdNKUmJA0KKiPkeex8aVY2X5oZBODnRzr5+eEOTsX0rmilxpIOH6mMMivbRUWWkx3NUbY2RvnR/jauLPSyYooXt12vcZQaLQ0KKuPYbcLyYh8L8jz8vj7MtqYou0/1cEOpjwW5msKq1GjopZXKWAGnjc9Nz+L+OdlkuWz8+ng3zx3qoCGiKaxKXSoNCirjTfU7uX9ONreWB2jvGeCZAx28eqKLsKawKnXRdPhITQgiwuJ8D3NzXGxtjLKjKcqBtl6uKfFxeaEHuw4pKZUQ7SmoCcVtt3FTqZ9HqnIo9Tt4sy7MU/vbOdqpKaxKJUKDgpqQ8j0OvjQryF0zgwwMGl463MnLRzpp79EUVqUuRIeP1IQlIlRmu6jIymV7U5S3T0Z4Yl8vVxZ5WVHsw2XXISWlzpbUnoKI3CIiB0SkRkS+dZ59viwie0Vkj4g8n8z2qMnJYRNWTPHx6Pxc5uW4eedklCf2tbHnVAxjTKqbp1RaSVpPQUTswGPAGiAEbBeRdcaYvcP2mQ38FXCNMaZNRIqS1R6lspx2/qAii6UFHtaHwvzqeDc7W2KsLgswxaedZqUguT2FK4EaY8wRY0wv8CJwx1n7fA14zBjTBmCMaUpie5QCoCzg5P652Xx2WoBTPQP8+EA7vz3RTURTWJVKrKcgIm7gTqBi+GuMMd+9wMtKgdph2yFg+Vn7zIm//1bADnzHGPPbc3z+o8CjAOXl5Yk0WakLsolwWYGVwvpWY4T3m2Psa+/huhIfywo82DSFVU1SifYU/hPrKr8fCA/7Gi0HMBu4AVgLPCEiOWfvZIx53BhTbYypLiwsHIOPVcricdhYXRbgkXk5lPgcrA9ZKazHujSFVU1OiQ6klhljbrnI964Dpg1/j/hjw4WAd40xfcBRETmIFSS2X+RnKTUqBV4Hd88Kcqijlw11YV6s6WROtoubSv3kuO2pbp5S4ybRnsLbIrLoIt97OzBbRGaIiAu4B1h31j6/xOolICIFWMNJRy7yc5QaEyLCnBw3X6vKZWWJj6NdvfxoXxtbGsL0DWqWkpocEu0pXAs8KCJHgR5AAGOMWXy+Fxhj+kXk68DrWPMFTxlj9ojId4Edxph18eduFpG9wADwl8aY1lEcj1Kj5rAJV0/xsTDPzab6CFsbo3zc2sONpX7m5bi0Cqua0CSRPG0RmX6ux40xx8e8RSOorq42O3bsuPgX1jbCkRBcuxTsOhygElfb3ccboW6aogNMCzhYUxagyKsprCqziMj7xpjqkfZLaPgofvLPAf4g/pWTioCgVCpMCzh5cG4On5nmpyU6wNP72/ldbTfRfk1hVRNPQkFBRP4c+ClQFP96TkT+NJkNUyqd2ERYWuDl0fm5LCv0sLMlxr/vbeOD5iiDele0mkAS7QM/Aiw3xoQBROTvgXeAf0lWw5RKR16HjTVlAS7Lt+6K/l0ozM6WGGvKApRnOVPdPKVGLdHsI8GaCD5tIP6YUpNSkdfB2sogn6/IomfA8HxNB7882klHr1ZhVZkt0Z7C08C7IvIf8e3PA08mp0lKZQYRYV6um1nZLt49GWXbyQg1Hb1cVexjebEXp02vm1TmSSgoGGO+LyKbsFJTAR4yxuxMWquUyiBOm3BtiY9F+W421oV5qzHCrlMxbir1MzdbU1hVZrlgUBCRoDGmU0TygGPxr9PP5RljTiW3eUpljmyXnc/PCHK8q5f1oTC/PNrF9ICT1WV+CjWFVWWIkX5TnwduA94HhqdYSHx7ZpLapVTGmp7l4qF5Tna2xNjSEOGp/e0sK/Rw3RQfHocudqjS2wWDgjHmtvi/M8anOUpNDDYRLi/0UpXrZktDhA+aY+w91cP1U/0szndrFVaVthK9T+EaEfHHv79PRL4vIlrDWqkR+Bw2PjMtwINzc8j32PltbTfPHGgn1N2X6qYpdU6J9mX/FYiIyGXAXwCHgWeT1iqlJphin4N7Z2dzR0UWkX7Dc4c6WHesiy5NYVVpJtHZr35jjBGRO4D/Z4x5UkQeSWbDlJpoRISqXDezgi62nYzwblOUQx09rCj2cWWRF4emsKo0kGhQ6BKRvwLuA1aKiA3Q2zeVugQuu7Byqp/F+R7erAuzuSHCrtYYq8r8VAY1hVWlVqLDR3djlcx+xBjTiLVgzj8krVVKZZquMByuhcHEi+TluO18cWaQe2YFcdiEl4908bPDnbTE+pPYUKUuLNGb1xqB7w/bPgH8JFmNUipjDAzCsToInbS2C3IhO3BRb1ERdPFQlpMPmmO81RjhqX3tXF7o4ZoSHx67prCq8TXSzWtvGWOuFZEuznGfgjEmmNTWKZXO2rvg4DGI9kCW3+otXCK7CFcUeVmQ6+b3DWG2N8fY0xZPYc1z65CSGjcj3adwbfzfrPFpjlIZoH/AWrCpoRk8Llg8B4yBjw+N+q19ThufLc9iaYGXN0LdvHaiO16F1U+pX6fxVPIlep/CVSKSNWw7S0SWJ69ZSqWp1nbYsdsKCKXFUL0Acse+wzzF5+C+2dn8wfQA3X2DPHuwg18f76K7Txf2UcmVaPbRvwLLhm2Hz/GYUhNXXx/U1ELTKfB5YOksCF7c3MHFEhEW5HmYne3mnZMR3muKcrC9l6uneKku1BRWlRyJBgUxwxZzNsYMiohW+FITnzHQ3AY1J6xho+klUF4CtvGbAHbZJV4ew8OGujCb6iN81BpjVWmAymzXuLVDTQ6J/mYfEZE/ExFn/OvPgSPJbJhSKdfTC3sOw74j1tzBsiqoKB3XgDBcrtvOXTODfHlWEEH4xZFOfn64g1MxvStajZ1Ef7v/CLgaqANCwHLg0WQ1SqmUMsaaM9i+B9o6YGYZLK2CgC/VLQNgZtDFI/NyuKnUT213Pz/a38bGujA9AzrfoEYv0fsUmoB7ktwWpVIvGoODx6100+wAzK0AryfVrfoUu0248nQKa32Yd5ui7D4V44apfhZqCqsahUSzj+aIyAYR2R3fXiwi/yu5TVNqHBkDtY2wY691v8Hs6XDZ3LQMCMP5nTZunZ7F/XOyyXbZ+c2Jbp492EFDWKuwqkuT6PDRE8BfAX0AxphdaM9BTRThKOzcb917kJMFVyyEqYWQQVfbU/1Ovjonm8+VB+joHeCZgx385ngXYU1hVRcp0QwinzHmvbO6pFqgRWW2wUE40QgnGsBhh6oZUJiXUcFgOBFhUb6HOTku3m6Msr35zBRWu6awqgQkGhRaRGQW8VIXInIX0JC0VimVbJ3d1txBOApFeTBrGrgmxh3DbruNG0v9XJbvYUNdNxvrI3zU2sPqMj8zg5rCqi4s0aDwX4HHgXkiUgccBe5NWquUSpaBAThWbxWwczlhQSUU5KS6VUmR57HzpVnZ1HT0sqGum58d7qQy6GJVmZ9ctz3VzVNpasSgEF87odoYszq+JKfNGNOV/KYpNcbaOq3eQawHSgphZik4Jv49mJXZLiqyctnRHOXtxig/2tfGFUVeri724bLrkJI604h/EfG7l/878DNjzKWXgVQqVfr74wXsWsDjtgrYJaFeUTpz2ISrin0szPOwqT7MtpNRdp/q4capPubnagqr+kSi2UfrReSbIjJNRPJOf430IhG5RUQOiEiNiHzrAvvdKSJGRKoTbrlSiWhpt25Ca2iBsmKonj/pAsJwAaeN26Zn8dU52QScNn51vJvnDnXQGNG8EWVJtO98N9Yk85+c9fjM871AROzAY8AarLugt4vIOmPM3rP2ywL+HHg30UYrNaLePmsltKZT4PdacwdBf6pblTZK/U4emJPNrlM9/L4+zI8PtHNZvpvrS/z4nLqwz2SWaFCYjxUQrsUKDluAfxvhNVcCNcaYIwAi8iJwB7D3rP3+N/D3wF8m2Balzs8YKxDU1FqTytOnQvmUlNUrSmciwmX5HubmuNjaEOH95hj723u5doqPZYUe7DqkNCkl+pfyDFAF/AD4F6wg8cwIrykFaodth+KPDRGRZcA0Y8xvEmyHUucX64XdNbD/KHjdcPl8qJiqAWEEHruNVWUBHq7KYarPwYa6ME/vb+dYZ2+qm6ZSINGewkJjzPxh2xtF5Owr/osSz2r6PvBgAvs+SrwAX3l5+Wg+Vk1ExlhzBkdqrX7srDJrARy90r0oBR4HX54V5FBHL2/WhXnxcCdzsl3cVOonR1NYJ41EL6E+EJGrTm/EV13bMcJr6oBpw7bL4o+dlgUsBDaJyDHgKmDduSabjTGPG2OqjTHVhYWFCTZZTQqRGHx0AA4dt9ZJrl4AZVM0IFwiEWFOjps/rMplZYmPo129PLGvjc31YXoHzMhvoDJeoj2Fy4G3ReREfLscOCAiHwPGGLP4HK/ZDswWkRlYweAe4CunnzTGdAAFp7dFZBPwTWPMSMFGKat3EDoJx+pAbDBnOkwp0GAwRhw24eopPhbmudlUH+Htk1E+PtXDjaV+qnJcmsI6gSUaFG652Dc2xvSLyNeB1wE78JQxZo+IfBfYYYxZd7HvqRQA3RE4eAy6IpCfA7PLwa3lG5Ih6LJze0UWSws8rA91s+5YFx/4HawpC1Dsm/g3/k1Gia6ncPxS3twY8yrw6lmP/fV59r3hUj5DTSKDg3C8wSpx7bBD1UwozNXewTiYFnDywNwcdrV+ksK6pMDDdSU+fA6dyJ9INNSrzNDZDQeOWXMIRXlQWQ5O/fUdTzYRlhR4mJfjYktjhA+aY+xr6+G6Eh9LCzzYNDhPCPpXpdLbwAAcrYO6JnA7YWGlNWSkUsbjsLGmLMCSfA/rQ2HeCIX5sCXG6jI/07N0GC/TaVBQ6aut05o7iPVai97MKLOGjVRaKPQ6uKcyyMGOXjbUhXmhppO5OVYKa7ZL/58ylQYFlX76++FwCBpbrJvQLptrrYim0o6IMDfHzcygi/eaorzTGOFwRy/Li71cVezDqQv7ZBwNCiq9tLTBoRNW7aJpU6wyFXadyEx3TptwTTyFdWNdmK2NUT5u7eGmUj9zNYU1o2hQUOmhtw9qTkBzm1XAbmGldTOayijZLjufnxHkRFcfb4S6+eWxLsoDTlaX+Sny6ukmE+j/kkqtoQJ2J2Bg0KpVNE0L2GW68iwnD83L4cOWGJsbIjy9v52l8RRWr6awpjUNCip1Yj1WeYpTnVZZ6zkVVi9BTQg2EZYVeqnKdbOlIcLOFiuFdeVUH5flawprutKgoMafMVDfDEdD8QJ206C0SG9Cm6C8Dhs3TwuwpMDDG6FuXq8Ns7MlxuqyAOUBZ6qbp86iQUGNr0jMSjPt6LYyiuZUWBlGasIr8jr4SmU2+9t72VgX5vlDHVTluLix1E9QU1jThgYFNT6MscpTHKu3sonmVkBxvvYOJhkRoSrXTWW2i20nI7x7MkpNZy9XFftYXuTFoSmsKadBQSVfd8QqUdEdgYIcq0SFFrCb1Jw24boSP4vyPGysD7OlIcKu1hg3lfqZk60prKmkQUElz+AgHK+HE41WnaL5s6wCdkrF5bjtfGFGkGNdvawPhfmPo11UZDlZXeqnQFNYU0J/6io5OrqtuYNIzBommjVNC9ip86rIcvHwPCcftMTY0hDhyf3tXF7o4dopPjyawjqu9K9Uja0zCti5YNFsyMtOdatUBrCJUF3oZX6Om80NEXY0x9jb1sP1JX4W5bs1hXWcaFBQY+dUBxw8Dj0Fzlg1AAASVUlEQVS9MLUIZpRqATt10XxOG7eUWyms60PdvFbbHU9h9VOmKaxJp0FBjV5fPxyuhZOt4PXAkrmQrQXs1OhM8Tm4d3Y2+9p62Vgf5rlDHSzIdXNDqY8sp15sJIsGBTU6zW3WXcl9/VZ5ioqpWqJCjRkRYX7esBTWpigHO3q4utjHFZrCmhQaFNSl6e2zqpm2tEHAC4vmQJYv1a1SE5TLLqyc6mdRvoc368L8viHCR60xVpX5qQxqCutY0qCgLo4x1jDR4VqrgN2MUigr1t6BGhe5bjt3zgxytNNKYX35SBczs5ysKvOT79HT2VjQn6JKXKzHmkhu64RgAOZOB58WsFPjb0bQxcNVTj5ojvFWQ4Qn97VTXeTlmile3Lr+xqhoUFAjMwbqm+BInbVdWW4tj6lddpVCdhGuKPIyP9fN7xvCvNcUZc+pGNdP9bMoz61DSpdIg4K6sEgUDhyHzm7IDcKc6eDRAnYqffidNm4tz2JpgYf1oTCvnrBSWNeU+Znq1xTWi6VBQZ3b4CDUnrTKVGgBO5UBSnxO7pudzZ62HjbVRfjJwQ4W5rm5YaqfgFOHlBKlQUF9WlfYKlHRHYWCXJhdDi694lLpT0RYmOdhdraLdxqjvNcc5WB7L9dM8VJd6MWuKawj0qCgPjEQL2BX22gFAS1gpzKU227jhlI/i/M9bKjrZmN9hI9ae1hV6mdWtlbovRANCsrS0WWVt472wJR8mKkF7FTmy/PY+dKsbA539LK+rpufH+lkVtDJqtIAeR69K/pc9K9+susfsJbFrG8GjxawUxPTrGwXFVm57GiOsrUxyo/2t3FloZcVmsL6KRoUJrPhBexK4wXs7Hr1pCYmu01YXuxjQZ6HTfVhtjVF2X2qhxtKfSzI1RTW0zQoTEbDC9j5PLBkHmQHUt0qpcZFwGnjtulZLCvw8EYozK+Pf1KFtcSnCRUaFCYTY6xaRYdOWMNG5SUwvURLVKhJaarfyf1zsvn4VA+b6sM8c6CDxfluri/x45/EKaxJPXIRuUVEDohIjYh86xzPf0NE9orILhHZICLTk9meSa2nF/Ychr1HrMVvllVZw0UaENQkJiIszvfw6Pxcrizysru1h8f3tvFeU5QBY1LdvJRIWk9BROzAY8AaIARsF5F1xpi9w3bbCVQbYyIi8sfA/wXuTlabJiVjoDFewM7EC9hNm6I3oSk1jMdu46ZSP5flu1kfCvNmXZiPWmOsLvUzIzi5UliTeZl4JVBjjDlijOkFXgTuGL6DMWajMSYS39wGlCWxPZNPtAd2HbRuRAt44fIF1pCRBgSlzinf4+DLs4LcOTOLgUHDS4c7eflIJ+09A6lu2rhJ5pxCKVA7bDsELL/A/o8Ar53rCRF5FHgUoLy8fKzaN3EZY62RfLQOBOuO5BItYKdUIkSE2dluZmS52N4U5e2TEZ7Y18uVRV5WFPtw2Sf231FaTDSLyH1ANXD9uZ43xjwOPA5QXV09OQf6EhWOWjehdYUhLwizK6z7D5RSF8VhE1ZM8bEwz82m+gjvnIynsE71MX8Cp7AmMyjUAdOGbZfFHzuDiKwG/idwvTGmJ4ntmdgGB63yFMcbrHsN5s2AojztHSg1SlkuO39QYVVhfSPUza+GUlgDTPGlxXX1mErmEW0HZovIDKxgcA/wleE7iMhS4N+BW4wxTUlsy8TWFbZ6B+GoVauoUgvYKTXWygJOHpibw8etPWxqCPPjA+0syfewssSHbwKlsCYtKBhj+kXk68DrgB14yhizR0S+C+wwxqwD/gEIAD+Pd8VOGGNuT1abJpyBQThWB6GTVhBYMMuqaqqUSgqbCJcVeJib4+KtxgjvN8fY197DdSU+lhV4sE2AnnlS+z7GmFeBV8967K+Hfb86mZ8/obV3WVlF0R6YUgCzysAx8bqySqUjj8PG6rIAS/I9rK8Lsz4U5sP4XdEVWZk9h6dnkUzTPwBHQtDQbK2AtniOtSKaUmrcFXgd3D0ryKGOXjbUhXmxppM52S5uKvWT487MOmIaFDJJazscOg49fVBWDBVTtYCdUikmIszJcTMj6OK9pijbTkb40b5elhd7uarYhzPDFvbRoJAJ+vqgphaaTlkF7JbOgqAWsFMqnThtwjVTfCzKc7OxLszWxigft/ZwY6mfeTmujElh1aCQzoyB5jaoiRewm15i3ZGs9YqUSltBl507ZgRZ2t3H+lA3/3msiw8CDtaUBSjypv8pN/1bOFn19FrVTFvbIcsHcyog4Et1q5RSCSoPOHlwbg4ftcbYXB/h6f3tLC3wcF2JD68jfS/sNCikG2OgsQUOh6wCdjPLrPmDDOl6KqU+YRNhaYGXeTlutjRE2NkSY29bDytLfCxJ0xRWDQrpJBqzVkJr77IWvZlbAV5PqlullBolr8PGzdMCLCnwsD4U5nehMDtbYqwpC1CelV43mmpQSAfGWDegHau3egSzp0NJgfYOlJpgirwO1lYGOdDey5t1YZ6v6WBejosbS/1ku9Ijk1CDQqqdUcAuG+ZMtxbBUUpNSCLCvFw3s7JdvHvSSmGt6ejlqmIfy4u9KU9h1aCQKoODcKIRTjSAww5VM6BQC9gpNVk4bcK1JT4W5VsprG81Rth1KsZNpX7mZqcuhVWDQip0dltzB+GoVcl01jQtYKfUJJXtsvP5GUGOd/WyPhTml0e7mB5wsrrMT2EKUlg1KIyngQFr3mCogF0lFOSkulVKqVQwxhox6BuA/n6mD/TzUIFhZ4dhS3cvT+3vZZmjn+tsETwD1j5MK7EqISeRBoXx0tZp9Q5iPdYqaDNLtYCdUhPB4KB1c2lfv3XiHvp+4NzbfcMeN2euGWYDLgeqxMYWdw4fmCz2ksX1hFnsANs4zDfoWSnZ+vvjBexarAJ2l82BHC1gp1RaMebTJ/azT+DnO9EPDl74vR32+JfD+jfg+uR7hx2cjk+2nda/PoeDz9htLIkO8Eaom9+Gbex02bnZ46c0yT8KDQrJ1BIvYNerBeyUSjpjrCHa/oFhJ/OzTuzn2x4YuPB722zgHHZi97ghy//JyX74c2dvj2LCuNjn4N7Z2eyPp7B2940QgMaABoVk6I0XsGs+BX6vNXcQ9Ke6VUqlv9Pj7CMNwZxvSOZCRM64GsfltP4+h1/FO89zYk9hvTERoSrXTWW2C8c4JCRpUBhLxliVTGtqrSuP6VOhfIoWsFOTz+kT+0hX6uc60Z81zv4pQyfs+Enb4z7zBD78xD5822bL6JTv8bp/QYPCWIn1WkNFpzqsbuXcCusqRKlMdXqcfaTJ0v7+Yc/F9xtpnN1uP/OE7fNeeAjm9LZ9dMMxamQaFEbLGGsS+UgtGKxlMUu1gJ1KE8ZYa3knNAQz7MR++vkLsdnOPLF73Oc+oZ9rSEb/PtKWBoXRiMSsdZI7uiEnyypv7XWnulVqIhocPP+V+hkn+rOv4j+d9niG0+PsjmHj7D7PyJOnTocOi05QGhQuxVABuzoQm1WvaIoWsFMjMGaEIZjzDcn0w+BI4+xnnbg9rgtfqQ8/sevvrRpGg8LF6o5YvYOuCOTnwOxyLWCnPnEkZF1pn+vKfWCkcXbbmVfmp6/Yzzt5OizXXU/saoxoUEjU4CAcb4DaxngBu5nW7eb6x6jAujAQsepaxXo+OYG7XeC/0OTpsBO7DseoNKBBIRGd3VZ560jMKmBXWW79MSt1mt8L1y3TiwSV8fTMdiEDA3C0DuqawO2EhZXWkJFS56IBQU0AGhTOp63TmjuI9cLUQphRZnXxlVJqAtOgcLb+fjgcgsYWK730srlWuqlSSk0CGhSGa2mDQyes2kXTplhlKuw6+aeUmjw0KEC8gN0JaG6zJgwXVlqlKpRSapKZ3EFhqIDdCSuHvGKq1UPQ1ECl1CSV1LOfiNwiIgdEpEZEvnWO590i8lL8+XdFpCKZ7TlDrAd2H4L9R62bhC6fbw0XaUBQSk1iSespiIgdeAxYA4SA7SKyzhizd9hujwBtxphKEbkH+Hvg7mS1CbCK1tU1wdFQvIDdNCgt0nRCpZQiuT2FK4EaY8wRY0wv8CJwx1n73AE8E//+F8AqkSSfnXcdtIaLggG4YoG1IpoGBKWUApI7p1AK1A7bDgHLz7ePMaZfRDqAfKAlaa2Kxqy1DorzNRgopdRZMmKiWUQeBR4FKC8vv7Q3Kcixqk6WFmkBO6WUOo9kDh/VAdOGbZfFHzvnPiLiALKB1rPfyBjzuDGm2hhTXVhYeGmt8XpgZpkGBKWUuoBkBoXtwGwRmSEiLuAeYN1Z+6wDHoh/fxfwpjEjLdCqlFIqWZI2fBSfI/g68DpgB54yxuwRke8CO4wx64AngWdFpAY4hRU4lFJKpUhS5xSMMa8Cr5712F8P+z4GfCmZbVBKKZU4vVNLKaXUEA0KSimlhmhQUEopNUSDglJKqSEaFJRSSg2RTLstQESageOX+PICkllCIz3pMU8OesyTw2iOeboxZsS7fzMuKIyGiOwwxlSnuh3jSY95ctBjnhzG45h1+EgppdQQDQpKKaWGTLag8HiqG5ACesyTgx7z5JD0Y55UcwpKKaUubLL1FJRSSl3AhAwKInKLiBwQkRoR+dY5nneLyEvx598VkYrxb+XYSuCYvyEie0Vkl4hsEJHpqWjnWBrpmIftd6eIGBHJ+EyVRI5ZRL4c/7/eIyLPj3cbx1oCv9vlIrJRRHbGf79vTUU7x4qIPCUiTSKy+zzPi4j8IP7z2CUiy8a0AcaYCfWFVab7MDATcAEfAfPP2udPgH+Lf38P8FKq2z0Ox3wj4It//8eT4Zjj+2UBm4FtQHWq2z0O/8+zgZ1Abny7KNXtHodjfhz44/j384FjqW73KI95JbAM2H2e528FXgMEuAp4dyw/fyL2FK4EaowxR4wxvcCLwB1n7XMH8Ez8+18Aq0QyesHmEY/ZGLPRGBOJb27DWgkvkyXy/wzwv4G/B2Lj2bgkSeSYvwY8ZoxpAzDGNI1zG8daIsdsgGD8+2ygfhzbN+aMMZux1pc5nzuAnxjLNiBHRErG6vMnYlAoBWqHbYfij51zH2NMP9AB5I9L65IjkWMe7hGsK41MNuIxx7vV04wxvxnPhiVRIv/Pc4A5IrJVRLaJyC3j1rrkSOSYvwPcJyIhrPVb/nR8mpYyF/v3flGSusiOSj8ich9QDVyf6rYkk4jYgO8DD6a4KePNgTWEdANWb3CziCwyxrSntFXJtRb4sTHmH0VkBdZqjguNMYOpblgmmog9hTpg2rDtsvhj59xHRBxYXc7WcWldciRyzIjIauB/ArcbY3rGqW3JMtIxZwELgU0icgxr7HVdhk82J/L/HALWGWP6jDFHgYNYQSJTJXLMjwA/AzDGvAN4sGoETVQJ/b1fqokYFLYDs0Vkhoi4sCaS1521zzrggfj3dwFvmvgMToYa8ZhFZCnw71gBIdPHmWGEYzbGdBhjCowxFcaYCqx5lNuNMTtS09wxkcjv9i+xegmISAHWcNKR8WzkGEvkmE8AqwBEpAorKDSPayvH1zrg/ngW0lVAhzGmYazefMINHxlj+kXk68DrWJkLTxlj9ojId4Edxph1wJNYXcwarAmde1LX4tFL8Jj/AQgAP4/PqZ8wxtyeskaPUoLHPKEkeMyvAzeLyF5gAPhLY0zG9oITPOa/AJ4Qkf+GNen8YCZf5InIC1iBvSA+T/JtwAlgjPk3rHmTW4EaIAI8NKafn8E/O6WUUmNsIg4fKaWUukQaFJRSSg3RoKCUUmqIBgWllFJDNCgopZQaokFBqXEkIhWnq1+KyA0i8utUt0mp4TQoKJWA+I1C+veiJjz9JVfqPOJX9QdE5CfAbuCrIvKOiHwgIj8XkUB8vytE5G0R+UhE3hORrPhrt8T3/UBErk7t0SiVmAl3R7NSY2w2VkmUGuAVYLUxJiwi/wP4hoj8H+Al4G5jzHYRCQJRoAlYY4yJichs4AWsQoRKpTUNCkpd2HFjzDYRuQ1rAZet8TIhLuAdYC7QYIzZDmCM6QQQET/w/0RkCVa5iTmpaLxSF0uDglIXFo7/K8Abxpi1w58UkUXned1/A04Cl2EN006ERX7UJKBzCkolZhtwjYhUgtUTEJE5wAGgRESuiD+eNawce0O8pv9XsYq5KZX2NCgolQBjTDPWgj0viMgurKGjefElIu8G/kVEPgLewCrd/EPggfhj8/ikx6FUWtMqqUoppYZoT0EppdQQDQpKKaWGaFBQSik1RIOCUkqpIRoUlFJKDdGgoJRSaogGBaWUUkM0KCillBry/wEYMY492EdtVwAAAABJRU5ErkJggg==\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x106626ba8>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.plot(recall_1, precision_1, \"pink\", label=\"Model_1\")\n",
"plt.plot(recall_2, precision_2, \"skyblue\", label=\"Model_2\")\n",
"plt.xlabel('recall')\n",
"plt.ylabel('precision')\n",
"plt.legend()\n",
"\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Use [step_plot](http://scikit-learn.org/stable/auto_examples/model_selection/plot_precision_recall.html#plot-the-precision-recall-curve) to plot PR-curves\n",
"\n",
"This code is found on the sklearn documentation.\n",
"\n",
"Especially, we notice we are using the `post` argument. From the [`step()`](https://matplotlib.org/api/_as_gen/matplotlib.pyplot.step.html) documentation:\n",
"\n",
"> 'post': The y value is continued constantly to the right from every x position, i.e. the interval `[[x[i], x[i+1])` has the value `y[i]`.\n",
"\n",
"Because we only have two points in order: $\\left(1, 0.1\\right), \\left(0, 1\\right)$, the interval between $x \\in [0, 1)$ will have the same y value $0.1$."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x10663b0f0>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from sklearn.utils.fixes import signature\n",
"\n",
"# In matplotlib < 1.5, plt.fill_between does not have a 'step' argument\n",
"step_kwargs = ({'step': 'post'}\n",
" if 'step' in signature(plt.fill_between).parameters\n",
" else {})\n",
"\n",
"plt.step(recall_1, precision_1, color='blue', alpha=0.2, where='post')\n",
"plt.fill_between(recall_1, precision_1, alpha=0.2, color='blue', label=\"Model_1\", **step_kwargs)\n",
"\n",
"plt.step(recall_2, precision_2, color='red', alpha=0.2, where='post')\n",
"plt.fill_between(recall_2, precision_2, alpha=0.2, color='red', label=\"Model_2\", **step_kwargs)\n",
"\n",
"plt.xlabel('Recall')\n",
"plt.ylabel('Precision')\n",
"plt.ylim([0.0, 1.05])\n",
"plt.xlim([0.0, 1.0])\n",
"plt.legend()\n",
"plt.title('Precision-Recall curve: AP_model_1={:.3f}, AP_Model_2={:.3f}'.format(model_1_ap,\n",
" model_2_ap))\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Computing AP using 1D Array vs. 2D Array"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"20 20 20\n",
"Model_1 AP = 0.30556\n",
"Model_2 AP = 0.10000\n"
]
}
],
"source": [
"model_1_score = [0.4325, 0.3498, 0.2368, 0.2601, 0.1698, 0.211 , 0.1913, 0.8441,\n",
" 0.098 , 0.1682, 0.1844, 0.3937, 0.1746, 0.295 , 0.3164, 0.1856,\n",
" 0.1353, 0.2147, 0.1898, 0.2257]\n",
"model_2_score = [0.502, 0.502, 0.502, 0.502, 0.502, 0.502, 0.502, 0.502,\n",
" 0.502, 0.502, 0.502, 0.502, 0.502, 0.502, 0.502, 0.502,\n",
" 0.502, 0.502, 0.502, 0.502]\n",
"y_true = [1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]\n",
"\n",
"print(len(model_1_score), len(model_2_score), len(y_true))\n",
"\n",
"model_1_ap = metrics.average_precision_score(y_true, model_1_score)\n",
"model_2_ap = metrics.average_precision_score(y_true, model_2_score)\n",
"print(\"Model_1 AP = {:.5f}\".format(model_1_ap))\n",
"print(\"Model_2 AP = {:.5f}\".format(model_2_ap))"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[0 1]\n",
" [1 0]\n",
" [1 0]\n",
" [1 0]\n",
" [1 0]]\n",
"[[0.5675 0.4325]\n",
" [0.6502 0.3498]\n",
" [0.7632 0.2368]\n",
" [0.7399 0.2601]\n",
" [0.8302 0.1698]]\n",
"[[0.498 0.502]\n",
" [0.498 0.502]\n",
" [0.498 0.502]\n",
" [0.498 0.502]\n",
" [0.498 0.502]]\n"
]
}
],
"source": [
"y_true_2d = np.stack([[0, 1] if i else [1, 0] for i in y_true], axis=0)\n",
"model_1_score_2d = np.stack([[1-i, i] for i in model_1_score], axis=0)\n",
"model_2_score_2d = np.stack([[1-i, i] for i in model_2_score], axis=0)\n",
"\n",
"print(y_true_2d[:5])\n",
"print(model_1_score_2d[:5])\n",
"print(model_2_score_2d[:5])"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model_1 AP = 0.60384\n",
"Model_2 AP = 0.50000\n"
]
}
],
"source": [
"model_1_ap_2d = metrics.average_precision_score(y_true_2d, model_1_score_2d)\n",
"model_2_ap_2d = metrics.average_precision_score(y_true_2d, model_2_score_2d)\n",
"\n",
"print(\"Model_1 AP = {:.5f}\".format(model_1_ap_2d))\n",
"print(\"Model_2 AP = {:.5f}\".format(model_2_ap_2d))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Feeding `average_precision_score()` with the raw 2D probability(score) gives different AP value.\n",
"\n",
"The documentation says that `y_true` and `y_score` can have shape `[n_samples, n_classes]`. However, it does not mean `average_precision_score()` expects the 2D array for binary cases."
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model_1 AP = [0.90212363 0.30555556]\n",
"Model_2 AP = [0.9 0.1]\n"
]
}
],
"source": [
"model_1_ap_2d = metrics.average_precision_score(y_true_2d, model_1_score_2d, average=None)\n",
"model_2_ap_2d = metrics.average_precision_score(y_true_2d, model_2_score_2d, average=None)\n",
"\n",
"print(\"Model_1 AP = {}\".format(model_1_ap_2d))\n",
"print(\"Model_2 AP = {}\".format(model_2_ap_2d))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"`average_precision_score()` is treating each column as one binary class, we can get the same AP for the positive column if not using `average`."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Unique Score and PR and AUC\n",
"\n",
"Sometimes the results of PR and AUC can be very confusing if our `y_score` has only one unique element."
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Predicted score: [0.15, 0.15, 0.15, 0.15, 0.15, 0.15, 0.15, 0.15, 0.15, 0.15, 0.15, 0.15, 0.15, 0.15, 0.15, 0.15, 0.15, 0.15, 0.15, 0.15]\n",
"Predicted label: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]\n",
"True label: [0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]\n"
]
}
],
"source": [
"model_1_score = [0.15 for i in range(20)]\n",
"model_1_predict_label = [1 if i >= 0.5 else 0 for i in model_1_score]\n",
"\n",
"y_true = [0] + [1 for i in range(19)]\n",
"\n",
"print(\"Predicted score: {}\".format(model_1_score))\n",
"print(\"Predicted label: {}\".format(model_1_predict_label))\n",
"print(\"True label: {}\".format(y_true))"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Acc = 0.05, AP = 0.95, AUCROC= 0.5\n"
]
}
],
"source": [
"print(\"Acc = {}, AP = {}, AUCROC= {}\".format(\n",
" metrics.accuracy_score(y_true, model_1_predict_label),\n",
" metrics.average_precision_score(y_true, model_1_score),\n",
" metrics.roc_auc_score(y_true, model_1_score)))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Even through the accuracy is extremely bad (as expected), we still get high AP and AUCROC."
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"precisions = [0.95 1. ]\n",
"recalls = [1. 0.]\n",
"thresholds = [0.15]\n"
]
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x1067afe10>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"precisions, recalls, thresholds_pr = metrics.precision_recall_curve(y_true, model_1_score)\n",
"print(\"precisions = {}\".format(precisions))\n",
"print(\"recalls = {}\".format(recalls))\n",
"print(\"thresholds = {}\".format(thresholds_pr))\n",
"\n",
"# In matplotlib < 1.5, plt.fill_between does not have a 'step' argument\n",
"step_kwargs = ({'step': 'post'}\n",
" if 'step' in signature(plt.fill_between).parameters\n",
" else {})\n",
"\n",
"plt.step(recalls, precisions, color='blue', alpha=0.2, where='post')\n",
"plt.fill_between(recalls, precisions, alpha=0.2, color='blue', label=\"Model_1\", **step_kwargs)\n",
"\n",
"\n",
"plt.xlabel('Recall')\n",
"plt.ylabel('Precision')\n",
"plt.ylim([0.0, 1.05])\n",
"plt.xlim([0.0, 1.0])\n",
"plt.legend()\n",
"plt.title('Precision-Recall curve: AP={:.3f}'.format(metrics.average_precision_score(y_true, model_1_score)))\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Because there is only one unique prediction score, we only get one threshold which is exactly that unique score. At this threshold, all scores ($0.15$) are translated to the positive label ($1$). Since there is only one negative label in `y_true`, the precision at this threshold is $\\frac{\\text{TP}}{\\text{TP} + \\text{FP}} = \\frac{19}{19+1} = 0.95$\n",
"\n",
"Therefore, the AP score is simply given by $\\left(1 - 0\\right) \\times \\text{Precision} = 0.95$.\n",
"\n",
"This tells us when `y_score` has only one unique value, PR curve and AP only depend on the number of positives in `y_true`. Given this condition,\n",
"\n",
"$$\\text{Precision} = \\frac{\\text{TP}}{\\text{TP} + \\text{FP}} = \\frac{N_p}{N_p + \\left(N - N_p\\right)} = \\frac{N_p}{N}$$\n",
"\n",
"Interestingly, \n",
"\n",
"$$\\text{ACC} = \\frac{\\text{TP} + \\text{TN}}{\\text{N}}$$\n",
"\n",
"If the value of `y_score` $\\geq 0.5$, then all predictions in ACC becomes $1$.\n",
"\n",
"$$\\text{ACC} = \\frac{N_p}{N} = \\text{Precision}$$\n",
"\n",
"If the value of `y_score` $< 0.5$, then all predictions in ACC becomes $0$.\n",
"\n",
"$$\\text{ACC} = \\frac{N_n}{N} = 1 - \\frac{N_p}{N} = 1 - \\text{Precision}$$\n",
"\n",
"$N$ is the size of `y_true`, $N_p$ is the number of positives in `y_true`, and $N_n$ is number of negatives in `y_true`."
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"FP Rates = [0.95 1. ]\n",
"TP Rates = [1. 0.]\n",
"thresholds = [0.15]\n"
]
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x101a40710>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"fprs, tprs, thresholds_roc = metrics.roc_curve(y_true, model_1_score)\n",
"print(\"FP Rates = {}\".format(precisions))\n",
"print(\"TP Rates = {}\".format(recalls))\n",
"print(\"thresholds = {}\".format(thresholds_pr))\n",
"\n",
"# In matplotlib < 1.5, plt.fill_between does not have a 'step' argument\n",
"step_kwargs = ({'step': 'post'}\n",
" if 'step' in signature(plt.fill_between).parameters\n",
" else {})\n",
"\n",
"plt.plot(fprs, tprs)\n",
"plt.fill_between(fprs, tprs, alpha=0.2, color='blue', label=\"Model_1\")\n",
"\n",
"plt.xlabel('FP Rates')\n",
"plt.ylabel('TP Rates')\n",
"plt.ylim([0.0, 1.05])\n",
"plt.xlim([0.0, 1.0])\n",
"plt.legend()\n",
"plt.title('ROC curve: AUCROC={:.3f}'.format(metrics.roc_auc_score(y_true, model_1_score)))\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Similarly to the PR curve, we only have one threshold $\\left(0.15\\right)$. The threshold $1.15$ is arbitrarily added as $\\max\\left(\\text{y_score}\\right) + 1$.\n",
"\n",
"At this threshold, all predictions are positive. Therefore,\n",
"\n",
"$$\\text{TPR} = \\frac{\\text{TP}}{\\text{TP} + \\text{FN}} = \\frac{19}{19 + 0} = 1$$\n",
"\n",
"$$\\text{FPR} = \\frac{\\text{FP}}{\\text{FP} + \\text{TN}} = \\frac{1}{1 + 0} = 1$$\n",
"\n",
"Then, sklearn added an arbitrary point $\\left(0, 0\\right)$ to make sure the curve starts from there (see the [code](https://github.com/scikit-learn/scikit-learn/blob/bac89c253b35a8f1a3827389fbee0f5bebcbc985/sklearn/metrics/ranking.py#L633), where our `tps = [19]` and `fps = [1]`).\n",
"\n",
"Eventually, we always can get a line linking $\\left(0,0\\right), \\left(1,1\\right)$ , and its AUC approximation is 0.5.\n",
"\n",
"This result holds as long as `y_score` has only one unique value, regardless of the values of `y_true` and `y_score`."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment