Last active
March 28, 2021 03:37
-
-
Save ShokuninSan/f2696379614933913e68b0327e2fd71d to your computer and use it in GitHub Desktop.
pymc3.10-classification-using-BART.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"name": "pymc3.10-classification-using-BART.ipynb", | |
"provenance": [], | |
"collapsed_sections": [], | |
"authorship_tag": "ABX9TyPKjf9aieEeS571emG2sBTJ", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/ShokuninSan/f2696379614933913e68b0327e2fd71d/pymc3-10-bart-experiments.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "Vi0lNfG_mLqQ", | |
"outputId": "584563bf-2ca8-4fd0-d7e3-7ce026ccd4a3" | |
}, | |
"source": [ | |
"!pip install pymc3==3.10" | |
], | |
"execution_count": 22, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Requirement already satisfied: pymc3==3.10 in /usr/local/lib/python3.6/dist-packages (3.10.0)\n", | |
"Requirement already satisfied: scipy>=0.18.1 in /usr/local/lib/python3.6/dist-packages (from pymc3==3.10) (1.4.1)\n", | |
"Requirement already satisfied: theano-pymc==1.0.11 in /usr/local/lib/python3.6/dist-packages (from pymc3==3.10) (1.0.11)\n", | |
"Requirement already satisfied: typing-extensions>=3.7.4 in /usr/local/lib/python3.6/dist-packages (from pymc3==3.10) (3.7.4.3)\n", | |
"Requirement already satisfied: numpy>=1.13.0 in /usr/local/lib/python3.6/dist-packages (from pymc3==3.10) (1.18.5)\n", | |
"Requirement already satisfied: arviz>=0.9.0 in /usr/local/lib/python3.6/dist-packages (from pymc3==3.10) (0.10.0)\n", | |
"Requirement already satisfied: patsy>=0.5.1 in /usr/local/lib/python3.6/dist-packages (from pymc3==3.10) (0.5.1)\n", | |
"Requirement already satisfied: dill in /usr/local/lib/python3.6/dist-packages (from pymc3==3.10) (0.3.3)\n", | |
"Requirement already satisfied: pandas>=0.18.0 in /usr/local/lib/python3.6/dist-packages (from pymc3==3.10) (1.1.5)\n", | |
"Requirement already satisfied: fastprogress>=0.2.0 in /usr/local/lib/python3.6/dist-packages (from pymc3==3.10) (1.0.0)\n", | |
"Requirement already satisfied: contextvars; python_version < \"3.7\" in /usr/local/lib/python3.6/dist-packages (from pymc3==3.10) (2.4)\n", | |
"Requirement already satisfied: dataclasses; python_version < \"3.7\" in /usr/local/lib/python3.6/dist-packages (from pymc3==3.10) (0.8)\n", | |
"Requirement already satisfied: matplotlib>=3.0 in /usr/local/lib/python3.6/dist-packages (from arviz>=0.9.0->pymc3==3.10) (3.2.2)\n", | |
"Requirement already satisfied: xarray>=0.16.1 in /usr/local/lib/python3.6/dist-packages (from arviz>=0.9.0->pymc3==3.10) (0.16.2)\n", | |
"Requirement already satisfied: netcdf4 in /usr/local/lib/python3.6/dist-packages (from arviz>=0.9.0->pymc3==3.10) (1.5.5)\n", | |
"Requirement already satisfied: packaging in /usr/local/lib/python3.6/dist-packages (from arviz>=0.9.0->pymc3==3.10) (20.7)\n", | |
"Requirement already satisfied: setuptools>=38.4 in /usr/local/lib/python3.6/dist-packages (from arviz>=0.9.0->pymc3==3.10) (50.3.2)\n", | |
"Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from patsy>=0.5.1->pymc3==3.10) (1.15.0)\n", | |
"Requirement already satisfied: python-dateutil>=2.7.3 in /usr/local/lib/python3.6/dist-packages (from pandas>=0.18.0->pymc3==3.10) (2.8.1)\n", | |
"Requirement already satisfied: pytz>=2017.2 in /usr/local/lib/python3.6/dist-packages (from pandas>=0.18.0->pymc3==3.10) (2018.9)\n", | |
"Requirement already satisfied: immutables>=0.9 in /usr/local/lib/python3.6/dist-packages (from contextvars; python_version < \"3.7\"->pymc3==3.10) (0.14)\n", | |
"Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.6/dist-packages (from matplotlib>=3.0->arviz>=0.9.0->pymc3==3.10) (0.10.0)\n", | |
"Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib>=3.0->arviz>=0.9.0->pymc3==3.10) (1.3.1)\n", | |
"Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib>=3.0->arviz>=0.9.0->pymc3==3.10) (2.4.7)\n", | |
"Requirement already satisfied: cftime in /usr/local/lib/python3.6/dist-packages (from netcdf4->arviz>=0.9.0->pymc3==3.10) (1.3.0)\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "aI2ce03DDc9u" | |
}, | |
"source": [ | |
"import theano\n", | |
"import numpy as np\n", | |
"import pymc3 as pm\n", | |
"import matplotlib.pyplot as plt\n", | |
"\n", | |
"from sklearn.metrics import classification_report\n", | |
"from sklearn.datasets import load_breast_cancer\n", | |
"from sklearn.model_selection import train_test_split" | |
], | |
"execution_count": 23, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "zxqywt4lah7Q" | |
}, | |
"source": [ | |
"X, Y = load_breast_cancer(return_X_y=True)" | |
], | |
"execution_count": 24, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "jy_YzYUboKWp" | |
}, | |
"source": [ | |
"X_train, X_test, Y_train, Y_test = train_test_split(X, Y, stratify=Y)" | |
], | |
"execution_count": 25, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "Hn8t_culayDW", | |
"outputId": "66c68b36-4575-4e95-def6-2ea99812242f" | |
}, | |
"source": [ | |
"X_train.shape, Y_train.shape, X_test.shape, Y_test.shape" | |
], | |
"execution_count": 26, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"((426, 30), (426,), (143, 30), (143,))" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 26 | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "JmnrHgOrwpjV" | |
}, | |
"source": [ | |
"## Using `BART`" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "VJABwZPvEtu6" | |
}, | |
"source": [ | |
"X_shared = theano.shared(X_train)" | |
], | |
"execution_count": 27, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 196 | |
}, | |
"id": "Ic6uQSGPaz11", | |
"outputId": "0dccd868-5d5b-41b0-d133-42e333bb6e1f" | |
}, | |
"source": [ | |
"with pm.Model() as model:\n", | |
" \n", | |
" x = pm.BART('x', X_shared.get_value(), Y_train)\n", | |
" y = pm.Bernoulli('y', p=pm.math.sigmoid(x), observed=Y_train)\n", | |
"\n", | |
" trace = pm.sample()" | |
], | |
"execution_count": 28, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"/usr/local/lib/python3.6/dist-packages/pymc3/sampling.py:468: FutureWarning: In an upcoming release, pm.sample will return an `arviz.InferenceData` object instead of a `MultiTrace` by default. You can pass return_inferencedata=True or return_inferencedata=False to be safe and silence this warning.\n", | |
" FutureWarning,\n", | |
"The BART model is experimental. Use with caution.\n", | |
"Sequential sampling (2 chains in 1 job)\n", | |
"PGBART: [x]\n" | |
], | |
"name": "stderr" | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/html": [ | |
"\n", | |
" <div>\n", | |
" <style>\n", | |
" /* Turns off some styling */\n", | |
" progress {\n", | |
" /* gets rid of default border in Firefox and Opera. */\n", | |
" border: none;\n", | |
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n", | |
" background-size: auto;\n", | |
" }\n", | |
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n", | |
" background: #F44336;\n", | |
" }\n", | |
" </style>\n", | |
" <progress value='2000' class='' max='2000' style='width:300px; height:20px; vertical-align: middle;'></progress>\n", | |
" 100.00% [2000/2000 05:08<00:00 Sampling chain 0, 0 divergences]\n", | |
" </div>\n", | |
" " | |
], | |
"text/plain": [ | |
"<IPython.core.display.HTML object>" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
} | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/html": [ | |
"\n", | |
" <div>\n", | |
" <style>\n", | |
" /* Turns off some styling */\n", | |
" progress {\n", | |
" /* gets rid of default border in Firefox and Opera. */\n", | |
" border: none;\n", | |
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n", | |
" background-size: auto;\n", | |
" }\n", | |
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n", | |
" background: #F44336;\n", | |
" }\n", | |
" </style>\n", | |
" <progress value='2000' class='' max='2000' style='width:300px; height:20px; vertical-align: middle;'></progress>\n", | |
" 100.00% [2000/2000 05:08<00:00 Sampling chain 1, 0 divergences]\n", | |
" </div>\n", | |
" " | |
], | |
"text/plain": [ | |
"<IPython.core.display.HTML object>" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
} | |
}, | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Sampling 2 chains for 1_000 tune and 1_000 draw iterations (2_000 + 2_000 draws total) took 617 seconds.\n", | |
"The estimated number of effective samples is smaller than 200 for some parameters.\n" | |
], | |
"name": "stderr" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "RaCv3m3iFuaX" | |
}, | |
"source": [ | |
"X_shared.set_value(X_test)" | |
], | |
"execution_count": 29, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 37 | |
}, | |
"id": "-ZB5MU5ltkHv", | |
"outputId": "39018170-798a-4d8d-941a-0ea14c60cdd1" | |
}, | |
"source": [ | |
"with model:\n", | |
" ppc = pm.sample_posterior_predictive(trace)" | |
], | |
"execution_count": 30, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/html": [ | |
"\n", | |
" <div>\n", | |
" <style>\n", | |
" /* Turns off some styling */\n", | |
" progress {\n", | |
" /* gets rid of default border in Firefox and Opera. */\n", | |
" border: none;\n", | |
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n", | |
" background-size: auto;\n", | |
" }\n", | |
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n", | |
" background: #F44336;\n", | |
" }\n", | |
" </style>\n", | |
" <progress value='2000' class='' max='2000' style='width:300px; height:20px; vertical-align: middle;'></progress>\n", | |
" 100.00% [2000/2000 00:01<00:00]\n", | |
" </div>\n", | |
" " | |
], | |
"text/plain": [ | |
"<IPython.core.display.HTML object>" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
} | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "s29AtTLFlrVK" | |
}, | |
"source": [ | |
"posterior = ppc.get('y')[:, :X_test.shape[0]] # maybe it's the the value of `shape` which isn't correct?" | |
], | |
"execution_count": 31, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "GeuruOfSc5n3", | |
"outputId": "71bf018c-af32-4d70-f8c5-e21bbff04ca7" | |
}, | |
"source": [ | |
"posterior.shape" | |
], | |
"execution_count": 32, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"(2000, 143)" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 32 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "mcE6Fh7ADjdw", | |
"outputId": "8c25196e-52e4-4d5f-a4c0-8156f13f810a" | |
}, | |
"source": [ | |
"print(classification_report(Y_test, np.round(posterior.mean(0))))" | |
], | |
"execution_count": 33, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
" precision recall f1-score support\n", | |
"\n", | |
" 0 0.32 0.15 0.21 53\n", | |
" 1 0.62 0.81 0.70 90\n", | |
"\n", | |
" accuracy 0.57 143\n", | |
" macro avg 0.47 0.48 0.45 143\n", | |
"weighted avg 0.51 0.57 0.52 143\n", | |
"\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "tNI9UQK_sEbh" | |
}, | |
"source": [ | |
"## Using sklearn's `RandomForest`" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "adSl86ZvHlI5" | |
}, | |
"source": [ | |
"from sklearn.ensemble import RandomForestClassifier" | |
], | |
"execution_count": 34, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "z6rNQ38isOEE" | |
}, | |
"source": [ | |
"rf = RandomForestClassifier()" | |
], | |
"execution_count": 35, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "ZOa48lOtsQov", | |
"outputId": "9ae8795a-1d14-4246-91cc-c5e92263013e" | |
}, | |
"source": [ | |
"rf.fit(X_train, Y_train)" | |
], | |
"execution_count": 36, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"RandomForestClassifier(bootstrap=True, ccp_alpha=0.0, class_weight=None,\n", | |
" criterion='gini', max_depth=None, max_features='auto',\n", | |
" max_leaf_nodes=None, max_samples=None,\n", | |
" min_impurity_decrease=0.0, min_impurity_split=None,\n", | |
" min_samples_leaf=1, min_samples_split=2,\n", | |
" min_weight_fraction_leaf=0.0, n_estimators=100,\n", | |
" n_jobs=None, oob_score=False, random_state=None,\n", | |
" verbose=0, warm_start=False)" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 36 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "KFEYb--EsS6c" | |
}, | |
"source": [ | |
"y_pred = rf.predict(X_test)" | |
], | |
"execution_count": 37, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "WR8kaEcTsYuV", | |
"outputId": "a9973957-834f-4747-ad1d-08a16d0223c9" | |
}, | |
"source": [ | |
"print(classification_report(Y_test, y_pred))" | |
], | |
"execution_count": 38, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
" precision recall f1-score support\n", | |
"\n", | |
" 0 1.00 0.91 0.95 53\n", | |
" 1 0.95 1.00 0.97 90\n", | |
"\n", | |
" accuracy 0.97 143\n", | |
" macro avg 0.97 0.95 0.96 143\n", | |
"weighted avg 0.97 0.97 0.96 143\n", | |
"\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "hwszgI7Jsdoo" | |
}, | |
"source": [ | |
"" | |
], | |
"execution_count": 38, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment