Skip to content

Instantly share code, notes, and snippets.

@ShokuninSan
Last active March 28, 2021 03:37
Show Gist options
  • Save ShokuninSan/f2696379614933913e68b0327e2fd71d to your computer and use it in GitHub Desktop.
Save ShokuninSan/f2696379614933913e68b0327e2fd71d to your computer and use it in GitHub Desktop.
pymc3.10-classification-using-BART.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "pymc3.10-classification-using-BART.ipynb",
"provenance": [],
"collapsed_sections": [],
"authorship_tag": "ABX9TyPKjf9aieEeS571emG2sBTJ",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/ShokuninSan/f2696379614933913e68b0327e2fd71d/pymc3-10-bart-experiments.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Vi0lNfG_mLqQ",
"outputId": "584563bf-2ca8-4fd0-d7e3-7ce026ccd4a3"
},
"source": [
"!pip install pymc3==3.10"
],
"execution_count": 22,
"outputs": [
{
"output_type": "stream",
"text": [
"Requirement already satisfied: pymc3==3.10 in /usr/local/lib/python3.6/dist-packages (3.10.0)\n",
"Requirement already satisfied: scipy>=0.18.1 in /usr/local/lib/python3.6/dist-packages (from pymc3==3.10) (1.4.1)\n",
"Requirement already satisfied: theano-pymc==1.0.11 in /usr/local/lib/python3.6/dist-packages (from pymc3==3.10) (1.0.11)\n",
"Requirement already satisfied: typing-extensions>=3.7.4 in /usr/local/lib/python3.6/dist-packages (from pymc3==3.10) (3.7.4.3)\n",
"Requirement already satisfied: numpy>=1.13.0 in /usr/local/lib/python3.6/dist-packages (from pymc3==3.10) (1.18.5)\n",
"Requirement already satisfied: arviz>=0.9.0 in /usr/local/lib/python3.6/dist-packages (from pymc3==3.10) (0.10.0)\n",
"Requirement already satisfied: patsy>=0.5.1 in /usr/local/lib/python3.6/dist-packages (from pymc3==3.10) (0.5.1)\n",
"Requirement already satisfied: dill in /usr/local/lib/python3.6/dist-packages (from pymc3==3.10) (0.3.3)\n",
"Requirement already satisfied: pandas>=0.18.0 in /usr/local/lib/python3.6/dist-packages (from pymc3==3.10) (1.1.5)\n",
"Requirement already satisfied: fastprogress>=0.2.0 in /usr/local/lib/python3.6/dist-packages (from pymc3==3.10) (1.0.0)\n",
"Requirement already satisfied: contextvars; python_version < \"3.7\" in /usr/local/lib/python3.6/dist-packages (from pymc3==3.10) (2.4)\n",
"Requirement already satisfied: dataclasses; python_version < \"3.7\" in /usr/local/lib/python3.6/dist-packages (from pymc3==3.10) (0.8)\n",
"Requirement already satisfied: matplotlib>=3.0 in /usr/local/lib/python3.6/dist-packages (from arviz>=0.9.0->pymc3==3.10) (3.2.2)\n",
"Requirement already satisfied: xarray>=0.16.1 in /usr/local/lib/python3.6/dist-packages (from arviz>=0.9.0->pymc3==3.10) (0.16.2)\n",
"Requirement already satisfied: netcdf4 in /usr/local/lib/python3.6/dist-packages (from arviz>=0.9.0->pymc3==3.10) (1.5.5)\n",
"Requirement already satisfied: packaging in /usr/local/lib/python3.6/dist-packages (from arviz>=0.9.0->pymc3==3.10) (20.7)\n",
"Requirement already satisfied: setuptools>=38.4 in /usr/local/lib/python3.6/dist-packages (from arviz>=0.9.0->pymc3==3.10) (50.3.2)\n",
"Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from patsy>=0.5.1->pymc3==3.10) (1.15.0)\n",
"Requirement already satisfied: python-dateutil>=2.7.3 in /usr/local/lib/python3.6/dist-packages (from pandas>=0.18.0->pymc3==3.10) (2.8.1)\n",
"Requirement already satisfied: pytz>=2017.2 in /usr/local/lib/python3.6/dist-packages (from pandas>=0.18.0->pymc3==3.10) (2018.9)\n",
"Requirement already satisfied: immutables>=0.9 in /usr/local/lib/python3.6/dist-packages (from contextvars; python_version < \"3.7\"->pymc3==3.10) (0.14)\n",
"Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.6/dist-packages (from matplotlib>=3.0->arviz>=0.9.0->pymc3==3.10) (0.10.0)\n",
"Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib>=3.0->arviz>=0.9.0->pymc3==3.10) (1.3.1)\n",
"Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib>=3.0->arviz>=0.9.0->pymc3==3.10) (2.4.7)\n",
"Requirement already satisfied: cftime in /usr/local/lib/python3.6/dist-packages (from netcdf4->arviz>=0.9.0->pymc3==3.10) (1.3.0)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "aI2ce03DDc9u"
},
"source": [
"import theano\n",
"import numpy as np\n",
"import pymc3 as pm\n",
"import matplotlib.pyplot as plt\n",
"\n",
"from sklearn.metrics import classification_report\n",
"from sklearn.datasets import load_breast_cancer\n",
"from sklearn.model_selection import train_test_split"
],
"execution_count": 23,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "zxqywt4lah7Q"
},
"source": [
"X, Y = load_breast_cancer(return_X_y=True)"
],
"execution_count": 24,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "jy_YzYUboKWp"
},
"source": [
"X_train, X_test, Y_train, Y_test = train_test_split(X, Y, stratify=Y)"
],
"execution_count": 25,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Hn8t_culayDW",
"outputId": "66c68b36-4575-4e95-def6-2ea99812242f"
},
"source": [
"X_train.shape, Y_train.shape, X_test.shape, Y_test.shape"
],
"execution_count": 26,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"((426, 30), (426,), (143, 30), (143,))"
]
},
"metadata": {
"tags": []
},
"execution_count": 26
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "JmnrHgOrwpjV"
},
"source": [
"## Using `BART`"
]
},
{
"cell_type": "code",
"metadata": {
"id": "VJABwZPvEtu6"
},
"source": [
"X_shared = theano.shared(X_train)"
],
"execution_count": 27,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 196
},
"id": "Ic6uQSGPaz11",
"outputId": "0dccd868-5d5b-41b0-d133-42e333bb6e1f"
},
"source": [
"with pm.Model() as model:\n",
" \n",
" x = pm.BART('x', X_shared.get_value(), Y_train)\n",
" y = pm.Bernoulli('y', p=pm.math.sigmoid(x), observed=Y_train)\n",
"\n",
" trace = pm.sample()"
],
"execution_count": 28,
"outputs": [
{
"output_type": "stream",
"text": [
"/usr/local/lib/python3.6/dist-packages/pymc3/sampling.py:468: FutureWarning: In an upcoming release, pm.sample will return an `arviz.InferenceData` object instead of a `MultiTrace` by default. You can pass return_inferencedata=True or return_inferencedata=False to be safe and silence this warning.\n",
" FutureWarning,\n",
"The BART model is experimental. Use with caution.\n",
"Sequential sampling (2 chains in 1 job)\n",
"PGBART: [x]\n"
],
"name": "stderr"
},
{
"output_type": "display_data",
"data": {
"text/html": [
"\n",
" <div>\n",
" <style>\n",
" /* Turns off some styling */\n",
" progress {\n",
" /* gets rid of default border in Firefox and Opera. */\n",
" border: none;\n",
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
" background-size: auto;\n",
" }\n",
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
" background: #F44336;\n",
" }\n",
" </style>\n",
" <progress value='2000' class='' max='2000' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
" 100.00% [2000/2000 05:08<00:00 Sampling chain 0, 0 divergences]\n",
" </div>\n",
" "
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "display_data",
"data": {
"text/html": [
"\n",
" <div>\n",
" <style>\n",
" /* Turns off some styling */\n",
" progress {\n",
" /* gets rid of default border in Firefox and Opera. */\n",
" border: none;\n",
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
" background-size: auto;\n",
" }\n",
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
" background: #F44336;\n",
" }\n",
" </style>\n",
" <progress value='2000' class='' max='2000' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
" 100.00% [2000/2000 05:08<00:00 Sampling chain 1, 0 divergences]\n",
" </div>\n",
" "
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "stream",
"text": [
"Sampling 2 chains for 1_000 tune and 1_000 draw iterations (2_000 + 2_000 draws total) took 617 seconds.\n",
"The estimated number of effective samples is smaller than 200 for some parameters.\n"
],
"name": "stderr"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "RaCv3m3iFuaX"
},
"source": [
"X_shared.set_value(X_test)"
],
"execution_count": 29,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 37
},
"id": "-ZB5MU5ltkHv",
"outputId": "39018170-798a-4d8d-941a-0ea14c60cdd1"
},
"source": [
"with model:\n",
" ppc = pm.sample_posterior_predictive(trace)"
],
"execution_count": 30,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/html": [
"\n",
" <div>\n",
" <style>\n",
" /* Turns off some styling */\n",
" progress {\n",
" /* gets rid of default border in Firefox and Opera. */\n",
" border: none;\n",
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
" background-size: auto;\n",
" }\n",
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
" background: #F44336;\n",
" }\n",
" </style>\n",
" <progress value='2000' class='' max='2000' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
" 100.00% [2000/2000 00:01<00:00]\n",
" </div>\n",
" "
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {
"tags": []
}
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "s29AtTLFlrVK"
},
"source": [
"posterior = ppc.get('y')[:, :X_test.shape[0]] # maybe it's the the value of `shape` which isn't correct?"
],
"execution_count": 31,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "GeuruOfSc5n3",
"outputId": "71bf018c-af32-4d70-f8c5-e21bbff04ca7"
},
"source": [
"posterior.shape"
],
"execution_count": 32,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(2000, 143)"
]
},
"metadata": {
"tags": []
},
"execution_count": 32
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "mcE6Fh7ADjdw",
"outputId": "8c25196e-52e4-4d5f-a4c0-8156f13f810a"
},
"source": [
"print(classification_report(Y_test, np.round(posterior.mean(0))))"
],
"execution_count": 33,
"outputs": [
{
"output_type": "stream",
"text": [
" precision recall f1-score support\n",
"\n",
" 0 0.32 0.15 0.21 53\n",
" 1 0.62 0.81 0.70 90\n",
"\n",
" accuracy 0.57 143\n",
" macro avg 0.47 0.48 0.45 143\n",
"weighted avg 0.51 0.57 0.52 143\n",
"\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "tNI9UQK_sEbh"
},
"source": [
"## Using sklearn's `RandomForest`"
]
},
{
"cell_type": "code",
"metadata": {
"id": "adSl86ZvHlI5"
},
"source": [
"from sklearn.ensemble import RandomForestClassifier"
],
"execution_count": 34,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "z6rNQ38isOEE"
},
"source": [
"rf = RandomForestClassifier()"
],
"execution_count": 35,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "ZOa48lOtsQov",
"outputId": "9ae8795a-1d14-4246-91cc-c5e92263013e"
},
"source": [
"rf.fit(X_train, Y_train)"
],
"execution_count": 36,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"RandomForestClassifier(bootstrap=True, ccp_alpha=0.0, class_weight=None,\n",
" criterion='gini', max_depth=None, max_features='auto',\n",
" max_leaf_nodes=None, max_samples=None,\n",
" min_impurity_decrease=0.0, min_impurity_split=None,\n",
" min_samples_leaf=1, min_samples_split=2,\n",
" min_weight_fraction_leaf=0.0, n_estimators=100,\n",
" n_jobs=None, oob_score=False, random_state=None,\n",
" verbose=0, warm_start=False)"
]
},
"metadata": {
"tags": []
},
"execution_count": 36
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "KFEYb--EsS6c"
},
"source": [
"y_pred = rf.predict(X_test)"
],
"execution_count": 37,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "WR8kaEcTsYuV",
"outputId": "a9973957-834f-4747-ad1d-08a16d0223c9"
},
"source": [
"print(classification_report(Y_test, y_pred))"
],
"execution_count": 38,
"outputs": [
{
"output_type": "stream",
"text": [
" precision recall f1-score support\n",
"\n",
" 0 1.00 0.91 0.95 53\n",
" 1 0.95 1.00 0.97 90\n",
"\n",
" accuracy 0.97 143\n",
" macro avg 0.97 0.95 0.96 143\n",
"weighted avg 0.97 0.97 0.96 143\n",
"\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "hwszgI7Jsdoo"
},
"source": [
""
],
"execution_count": 38,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment