Skip to content

Instantly share code, notes, and snippets.

@oguiza
Created May 3, 2020 18:30
Show Gist options
  • Save oguiza/6165459fae6e8e59880bb125abea970b to your computer and use it in GitHub Desktop.
Save oguiza/6165459fae6e8e59880bb125abea970b to your computer and use it in GitHub Desktop.
Proba_Metrics.ipynb
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 52
},
"colab_type": "code",
"id": "JIpVVG_gMx50",
"outputId": "13e48141-ffda-4d3f-ad03-f2d7dd4042cf",
"trusted": false
},
"cell_type": "code",
"source": "!pip install git+https://github.com/fastai/fastcore.git@master -q\n!pip install git+https://github.com/fastai/fastai2.git@master -q",
"execution_count": 1,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": " Building wheel for fastcore (setup.py) ... \u001b[?25l\u001b[?25hdone\n Building wheel for fastai2 (setup.py) ... \u001b[?25l\u001b[?25hdone\n"
}
]
},
{
"metadata": {
"colab": {},
"colab_type": "code",
"id": "F-vR9U_vRzbB",
"trusted": false
},
"cell_type": "code",
"source": "from fastai2.vision.all import *\nfrom fastai.metrics import *\nfrom sklearn import metrics as skm\nset_seed(2)",
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 16
},
"colab_type": "code",
"id": "uajPS5RlxjrI",
"outputId": "d6efb30d-43c6-4869-b658-bbf24303aa6f",
"trusted": false
},
"cell_type": "code",
"source": "path = untar_data(URLs.MNIST_TINY)\ndls = ImageDataLoaders.from_folder(path)",
"execution_count": 14,
"outputs": [
{
"data": {
"text/html": "",
"text/plain": "<IPython.core.display.HTML object>"
},
"metadata": {
"tags": []
},
"output_type": "display_data"
}
]
},
{
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 197
},
"colab_type": "code",
"id": "IYb6cStePQ9c",
"outputId": "0ca63966-8415-4154-9b4c-1d4cb6722b3e",
"trusted": false
},
"cell_type": "code",
"source": "learn = cnn_learner(dls, resnet18, pretrained=False, metrics=[accuracy, APScore(), RocAuc()])\nlearn.fit_one_cycle(5, 0.1)",
"execution_count": 18,
"outputs": [
{
"data": {
"text/html": "<table border=\"1\" class=\"dataframe\">\n <thead>\n <tr style=\"text-align: left;\">\n <th>epoch</th>\n <th>train_loss</th>\n <th>valid_loss</th>\n <th>accuracy</th>\n <th>average_precision_score</th>\n <th>roc_auc_score</th>\n <th>time</th>\n </tr>\n </thead>\n <tbody>\n <tr>\n <td>0</td>\n <td>1.592153</td>\n <td>279784.062500</td>\n <td>0.505007</td>\n <td>0.505007</td>\n <td>0.500000</td>\n <td>00:01</td>\n </tr>\n <tr>\n <td>1</td>\n <td>0.858801</td>\n <td>360.128876</td>\n <td>0.494993</td>\n <td>0.505007</td>\n <td>0.500000</td>\n <td>00:01</td>\n </tr>\n <tr>\n <td>2</td>\n <td>0.638751</td>\n <td>20.427900</td>\n <td>0.590844</td>\n <td>0.598958</td>\n <td>0.594901</td>\n <td>00:01</td>\n </tr>\n <tr>\n <td>3</td>\n <td>0.585736</td>\n <td>0.060126</td>\n <td>0.991416</td>\n <td>0.983287</td>\n <td>0.991329</td>\n <td>00:01</td>\n </tr>\n <tr>\n <td>4</td>\n <td>0.457349</td>\n <td>0.040091</td>\n <td>0.995708</td>\n <td>0.991573</td>\n <td>0.995665</td>\n <td>00:01</td>\n </tr>\n </tbody>\n</table>",
"text/plain": "<IPython.core.display.HTML object>"
},
"metadata": {
"tags": []
},
"output_type": "display_data"
}
]
},
{
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 16
},
"colab_type": "code",
"id": "JVldLRaEM8sE",
"outputId": "ac5857f3-ca67-4681-c9d1-af570df616b9",
"trusted": false
},
"cell_type": "code",
"source": "valid_probas, valid_targets, valid_preds = learn.get_preds(dl=dls.valid, with_decoded=True)",
"execution_count": 22,
"outputs": [
{
"data": {
"text/html": "",
"text/plain": "<IPython.core.display.HTML object>"
},
"metadata": {
"tags": []
},
"output_type": "display_data"
}
]
},
{
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 70
},
"colab_type": "code",
"id": "BaG-LDJwQncp",
"outputId": "95274aea-5969-4131-cb36-43db6c33ef97",
"trusted": false
},
"cell_type": "code",
"source": "# APScore and RocAuc calculated based on probas\nprint(f'accuracy : {skm.accuracy_score(valid_targets, valid_preds):8.6f}')\nprint(f'avg precision : {skm.average_precision_score(valid_targets, valid_probas[:, 1]):8.6f}')\nprint(f'roc auc : {skm.roc_auc_score(valid_targets, valid_probas[:, 1]):8.6f}')",
"execution_count": 20,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": "accuracy : 0.995708\navg precision : 0.999944\nroc auc : 0.999943\n"
}
]
},
{
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 70
},
"colab_type": "code",
"id": "8hgemM8exHj9",
"outputId": "6883a657-7e61-42d4-ef67-e313ffc13598",
"trusted": false
},
"cell_type": "code",
"source": "# APScore and RocAuc calculated based on preds - this are the reported ones\nprint(f'accuracy : {skm.accuracy_score(valid_targets, valid_preds):8.6f}')\nprint(f'avg precision : {skm.average_precision_score(valid_targets, valid_preds):8.6f}')\nprint(f'roc auc : {skm.roc_auc_score(valid_targets, valid_preds):8.6f}')",
"execution_count": 23,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": "accuracy : 0.995708\navg precision : 0.991573\nroc auc : 0.995665\n"
}
]
},
{
"metadata": {
"colab_type": "text",
"id": "J3KNheAIyduc"
},
"cell_type": "markdown",
"source": "The reported metrics match all metrics that need to be calculated based on preds, but are not correct for those metrics that must be calculated based on probas, like APScore and RocAuc."
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"name": "Proba_Metrics.ipynb",
"provenance": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3",
"language": "python"
},
"toc": {
"nav_menu": {},
"number_sections": true,
"sideBar": true,
"skip_h1_title": false,
"base_numbering": 1,
"title_cell": "Table of Contents",
"title_sidebar": "Contents",
"toc_cell": false,
"toc_position": {},
"toc_section_display": true,
"toc_window_display": false
},
"language_info": {
"name": "python",
"version": "3.7.3",
"mimetype": "text/x-python",
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"pygments_lexer": "ipython3",
"nbconvert_exporter": "python",
"file_extension": ".py"
},
"gist": {
"id": "",
"data": {
"description": "Proba_Metrics.ipynb",
"public": true
}
}
},
"nbformat": 4,
"nbformat_minor": 1
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment