Skip to content

Instantly share code, notes, and snippets.

@jamm1985
Created April 20, 2022 13:12
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jamm1985/3f92508c378022803fb94cc95a22d3e9 to your computer and use it in GitHub Desktop.
Save jamm1985/3f92508c378022803fb94cc95a22d3e9 to your computer and use it in GitHub Desktop.
Lab_16_intro_to_ML_decision_trees.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "Lab_16_intro_to_ML_decision_trees.ipynb",
"provenance": [],
"authorship_tag": "ABX9TyPV4wWPSiWyipiwnXU3RbL0",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/jamm1985/3f92508c378022803fb94cc95a22d3e9/lab_16_intro_to_ml_decision_trees.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"source": [
"Видео лабораторной: https://youtu.be/3N0Uz33_wFI\n",
"\n",
"TG: https://t.me/data_science_news\n",
"\n",
"\n",
"\n",
"---"
],
"metadata": {
"id": "HZsb7hnxGXX3"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "AQtJeVqMHMjq"
},
"outputs": [],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"import matplotlib.pylab as plt\n",
"\n",
"from sklearn import tree\n",
"from sklearn.ensemble import BaggingRegressor\n",
"from sklearn.ensemble import BaggingClassifier\n",
"from sklearn.ensemble import RandomForestRegressor\n",
"from sklearn.ensemble import RandomForestClassifier\n",
"from sklearn.ensemble import GradientBoostingRegressor\n",
"from sklearn.ensemble import GradientBoostingClassifier\n",
"\n",
"from sklearn.model_selection import cross_val_score\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"from sklearn.metrics import mean_squared_error\n",
"from sklearn.metrics import r2_score"
]
},
{
"cell_type": "markdown",
"source": [
"# Набор данных для регрессии\n",
"\n",
"[Лабораторная №13](https://youtu.be/A3LE-ZmtVGs)\n",
"\n",
"Model Name | parameters | $r^2$ | Mean Squared Error|\n",
"----------------|------------|--------------|-------------------|\n",
"LR | $\\bf4$ |$0.89\\pm0.04$ |$3.07\\pm1.28$ |\n",
"LR poly 2 | $10$ |$0.98\\pm0.01$ |$0.44\\pm0.39$ |\n",
"LR poly 3 | $20$ |$\\bf0.99\\pm0.01$ |$\\bf0.31\\pm0.24$ |\n",
"NN | $185$ |$0.91\\pm1.61$ |$1.86\\pm1.49$ |\n"
],
"metadata": {
"id": "9UcoBihHRodL"
}
},
{
"cell_type": "code",
"source": [
"!wget https://raw.githubusercontent.com/nguyen-toan/ISLR/master/dataset/Advertising.csv\n",
"!head Advertising.csv"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "hsVgffgRRbkR",
"outputId": "0c0c54fc-892e-4a9d-b2fa-4ad04998471c"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"--2022-04-20 05:24:38-- https://raw.githubusercontent.com/nguyen-toan/ISLR/master/dataset/Advertising.csv\n",
"Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...\n",
"Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 5166 (5.0K) [text/plain]\n",
"Saving to: ‘Advertising.csv’\n",
"\n",
"\rAdvertising.csv 0%[ ] 0 --.-KB/s \rAdvertising.csv 100%[===================>] 5.04K --.-KB/s in 0s \n",
"\n",
"2022-04-20 05:24:38 (45.8 MB/s) - ‘Advertising.csv’ saved [5166/5166]\n",
"\n",
"\"\",\"TV\",\"Radio\",\"Newspaper\",\"Sales\"\n",
"\"1\",230.1,37.8,69.2,22.1\n",
"\"2\",44.5,39.3,45.1,10.4\n",
"\"3\",17.2,45.9,69.3,9.3\n",
"\"4\",151.5,41.3,58.5,18.5\n",
"\"5\",180.8,10.8,58.4,12.9\n",
"\"6\",8.7,48.9,75,7.2\n",
"\"7\",57.5,32.8,23.5,11.8\n",
"\"8\",120.2,19.6,11.6,13.2\n",
"\"9\",8.6,2.1,1,4.8\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"REGRESSION = pd.read_csv('Advertising.csv')\n",
"REGRESSION = REGRESSION.drop(columns=['Unnamed: 0'])\n",
"REGRESSION"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 423
},
"id": "9V6a71yRR4Bf",
"outputId": "33798014-aeaf-4390-dfc4-42b7dc6d1008"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" TV Radio Newspaper Sales\n",
"0 230.1 37.8 69.2 22.1\n",
"1 44.5 39.3 45.1 10.4\n",
"2 17.2 45.9 69.3 9.3\n",
"3 151.5 41.3 58.5 18.5\n",
"4 180.8 10.8 58.4 12.9\n",
".. ... ... ... ...\n",
"195 38.2 3.7 13.8 7.6\n",
"196 94.2 4.9 8.1 9.7\n",
"197 177.0 9.3 6.4 12.8\n",
"198 283.6 42.0 66.2 25.5\n",
"199 232.1 8.6 8.7 13.4\n",
"\n",
"[200 rows x 4 columns]"
],
"text/html": [
"\n",
" <div id=\"df-44e8b0a4-2e25-42f6-95bd-4fa83d4e4e51\">\n",
" <div class=\"colab-df-container\">\n",
" <div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>TV</th>\n",
" <th>Radio</th>\n",
" <th>Newspaper</th>\n",
" <th>Sales</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>230.1</td>\n",
" <td>37.8</td>\n",
" <td>69.2</td>\n",
" <td>22.1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>44.5</td>\n",
" <td>39.3</td>\n",
" <td>45.1</td>\n",
" <td>10.4</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>17.2</td>\n",
" <td>45.9</td>\n",
" <td>69.3</td>\n",
" <td>9.3</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>151.5</td>\n",
" <td>41.3</td>\n",
" <td>58.5</td>\n",
" <td>18.5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>180.8</td>\n",
" <td>10.8</td>\n",
" <td>58.4</td>\n",
" <td>12.9</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>195</th>\n",
" <td>38.2</td>\n",
" <td>3.7</td>\n",
" <td>13.8</td>\n",
" <td>7.6</td>\n",
" </tr>\n",
" <tr>\n",
" <th>196</th>\n",
" <td>94.2</td>\n",
" <td>4.9</td>\n",
" <td>8.1</td>\n",
" <td>9.7</td>\n",
" </tr>\n",
" <tr>\n",
" <th>197</th>\n",
" <td>177.0</td>\n",
" <td>9.3</td>\n",
" <td>6.4</td>\n",
" <td>12.8</td>\n",
" </tr>\n",
" <tr>\n",
" <th>198</th>\n",
" <td>283.6</td>\n",
" <td>42.0</td>\n",
" <td>66.2</td>\n",
" <td>25.5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>199</th>\n",
" <td>232.1</td>\n",
" <td>8.6</td>\n",
" <td>8.7</td>\n",
" <td>13.4</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>200 rows × 4 columns</p>\n",
"</div>\n",
" <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-44e8b0a4-2e25-42f6-95bd-4fa83d4e4e51')\"\n",
" title=\"Convert this dataframe to an interactive table.\"\n",
" style=\"display:none;\">\n",
" \n",
" <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
" width=\"24px\">\n",
" <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n",
" <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n",
" </svg>\n",
" </button>\n",
" \n",
" <style>\n",
" .colab-df-container {\n",
" display:flex;\n",
" flex-wrap:wrap;\n",
" gap: 12px;\n",
" }\n",
"\n",
" .colab-df-convert {\n",
" background-color: #E8F0FE;\n",
" border: none;\n",
" border-radius: 50%;\n",
" cursor: pointer;\n",
" display: none;\n",
" fill: #1967D2;\n",
" height: 32px;\n",
" padding: 0 0 0 0;\n",
" width: 32px;\n",
" }\n",
"\n",
" .colab-df-convert:hover {\n",
" background-color: #E2EBFA;\n",
" box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
" fill: #174EA6;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert {\n",
" background-color: #3B4455;\n",
" fill: #D2E3FC;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert:hover {\n",
" background-color: #434B5C;\n",
" box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
" filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
" fill: #FFFFFF;\n",
" }\n",
" </style>\n",
"\n",
" <script>\n",
" const buttonEl =\n",
" document.querySelector('#df-44e8b0a4-2e25-42f6-95bd-4fa83d4e4e51 button.colab-df-convert');\n",
" buttonEl.style.display =\n",
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
"\n",
" async function convertToInteractive(key) {\n",
" const element = document.querySelector('#df-44e8b0a4-2e25-42f6-95bd-4fa83d4e4e51');\n",
" const dataTable =\n",
" await google.colab.kernel.invokeFunction('convertToInteractive',\n",
" [key], {});\n",
" if (!dataTable) return;\n",
"\n",
" const docLinkHtml = 'Like what you see? Visit the ' +\n",
" '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
" + ' to learn more about interactive tables.';\n",
" element.innerHTML = '';\n",
" dataTable['output_type'] = 'display_data';\n",
" await google.colab.output.renderOutput(dataTable, element);\n",
" const docLink = document.createElement('div');\n",
" docLink.innerHTML = docLinkHtml;\n",
" element.appendChild(docLink);\n",
" }\n",
" </script>\n",
" </div>\n",
" </div>\n",
" "
]
},
"metadata": {},
"execution_count": 3
}
]
},
{
"cell_type": "code",
"source": [
"X_reg = REGRESSION.loc[:, REGRESSION.columns != 'Sales'].to_numpy()\n",
"y_reg = REGRESSION['Sales'].to_numpy()\n",
"print(X_reg.shape)\n",
"print(y_reg.shape)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "4ruwQzKfT4wq",
"outputId": "dca54b25-9367-42c4-c213-a4867d5eda63"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"(200, 3)\n",
"(200,)\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"# split regression data\n",
"X_reg_train, X_reg_test, y_reg_train, y_reg_test = train_test_split(\n",
" X_reg, y_reg, test_size=0.33, random_state=0, shuffle=True)\n",
"X_reg_train.shape, X_reg_test.shape"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "NOFjExfQUEaH",
"outputId": "df658d59-ecbd-446a-ef9e-0a827df9b5fc"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"((134, 3), (66, 3))"
]
},
"metadata": {},
"execution_count": 5
}
]
},
{
"cell_type": "markdown",
"source": [
"# Набор данных для классификации\n",
"\n",
"[Лабораторная №15](https://youtu.be/lkiFy6LQnSk)\n",
"\n",
"```\n",
"Logit has 0.9961 OvR AUC with a standard deviation of 0.01\n",
"LDA has 0.9937 OvR AUC with a standard deviation of 0.01\n",
"QDA has 0.9888 OvR AUC with a standard deviation of 0.02\n",
"```"
],
"metadata": {
"id": "cpCrQE49SCQD"
}
},
{
"cell_type": "code",
"source": [
"# download dataset\n",
"# https://archive.ics.uci.edu/ml/datasets/seeds\n",
"!wget https://archive.ics.uci.edu/ml/machine-learning-databases/00236/seeds_dataset.txt\n",
"!head seeds_dataset.txt"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "85CzXy3bSAlA",
"outputId": "63f4f619-42da-4973-fba3-19feb47988ec"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"--2022-04-20 05:26:13-- https://archive.ics.uci.edu/ml/machine-learning-databases/00236/seeds_dataset.txt\n",
"Resolving archive.ics.uci.edu (archive.ics.uci.edu)... 128.195.10.252\n",
"Connecting to archive.ics.uci.edu (archive.ics.uci.edu)|128.195.10.252|:443... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 9300 (9.1K) [application/x-httpd-php]\n",
"Saving to: ‘seeds_dataset.txt’\n",
"\n",
"seeds_dataset.txt 100%[===================>] 9.08K --.-KB/s in 0s \n",
"\n",
"2022-04-20 05:26:13 (68.0 MB/s) - ‘seeds_dataset.txt’ saved [9300/9300]\n",
"\n",
"15.26\t14.84\t0.871\t5.763\t3.312\t2.221\t5.22\t1\n",
"14.88\t14.57\t0.8811\t5.554\t3.333\t1.018\t4.956\t1\n",
"14.29\t14.09\t0.905\t5.291\t3.337\t2.699\t4.825\t1\n",
"13.84\t13.94\t0.8955\t5.324\t3.379\t2.259\t4.805\t1\n",
"16.14\t14.99\t0.9034\t5.658\t3.562\t1.355\t5.175\t1\n",
"14.38\t14.21\t0.8951\t5.386\t3.312\t2.462\t4.956\t1\n",
"14.69\t14.49\t0.8799\t5.563\t3.259\t3.586\t5.219\t1\n",
"14.11\t14.1\t0.8911\t5.42\t3.302\t2.7\t\t5\t\t1\n",
"16.63\t15.46\t0.8747\t6.053\t3.465\t2.04\t5.877\t1\n",
"16.44\t15.25\t0.888\t5.884\t3.505\t1.969\t5.533\t1\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"# load the data\n",
"CLASSIFICATION = pd.read_csv(\n",
" \"https://archive.ics.uci.edu/ml/machine-learning-databases/00236/seeds_dataset.txt\",\n",
" sep='\\t',\n",
" header=None,\n",
" names=[\"area\", \"perimeter\", \"compactness\", \"length\", \"width\", \"asymmetry\", \"length_groove\", \"class\"],\n",
" on_bad_lines='skip')\n",
"CLASSIFICATION = CLASSIFICATION.dropna()\n",
"CLASSIFICATION[\"class\"] = CLASSIFICATION[\"class\"].astype('int32')\n",
"CLASSIFICATION"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 423
},
"id": "-PLzz6BESR8e",
"outputId": "c0e20253-8bcd-42e8-e7b0-7db1b2ed6d9c"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" area perimeter compactness length width asymmetry length_groove \\\n",
"0 15.26 14.84 0.8710 5.763 3.312 2.221 5.220 \n",
"1 14.88 14.57 0.8811 5.554 3.333 1.018 4.956 \n",
"2 14.29 14.09 0.9050 5.291 3.337 2.699 4.825 \n",
"3 13.84 13.94 0.8955 5.324 3.379 2.259 4.805 \n",
"4 16.14 14.99 0.9034 5.658 3.562 1.355 5.175 \n",
".. ... ... ... ... ... ... ... \n",
"205 12.19 13.20 0.8783 5.137 2.981 3.631 4.870 \n",
"206 11.23 12.88 0.8511 5.140 2.795 4.325 5.003 \n",
"207 13.20 13.66 0.8883 5.236 3.232 8.315 5.056 \n",
"208 11.84 13.21 0.8521 5.175 2.836 3.598 5.044 \n",
"209 12.30 13.34 0.8684 5.243 2.974 5.637 5.063 \n",
"\n",
" class \n",
"0 1 \n",
"1 1 \n",
"2 1 \n",
"3 1 \n",
"4 1 \n",
".. ... \n",
"205 3 \n",
"206 3 \n",
"207 3 \n",
"208 3 \n",
"209 3 \n",
"\n",
"[199 rows x 8 columns]"
],
"text/html": [
"\n",
" <div id=\"df-f7eb1d44-9a51-4936-bb6a-e7cff54a8561\">\n",
" <div class=\"colab-df-container\">\n",
" <div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>area</th>\n",
" <th>perimeter</th>\n",
" <th>compactness</th>\n",
" <th>length</th>\n",
" <th>width</th>\n",
" <th>asymmetry</th>\n",
" <th>length_groove</th>\n",
" <th>class</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>15.26</td>\n",
" <td>14.84</td>\n",
" <td>0.8710</td>\n",
" <td>5.763</td>\n",
" <td>3.312</td>\n",
" <td>2.221</td>\n",
" <td>5.220</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>14.88</td>\n",
" <td>14.57</td>\n",
" <td>0.8811</td>\n",
" <td>5.554</td>\n",
" <td>3.333</td>\n",
" <td>1.018</td>\n",
" <td>4.956</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>14.29</td>\n",
" <td>14.09</td>\n",
" <td>0.9050</td>\n",
" <td>5.291</td>\n",
" <td>3.337</td>\n",
" <td>2.699</td>\n",
" <td>4.825</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>13.84</td>\n",
" <td>13.94</td>\n",
" <td>0.8955</td>\n",
" <td>5.324</td>\n",
" <td>3.379</td>\n",
" <td>2.259</td>\n",
" <td>4.805</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>16.14</td>\n",
" <td>14.99</td>\n",
" <td>0.9034</td>\n",
" <td>5.658</td>\n",
" <td>3.562</td>\n",
" <td>1.355</td>\n",
" <td>5.175</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>205</th>\n",
" <td>12.19</td>\n",
" <td>13.20</td>\n",
" <td>0.8783</td>\n",
" <td>5.137</td>\n",
" <td>2.981</td>\n",
" <td>3.631</td>\n",
" <td>4.870</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <th>206</th>\n",
" <td>11.23</td>\n",
" <td>12.88</td>\n",
" <td>0.8511</td>\n",
" <td>5.140</td>\n",
" <td>2.795</td>\n",
" <td>4.325</td>\n",
" <td>5.003</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <th>207</th>\n",
" <td>13.20</td>\n",
" <td>13.66</td>\n",
" <td>0.8883</td>\n",
" <td>5.236</td>\n",
" <td>3.232</td>\n",
" <td>8.315</td>\n",
" <td>5.056</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <th>208</th>\n",
" <td>11.84</td>\n",
" <td>13.21</td>\n",
" <td>0.8521</td>\n",
" <td>5.175</td>\n",
" <td>2.836</td>\n",
" <td>3.598</td>\n",
" <td>5.044</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <th>209</th>\n",
" <td>12.30</td>\n",
" <td>13.34</td>\n",
" <td>0.8684</td>\n",
" <td>5.243</td>\n",
" <td>2.974</td>\n",
" <td>5.637</td>\n",
" <td>5.063</td>\n",
" <td>3</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>199 rows × 8 columns</p>\n",
"</div>\n",
" <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-f7eb1d44-9a51-4936-bb6a-e7cff54a8561')\"\n",
" title=\"Convert this dataframe to an interactive table.\"\n",
" style=\"display:none;\">\n",
" \n",
" <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
" width=\"24px\">\n",
" <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n",
" <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n",
" </svg>\n",
" </button>\n",
" \n",
" <style>\n",
" .colab-df-container {\n",
" display:flex;\n",
" flex-wrap:wrap;\n",
" gap: 12px;\n",
" }\n",
"\n",
" .colab-df-convert {\n",
" background-color: #E8F0FE;\n",
" border: none;\n",
" border-radius: 50%;\n",
" cursor: pointer;\n",
" display: none;\n",
" fill: #1967D2;\n",
" height: 32px;\n",
" padding: 0 0 0 0;\n",
" width: 32px;\n",
" }\n",
"\n",
" .colab-df-convert:hover {\n",
" background-color: #E2EBFA;\n",
" box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
" fill: #174EA6;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert {\n",
" background-color: #3B4455;\n",
" fill: #D2E3FC;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert:hover {\n",
" background-color: #434B5C;\n",
" box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
" filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
" fill: #FFFFFF;\n",
" }\n",
" </style>\n",
"\n",
" <script>\n",
" const buttonEl =\n",
" document.querySelector('#df-f7eb1d44-9a51-4936-bb6a-e7cff54a8561 button.colab-df-convert');\n",
" buttonEl.style.display =\n",
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
"\n",
" async function convertToInteractive(key) {\n",
" const element = document.querySelector('#df-f7eb1d44-9a51-4936-bb6a-e7cff54a8561');\n",
" const dataTable =\n",
" await google.colab.kernel.invokeFunction('convertToInteractive',\n",
" [key], {});\n",
" if (!dataTable) return;\n",
"\n",
" const docLinkHtml = 'Like what you see? Visit the ' +\n",
" '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
" + ' to learn more about interactive tables.';\n",
" element.innerHTML = '';\n",
" dataTable['output_type'] = 'display_data';\n",
" await google.colab.output.renderOutput(dataTable, element);\n",
" const docLink = document.createElement('div');\n",
" docLink.innerHTML = docLinkHtml;\n",
" element.appendChild(docLink);\n",
" }\n",
" </script>\n",
" </div>\n",
" </div>\n",
" "
]
},
"metadata": {},
"execution_count": 7
}
]
},
{
"cell_type": "code",
"source": [
"X_class = CLASSIFICATION.loc[:, CLASSIFICATION.columns != 'class'].to_numpy()\n",
"y_class = CLASSIFICATION[\"class\"].to_numpy()\n",
"X_class.shape, y_class.shape"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "SS9X5AIhUr0Z",
"outputId": "510b9bac-c36e-4bdb-f9ba-d88d98eee250"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"((199, 7), (199,))"
]
},
"metadata": {},
"execution_count": 8
}
]
},
{
"cell_type": "code",
"source": [
"# split regression data\n",
"X_class_train, X_class_test, y_class_train, y_class_test = train_test_split(\n",
" X_class, y_class, test_size=0.33, random_state=0, shuffle=True)\n",
"X_class_train.shape, X_class_test.shape"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "YzQcEDH5VBEg",
"outputId": "7255c525-0a34-4b7a-961f-1d1b12648a6d"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"((133, 7), (66, 7))"
]
},
"metadata": {},
"execution_count": 9
}
]
},
{
"cell_type": "markdown",
"source": [
"# Простое дерево решений (simple decision tree)"
],
"metadata": {
"id": "EPPA6y6WSoQu"
}
},
{
"cell_type": "markdown",
"source": [
"Пусть нам даны пары наблюдений $(X_1,Y_1),...,(X_n,Y_n)$, где $X_i \\in R^p$, a $Y_i$ может принимать количественные или качественные значения.\n",
"\n",
"Суть подхода: разделить пространство независимых переменных (предикторов) $X^1, X^2, X^3,..,X^p$ на $J$ различных непересекющихся областей $R_1, R_2, ..., R_J$. Для набора данных $\\in R_J$ выполняется единственное предсказание зависимой переменной."
],
"metadata": {
"id": "V4rdGrAwqYOI"
}
},
{
"cell_type": "markdown",
"source": [
"![image.png]()\n",
"\n",
"https://en.wikipedia.org/wiki/Decision_tree_learning#/media/File:Cart_tree_kyphosis.png"
],
"metadata": {
"id": "mM--kn_zv2LW"
}
},
{
"cell_type": "markdown",
"source": [
"## Регрессия\n",
"\n",
"Цель регрессии через решающие деревья найти такой набор $R_1, R_2, ..., R_J$, который минимизирует целевую функцию:\n",
"\n",
"$RSS = \\sum_{j=1}^J\\sum_{i \\in R_j}(y_i - \\hat{y}_{R_j})^2$\n",
"\n",
"### Алгоритм рекурсивного бинарного разделения (\"жадный\")\n",
"\n",
"Выбор предиктора $X^j$ и выборт точки разделения $s$ так что простарнство **всех** предикторов образует две области: $\\{X|X^j<s\\}$ и $\\{X|X^j>s\\}$. Выбор предиктора и точки выполняется таким образом, что бы на данном шаге **максимально** уменшить RSS. Таким образом, рассматриваются все предикторы $X^1,..,X^J$ и все варианты $s$ для каждого предиктора. \n",
"\n",
"То есть, для каждого $j$ и $s$ задаётся пара:\n",
"\n",
"$R_1(j,s)=\\{X|X^j<s\\}$ и $R_2(j,s)=\\{X|X^j \\geq s\\}$ таким образом, что $j$ и $s$ минимизируют целевую функции в виде:\n",
"\n",
"$\\sum_{i:x_i \\in R_1(j,s)}(y_i - \\hat{y}_{R_1})^2 + \\sum_{i:x_i \\in R_2(j,s)}(y_i - \\hat{y}_{R_2})^2$, где $\\hat{y}_{R_{1,2}}$ - это выборочное среднее в $R_{1,2}(j,s)$.\n",
"\n",
"Далее, процее повторяется, но уже для разделения одного из $R_j$ полученных на предыдущих шагах до выполнения условия остановки. Например, условием остановки может быть то, каждая область включает на более 10 наблюдений."
],
"metadata": {
"id": "NA2USke2TmsR"
}
},
{
"cell_type": "code",
"source": [
"# https://scikit-learn.org/stable/modules/generated/sklearn.tree.DecisionTreeRegressor.html#sklearn.tree.DecisionTreeRegressor\n",
"simple_dt_reg = tree.DecisionTreeRegressor(random_state=0, max_depth=3)\n",
"simple_dt_reg.fit(X_reg_train, y_reg_train)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "7GNOIVLZSg7r",
"outputId": "51474eb6-7709-41b7-92d3-210e47e2c9f4"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"DecisionTreeRegressor(max_depth=3, random_state=0)"
]
},
"metadata": {},
"execution_count": 10
}
]
},
{
"cell_type": "code",
"source": [
"# https://scikit-learn.org/stable/modules/tree.html\n",
"plt.rcParams['figure.figsize'] = [15, 15]\n",
"tree.plot_tree(simple_dt_reg, fontsize=10)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
},
"id": "T9PN5kqSW-l2",
"outputId": "6fe3dcaf-1ad0-4337-8dab-346d38197879"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"[Text(0.5, 0.875, 'X[0] <= 130.25\\nsquared_error = 27.346\\nsamples = 134\\nvalue = 14.33'),\n",
" Text(0.25, 0.625, 'X[0] <= 33.3\\nsquared_error = 7.449\\nsamples = 55\\nvalue = 9.985'),\n",
" Text(0.125, 0.375, 'X[1] <= 31.45\\nsquared_error = 2.474\\nsamples = 16\\nvalue = 6.894'),\n",
" Text(0.0625, 0.125, 'squared_error = 1.528\\nsamples = 10\\nvalue = 6.08'),\n",
" Text(0.1875, 0.125, 'squared_error = 1.109\\nsamples = 6\\nvalue = 8.25'),\n",
" Text(0.375, 0.375, 'X[2] <= 49.0\\nsquared_error = 3.959\\nsamples = 39\\nvalue = 11.254'),\n",
" Text(0.3125, 0.125, 'squared_error = 1.996\\nsamples = 33\\nvalue = 10.67'),\n",
" Text(0.4375, 0.125, 'squared_error = 2.556\\nsamples = 6\\nvalue = 14.467'),\n",
" Text(0.75, 0.625, 'X[1] <= 26.85\\nsquared_error = 18.91\\nsamples = 79\\nvalue = 17.354'),\n",
" Text(0.625, 0.375, 'X[1] <= 9.7\\nsquared_error = 3.248\\nsamples = 36\\nvalue = 13.311'),\n",
" Text(0.5625, 0.125, 'squared_error = 0.853\\nsamples = 19\\nvalue = 11.921'),\n",
" Text(0.6875, 0.125, 'squared_error = 1.352\\nsamples = 17\\nvalue = 14.865'),\n",
" Text(0.875, 0.375, 'X[0] <= 210.75\\nsquared_error = 6.876\\nsamples = 43\\nvalue = 20.74'),\n",
" Text(0.8125, 0.125, 'squared_error = 2.269\\nsamples = 17\\nvalue = 18.724'),\n",
" Text(0.9375, 0.125, 'squared_error = 5.493\\nsamples = 26\\nvalue = 22.058')]"
]
},
"metadata": {},
"execution_count": 11
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 1080x1080 with 1 Axes>"
],
"image/png": "\n"
},
"metadata": {
"needs_background": "light"
}
}
]
},
{
"cell_type": "code",
"source": [
"simple_dt_reg = tree.DecisionTreeRegressor(random_state=0)\n",
"simple_dt_reg.fit(X_reg_train, y_reg_train)\n",
"y_true = y_reg_test\n",
"y_pred = simple_dt_reg.predict(X_reg_test)\n",
"r2_score(y_true, y_pred), mean_squared_error(y_true, y_pred)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "CxTyl7-6V0ET",
"outputId": "50135782-5434-45ca-85ac-0f5e6cf5b326"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(0.923650758203275, 1.9833333333333334)"
]
},
"metadata": {},
"execution_count": 12
}
]
},
{
"cell_type": "code",
"source": [
"tree.plot_tree(simple_dt_reg)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
},
"id": "OoMPgDIJYB9O",
"outputId": "d0697ea4-d3ed-46ee-8906-ae6c5ab52fc1"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"[Text(0.43275375939849625, 0.9583333333333334, 'X[0] <= 130.25\\nsquared_error = 27.346\\nsamples = 134\\nvalue = 14.33'),\n",
" Text(0.1756578947368421, 0.875, 'X[0] <= 33.3\\nsquared_error = 7.449\\nsamples = 55\\nvalue = 9.985'),\n",
" Text(0.04962406015037594, 0.7916666666666666, 'X[1] <= 31.45\\nsquared_error = 2.474\\nsamples = 16\\nvalue = 6.894'),\n",
" Text(0.02406015037593985, 0.7083333333333334, 'X[0] <= 18.0\\nsquared_error = 1.528\\nsamples = 10\\nvalue = 6.08'),\n",
" Text(0.012030075187969926, 0.625, 'X[0] <= 4.75\\nsquared_error = 0.938\\nsamples = 5\\nvalue = 5.08'),\n",
" Text(0.006015037593984963, 0.5416666666666666, 'squared_error = 0.0\\nsamples = 1\\nvalue = 3.2'),\n",
" Text(0.01804511278195489, 0.5416666666666666, 'X[0] <= 15.15\\nsquared_error = 0.067\\nsamples = 4\\nvalue = 5.55'),\n",
" Text(0.012030075187969926, 0.4583333333333333, 'X[2] <= 5.75\\nsquared_error = 0.036\\nsamples = 3\\nvalue = 5.433'),\n",
" Text(0.006015037593984963, 0.375, 'squared_error = 0.0\\nsamples = 1\\nvalue = 5.7'),\n",
" Text(0.01804511278195489, 0.375, 'squared_error = -0.0\\nsamples = 2\\nvalue = 5.3'),\n",
" Text(0.02406015037593985, 0.4583333333333333, 'squared_error = -0.0\\nsamples = 1\\nvalue = 5.9'),\n",
" Text(0.03609022556390978, 0.625, 'X[2] <= 18.85\\nsquared_error = 0.118\\nsamples = 5\\nvalue = 7.08'),\n",
" Text(0.03007518796992481, 0.5416666666666666, 'squared_error = 0.0\\nsamples = 1\\nvalue = 7.6'),\n",
" Text(0.042105263157894736, 0.5416666666666666, 'X[0] <= 28.05\\nsquared_error = 0.062\\nsamples = 4\\nvalue = 6.95'),\n",
" Text(0.03609022556390978, 0.4583333333333333, 'X[2] <= 36.35\\nsquared_error = 0.029\\nsamples = 3\\nvalue = 6.833'),\n",
" Text(0.03007518796992481, 0.375, 'X[0] <= 23.45\\nsquared_error = 0.023\\nsamples = 2\\nvalue = 6.75'),\n",
" Text(0.02406015037593985, 0.2916666666666667, 'squared_error = 0.0\\nsamples = 1\\nvalue = 6.6'),\n",
" Text(0.03609022556390978, 0.2916666666666667, 'squared_error = 0.0\\nsamples = 1\\nvalue = 6.9'),\n",
" Text(0.042105263157894736, 0.375, 'squared_error = 0.0\\nsamples = 1\\nvalue = 7.0'),\n",
" Text(0.0481203007518797, 0.4583333333333333, 'squared_error = -0.0\\nsamples = 1\\nvalue = 7.3'),\n",
" Text(0.07518796992481203, 0.7083333333333334, 'X[0] <= 14.45\\nsquared_error = 1.109\\nsamples = 6\\nvalue = 8.25'),\n",
" Text(0.06015037593984962, 0.625, 'X[1] <= 37.9\\nsquared_error = 0.123\\nsamples = 2\\nvalue = 6.95'),\n",
" Text(0.05413533834586466, 0.5416666666666666, 'squared_error = 0.0\\nsamples = 1\\nvalue = 7.3'),\n",
" Text(0.06616541353383458, 0.5416666666666666, 'squared_error = 0.0\\nsamples = 1\\nvalue = 6.6'),\n",
" Text(0.09022556390977443, 0.625, 'X[1] <= 38.3\\nsquared_error = 0.335\\nsamples = 4\\nvalue = 8.9'),\n",
" Text(0.07819548872180451, 0.5416666666666666, 'X[0] <= 22.35\\nsquared_error = 0.16\\nsamples = 2\\nvalue = 8.4'),\n",
" Text(0.07218045112781955, 0.4583333333333333, 'squared_error = 0.0\\nsamples = 1\\nvalue = 8.0'),\n",
" Text(0.08421052631578947, 0.4583333333333333, 'squared_error = -0.0\\nsamples = 1\\nvalue = 8.8'),\n",
" Text(0.10225563909774436, 0.5416666666666666, 'X[0] <= 21.4\\nsquared_error = 0.01\\nsamples = 2\\nvalue = 9.4'),\n",
" Text(0.0962406015037594, 0.4583333333333333, 'squared_error = 0.0\\nsamples = 1\\nvalue = 9.3'),\n",
" Text(0.10827067669172932, 0.4583333333333333, 'squared_error = 0.0\\nsamples = 1\\nvalue = 9.5'),\n",
" Text(0.3016917293233083, 0.7916666666666666, 'X[2] <= 49.0\\nsquared_error = 3.959\\nsamples = 39\\nvalue = 11.254'),\n",
" Text(0.2244360902255639, 0.7083333333333334, 'X[1] <= 12.85\\nsquared_error = 1.996\\nsamples = 33\\nvalue = 10.67'),\n",
" Text(0.16466165413533834, 0.625, 'X[0] <= 98.95\\nsquared_error = 0.964\\nsamples = 16\\nvalue = 9.794'),\n",
" Text(0.1368421052631579, 0.5416666666666666, 'X[0] <= 66.5\\nsquared_error = 0.521\\nsamples = 12\\nvalue = 9.367'),\n",
" Text(0.12030075187969924, 0.4583333333333333, 'X[1] <= 4.75\\nsquared_error = 0.25\\nsamples = 2\\nvalue = 8.1'),\n",
" Text(0.11428571428571428, 0.375, 'squared_error = 0.0\\nsamples = 1\\nvalue = 7.6'),\n",
" Text(0.12631578947368421, 0.375, 'squared_error = 0.0\\nsamples = 1\\nvalue = 8.6'),\n",
" Text(0.15338345864661654, 0.4583333333333333, 'X[1] <= 9.6\\nsquared_error = 0.19\\nsamples = 10\\nvalue = 9.62'),\n",
" Text(0.13834586466165413, 0.375, 'X[1] <= 0.4\\nsquared_error = 0.085\\nsamples = 7\\nvalue = 9.429'),\n",
" Text(0.13233082706766916, 0.2916666666666667, 'squared_error = 0.0\\nsamples = 1\\nvalue = 8.8'),\n",
" Text(0.1443609022556391, 0.2916666666666667, 'X[0] <= 85.3\\nsquared_error = 0.022\\nsamples = 6\\nvalue = 9.533'),\n",
" Text(0.13233082706766916, 0.20833333333333334, 'X[1] <= 5.05\\nsquared_error = 0.002\\nsamples = 2\\nvalue = 9.35'),\n",
" Text(0.12631578947368421, 0.125, 'squared_error = 0.0\\nsamples = 1\\nvalue = 9.4'),\n",
" Text(0.13834586466165413, 0.125, 'squared_error = -0.0\\nsamples = 1\\nvalue = 9.3'),\n",
" Text(0.15639097744360902, 0.20833333333333334, 'X[1] <= 3.2\\nsquared_error = 0.007\\nsamples = 4\\nvalue = 9.625'),\n",
" Text(0.15037593984962405, 0.125, 'X[1] <= 1.45\\nsquared_error = 0.002\\nsamples = 2\\nvalue = 9.55'),\n",
" Text(0.1443609022556391, 0.041666666666666664, 'squared_error = 0.0\\nsamples = 1\\nvalue = 9.5'),\n",
" Text(0.15639097744360902, 0.041666666666666664, 'squared_error = -0.0\\nsamples = 1\\nvalue = 9.6'),\n",
" Text(0.162406015037594, 0.125, 'squared_error = 0.0\\nsamples = 2\\nvalue = 9.7'),\n",
" Text(0.16842105263157894, 0.375, 'X[1] <= 10.35\\nsquared_error = 0.149\\nsamples = 3\\nvalue = 10.067'),\n",
" Text(0.162406015037594, 0.2916666666666667, 'squared_error = 0.0\\nsamples = 1\\nvalue = 10.6'),\n",
" Text(0.17443609022556392, 0.2916666666666667, 'X[0] <= 71.2\\nsquared_error = 0.01\\nsamples = 2\\nvalue = 9.8'),\n",
" Text(0.16842105263157894, 0.20833333333333334, 'squared_error = 0.0\\nsamples = 1\\nvalue = 9.7'),\n",
" Text(0.18045112781954886, 0.20833333333333334, 'squared_error = -0.0\\nsamples = 1\\nvalue = 9.9'),\n",
" Text(0.1924812030075188, 0.5416666666666666, 'X[2] <= 40.0\\nsquared_error = 0.107\\nsamples = 4\\nvalue = 11.075'),\n",
" Text(0.18646616541353384, 0.4583333333333333, 'X[2] <= 13.35\\nsquared_error = 0.02\\nsamples = 3\\nvalue = 10.9'),\n",
" Text(0.18045112781954886, 0.375, 'squared_error = 0.0\\nsamples = 1\\nvalue = 10.7'),\n",
" Text(0.1924812030075188, 0.375, 'squared_error = -0.0\\nsamples = 2\\nvalue = 11.0'),\n",
" Text(0.19849624060150375, 0.4583333333333333, 'squared_error = -0.0\\nsamples = 1\\nvalue = 11.6'),\n",
" Text(0.28421052631578947, 0.625, 'X[0] <= 75.8\\nsquared_error = 1.565\\nsamples = 17\\nvalue = 11.494'),\n",
" Text(0.24060150375939848, 0.5416666666666666, 'X[0] <= 46.4\\nsquared_error = 0.409\\nsamples = 10\\nvalue = 10.79'),\n",
" Text(0.21654135338345865, 0.4583333333333333, 'X[1] <= 33.0\\nsquared_error = 0.226\\nsamples = 5\\nvalue = 10.36'),\n",
" Text(0.20451127819548873, 0.375, 'X[1] <= 26.3\\nsquared_error = 0.062\\nsamples = 2\\nvalue = 9.85'),\n",
" Text(0.19849624060150375, 0.2916666666666667, 'squared_error = 0.0\\nsamples = 1\\nvalue = 9.6'),\n",
" Text(0.21052631578947367, 0.2916666666666667, 'squared_error = 0.0\\nsamples = 1\\nvalue = 10.1'),\n",
" Text(0.22857142857142856, 0.375, 'X[1] <= 39.8\\nsquared_error = 0.047\\nsamples = 3\\nvalue = 10.7'),\n",
" Text(0.22255639097744362, 0.2916666666666667, 'squared_error = 0.0\\nsamples = 1\\nvalue = 10.4'),\n",
" Text(0.23458646616541354, 0.2916666666666667, 'X[1] <= 40.7\\nsquared_error = 0.002\\nsamples = 2\\nvalue = 10.85'),\n",
" Text(0.22857142857142856, 0.20833333333333334, 'squared_error = 0.0\\nsamples = 1\\nvalue = 10.9'),\n",
" Text(0.24060150375939848, 0.20833333333333334, 'squared_error = -0.0\\nsamples = 1\\nvalue = 10.8'),\n",
" Text(0.2646616541353383, 0.4583333333333333, 'X[1] <= 18.65\\nsquared_error = 0.222\\nsamples = 5\\nvalue = 11.22'),\n",
" Text(0.25263157894736843, 0.375, 'X[1] <= 16.5\\nsquared_error = 0.04\\nsamples = 2\\nvalue = 10.7'),\n",
" Text(0.24661654135338346, 0.2916666666666667, 'squared_error = 0.0\\nsamples = 1\\nvalue = 10.5'),\n",
" Text(0.2586466165413534, 0.2916666666666667, 'squared_error = 0.0\\nsamples = 1\\nvalue = 10.9'),\n",
" Text(0.27669172932330827, 0.375, 'X[0] <= 66.4\\nsquared_error = 0.042\\nsamples = 3\\nvalue = 11.567'),\n",
" Text(0.2706766917293233, 0.2916666666666667, 'X[1] <= 39.9\\nsquared_error = 0.01\\nsamples = 2\\nvalue = 11.7'),\n",
" Text(0.2646616541353383, 0.20833333333333334, 'squared_error = 0.0\\nsamples = 1\\nvalue = 11.8'),\n",
" Text(0.27669172932330827, 0.20833333333333334, 'squared_error = 0.0\\nsamples = 1\\nvalue = 11.6'),\n",
" Text(0.28270676691729324, 0.2916666666666667, 'squared_error = -0.0\\nsamples = 1\\nvalue = 11.3'),\n",
" Text(0.32781954887218046, 0.5416666666666666, 'X[1] <= 28.0\\nsquared_error = 1.497\\nsamples = 7\\nvalue = 12.5'),\n",
" Text(0.31278195488721805, 0.4583333333333333, 'X[1] <= 16.1\\nsquared_error = 0.058\\nsamples = 5\\nvalue = 11.74'),\n",
" Text(0.3007518796992481, 0.375, 'X[2] <= 8.15\\nsquared_error = 0.047\\nsamples = 3\\nvalue = 11.6'),\n",
" Text(0.29473684210526313, 0.2916666666666667, 'squared_error = 0.0\\nsamples = 1\\nvalue = 11.9'),\n",
" Text(0.3067669172932331, 0.2916666666666667, 'X[1] <= 14.4\\nsquared_error = 0.003\\nsamples = 2\\nvalue = 11.45'),\n",
" Text(0.3007518796992481, 0.20833333333333334, 'squared_error = 0.0\\nsamples = 1\\nvalue = 11.5'),\n",
" Text(0.31278195488721805, 0.20833333333333334, 'squared_error = 0.0\\nsamples = 1\\nvalue = 11.4'),\n",
" Text(0.324812030075188, 0.375, 'X[1] <= 22.45\\nsquared_error = 0.003\\nsamples = 2\\nvalue = 11.95'),\n",
" Text(0.318796992481203, 0.2916666666666667, 'squared_error = 0.0\\nsamples = 1\\nvalue = 11.9'),\n",
" Text(0.3308270676691729, 0.2916666666666667, 'squared_error = 0.0\\nsamples = 1\\nvalue = 12.0'),\n",
" Text(0.34285714285714286, 0.4583333333333333, 'X[1] <= 37.65\\nsquared_error = 0.04\\nsamples = 2\\nvalue = 14.4'),\n",
" Text(0.3368421052631579, 0.375, 'squared_error = 0.0\\nsamples = 1\\nvalue = 14.2'),\n",
" Text(0.34887218045112783, 0.375, 'squared_error = 0.0\\nsamples = 1\\nvalue = 14.6'),\n",
" Text(0.37894736842105264, 0.7083333333333334, 'X[0] <= 91.1\\nsquared_error = 2.556\\nsamples = 6\\nvalue = 14.467'),\n",
" Text(0.3669172932330827, 0.625, 'X[1] <= 35.4\\nsquared_error = 0.082\\nsamples = 3\\nvalue = 12.933'),\n",
" Text(0.3609022556390977, 0.5416666666666666, 'X[1] <= 30.25\\nsquared_error = 0.023\\nsamples = 2\\nvalue = 12.75'),\n",
" Text(0.3548872180451128, 0.4583333333333333, 'squared_error = 0.0\\nsamples = 1\\nvalue = 12.9'),\n",
" Text(0.3669172932330827, 0.4583333333333333, 'squared_error = -0.0\\nsamples = 1\\nvalue = 12.6'),\n",
" Text(0.37293233082706767, 0.5416666666666666, 'squared_error = -0.0\\nsamples = 1\\nvalue = 13.3'),\n",
" Text(0.39097744360902253, 0.625, 'X[0] <= 101.85\\nsquared_error = 0.327\\nsamples = 3\\nvalue = 16.0'),\n",
" Text(0.3849624060150376, 0.5416666666666666, 'squared_error = 0.0\\nsamples = 1\\nvalue = 15.3'),\n",
" Text(0.3969924812030075, 0.5416666666666666, 'X[0] <= 110.25\\nsquared_error = 0.122\\nsamples = 2\\nvalue = 16.35'),\n",
" Text(0.39097744360902253, 0.4583333333333333, 'squared_error = 0.0\\nsamples = 1\\nvalue = 16.7'),\n",
" Text(0.4030075187969925, 0.4583333333333333, 'squared_error = -0.0\\nsamples = 1\\nvalue = 16.0'),\n",
" Text(0.6898496240601504, 0.875, 'X[1] <= 26.85\\nsquared_error = 18.91\\nsamples = 79\\nvalue = 17.354'),\n",
" Text(0.5503759398496241, 0.7916666666666666, 'X[1] <= 9.7\\nsquared_error = 3.248\\nsamples = 36\\nvalue = 13.311'),\n",
" Text(0.48270676691729325, 0.7083333333333334, 'X[1] <= 3.25\\nsquared_error = 0.853\\nsamples = 19\\nvalue = 11.921'),\n",
" Text(0.4330827067669173, 0.625, 'X[0] <= 174.8\\nsquared_error = 0.426\\nsamples = 5\\nvalue = 10.84'),\n",
" Text(0.42105263157894735, 0.5416666666666666, 'X[1] <= 1.7\\nsquared_error = 0.01\\nsamples = 2\\nvalue = 10.2'),\n",
" Text(0.4150375939849624, 0.4583333333333333, 'squared_error = 0.0\\nsamples = 1\\nvalue = 10.1'),\n",
" Text(0.4270676691729323, 0.4583333333333333, 'squared_error = 0.0\\nsamples = 1\\nvalue = 10.3'),\n",
" Text(0.44511278195488724, 0.5416666666666666, 'X[2] <= 22.45\\nsquared_error = 0.249\\nsamples = 3\\nvalue = 11.267'),\n",
" Text(0.43909774436090226, 0.4583333333333333, 'squared_error = 0.0\\nsamples = 1\\nvalue = 10.6'),\n",
" Text(0.45112781954887216, 0.4583333333333333, 'X[1] <= 2.7\\nsquared_error = 0.04\\nsamples = 2\\nvalue = 11.6'),\n",
" Text(0.44511278195488724, 0.375, 'squared_error = 0.0\\nsamples = 1\\nvalue = 11.8'),\n",
" Text(0.45714285714285713, 0.375, 'squared_error = 0.0\\nsamples = 1\\nvalue = 11.4'),\n",
" Text(0.5323308270676692, 0.625, 'X[1] <= 7.65\\nsquared_error = 0.439\\nsamples = 14\\nvalue = 12.307'),\n",
" Text(0.5052631578947369, 0.5416666666666666, 'X[0] <= 235.95\\nsquared_error = 0.21\\nsamples = 10\\nvalue = 11.98'),\n",
" Text(0.48721804511278194, 0.4583333333333333, 'X[2] <= 16.25\\nsquared_error = 0.097\\nsamples = 6\\nvalue = 11.7'),\n",
" Text(0.4691729323308271, 0.375, 'X[0] <= 151.35\\nsquared_error = 0.042\\nsamples = 3\\nvalue = 11.467'),\n",
" Text(0.4631578947368421, 0.2916666666666667, 'squared_error = 0.0\\nsamples = 1\\nvalue = 11.2'),\n",
" Text(0.47518796992481205, 0.2916666666666667, 'X[1] <= 5.25\\nsquared_error = 0.01\\nsamples = 2\\nvalue = 11.6'),\n",
" Text(0.4691729323308271, 0.20833333333333334, 'squared_error = 0.0\\nsamples = 1\\nvalue = 11.5'),\n",
" Text(0.48120300751879697, 0.20833333333333334, 'squared_error = 0.0\\nsamples = 1\\nvalue = 11.7'),\n",
" Text(0.5052631578947369, 0.375, 'X[1] <= 4.75\\nsquared_error = 0.042\\nsamples = 3\\nvalue = 11.933'),\n",
" Text(0.4992481203007519, 0.2916666666666667, 'X[2] <= 67.3\\nsquared_error = 0.01\\nsamples = 2\\nvalue = 11.8'),\n",
" Text(0.4932330827067669, 0.20833333333333334, 'squared_error = 0.0\\nsamples = 1\\nvalue = 11.7'),\n",
" Text(0.5052631578947369, 0.20833333333333334, 'squared_error = -0.0\\nsamples = 1\\nvalue = 11.9'),\n",
" Text(0.5112781954887218, 0.2916666666666667, 'squared_error = 0.0\\nsamples = 1\\nvalue = 12.2'),\n",
" Text(0.5233082706766917, 0.4583333333333333, 'X[2] <= 14.0\\nsquared_error = 0.085\\nsamples = 4\\nvalue = 12.4'),\n",
" Text(0.5172932330827068, 0.375, 'squared_error = 0.0\\nsamples = 1\\nvalue = 12.8'),\n",
" Text(0.5293233082706766, 0.375, 'X[0] <= 251.35\\nsquared_error = 0.042\\nsamples = 3\\nvalue = 12.267'),\n",
" Text(0.5233082706766917, 0.2916666666666667, 'X[1] <= 4.6\\nsquared_error = 0.01\\nsamples = 2\\nvalue = 12.4'),\n",
" Text(0.5172932330827068, 0.20833333333333334, 'squared_error = 0.0\\nsamples = 1\\nvalue = 12.3'),\n",
" Text(0.5293233082706766, 0.20833333333333334, 'squared_error = 0.0\\nsamples = 1\\nvalue = 12.5'),\n",
" Text(0.5353383458646617, 0.2916666666666667, 'squared_error = 0.0\\nsamples = 1\\nvalue = 12.0'),\n",
" Text(0.5593984962406015, 0.5416666666666666, 'X[0] <= 216.35\\nsquared_error = 0.077\\nsamples = 4\\nvalue = 13.125'),\n",
" Text(0.5473684210526316, 0.4583333333333333, 'X[0] <= 191.95\\nsquared_error = 0.002\\nsamples = 2\\nvalue = 12.85'),\n",
" Text(0.5413533834586466, 0.375, 'squared_error = 0.0\\nsamples = 1\\nvalue = 12.8'),\n",
" Text(0.5533834586466165, 0.375, 'squared_error = -0.0\\nsamples = 1\\nvalue = 12.9'),\n",
" Text(0.5714285714285714, 0.4583333333333333, 'X[2] <= 32.6\\nsquared_error = 0.0\\nsamples = 2\\nvalue = 13.4'),\n",
" Text(0.5654135338345865, 0.375, 'squared_error = 0.0\\nsamples = 1\\nvalue = 13.4'),\n",
" Text(0.5774436090225564, 0.375, 'squared_error = 0.0\\nsamples = 1\\nvalue = 13.4'),\n",
" Text(0.6180451127819548, 0.7083333333333334, 'X[0] <= 140.3\\nsquared_error = 1.352\\nsamples = 17\\nvalue = 14.865'),\n",
" Text(0.5834586466165413, 0.625, 'X[1] <= 14.4\\nsquared_error = 0.276\\nsamples = 3\\nvalue = 12.933'),\n",
" Text(0.5774436090225564, 0.5416666666666666, 'squared_error = 0.0\\nsamples = 1\\nvalue = 12.2'),\n",
" Text(0.5894736842105263, 0.5416666666666666, 'X[1] <= 16.85\\nsquared_error = 0.01\\nsamples = 2\\nvalue = 13.3'),\n",
" Text(0.5834586466165413, 0.4583333333333333, 'squared_error = 0.0\\nsamples = 1\\nvalue = 13.4'),\n",
" Text(0.5954887218045113, 0.4583333333333333, 'squared_error = -0.0\\nsamples = 1\\nvalue = 13.2'),\n",
" Text(0.6526315789473685, 0.625, 'X[1] <= 22.45\\nsquared_error = 0.612\\nsamples = 14\\nvalue = 15.279'),\n",
" Text(0.631578947368421, 0.5416666666666666, 'X[0] <= 227.75\\nsquared_error = 0.296\\nsamples = 12\\nvalue = 15.083'),\n",
" Text(0.6075187969924812, 0.4583333333333333, 'X[0] <= 181.8\\nsquared_error = 0.117\\nsamples = 6\\nvalue = 14.7'),\n",
" Text(0.5954887218045113, 0.375, 'X[2] <= 24.9\\nsquared_error = 0.04\\nsamples = 2\\nvalue = 14.3'),\n",
" Text(0.5894736842105263, 0.2916666666666667, 'squared_error = 0.0\\nsamples = 1\\nvalue = 14.1'),\n",
" Text(0.6015037593984962, 0.2916666666666667, 'squared_error = -0.0\\nsamples = 1\\nvalue = 14.5'),\n",
" Text(0.6195488721804512, 0.375, 'X[2] <= 57.8\\nsquared_error = 0.035\\nsamples = 4\\nvalue = 14.9'),\n",
" Text(0.6135338345864662, 0.2916666666666667, 'X[0] <= 188.15\\nsquared_error = 0.007\\nsamples = 3\\nvalue = 14.8'),\n",
" Text(0.6075187969924812, 0.20833333333333334, 'squared_error = 0.0\\nsamples = 1\\nvalue = 14.7'),\n",
" Text(0.6195488721804512, 0.20833333333333334, 'X[0] <= 207.8\\nsquared_error = 0.002\\nsamples = 2\\nvalue = 14.85'),\n",
" Text(0.6135338345864662, 0.125, 'squared_error = 0.0\\nsamples = 1\\nvalue = 14.9'),\n",
" Text(0.6255639097744361, 0.125, 'squared_error = -0.0\\nsamples = 1\\nvalue = 14.8'),\n",
" Text(0.6255639097744361, 0.2916666666666667, 'squared_error = -0.0\\nsamples = 1\\nvalue = 15.2'),\n",
" Text(0.6556390977443609, 0.4583333333333333, 'X[1] <= 12.25\\nsquared_error = 0.182\\nsamples = 6\\nvalue = 15.467'),\n",
" Text(0.643609022556391, 0.375, 'X[2] <= 13.9\\nsquared_error = 0.01\\nsamples = 2\\nvalue = 14.9'),\n",
" Text(0.637593984962406, 0.2916666666666667, 'squared_error = 0.0\\nsamples = 1\\nvalue = 15.0'),\n",
" Text(0.649624060150376, 0.2916666666666667, 'squared_error = 0.0\\nsamples = 1\\nvalue = 14.8'),\n",
" Text(0.6676691729323309, 0.375, 'X[2] <= 24.55\\nsquared_error = 0.027\\nsamples = 4\\nvalue = 15.75'),\n",
" Text(0.6616541353383458, 0.2916666666666667, 'squared_error = 0.0\\nsamples = 2\\nvalue = 15.9'),\n",
" Text(0.6736842105263158, 0.2916666666666667, 'X[1] <= 16.2\\nsquared_error = 0.01\\nsamples = 2\\nvalue = 15.6'),\n",
" Text(0.6676691729323309, 0.20833333333333334, 'squared_error = 0.0\\nsamples = 1\\nvalue = 15.7'),\n",
" Text(0.6796992481203008, 0.20833333333333334, 'squared_error = 0.0\\nsamples = 1\\nvalue = 15.5'),\n",
" Text(0.6736842105263158, 0.5416666666666666, 'X[2] <= 25.1\\nsquared_error = 0.902\\nsamples = 2\\nvalue = 16.45'),\n",
" Text(0.6676691729323309, 0.4583333333333333, 'squared_error = 0.0\\nsamples = 1\\nvalue = 17.4'),\n",
" Text(0.6796992481203008, 0.4583333333333333, 'squared_error = 0.0\\nsamples = 1\\nvalue = 15.5'),\n",
" Text(0.8293233082706767, 0.7916666666666666, 'X[0] <= 210.75\\nsquared_error = 6.876\\nsamples = 43\\nvalue = 20.74'),\n",
" Text(0.7684210526315789, 0.7083333333333334, 'X[1] <= 43.35\\nsquared_error = 2.269\\nsamples = 17\\nvalue = 18.724'),\n",
" Text(0.7308270676691729, 0.625, 'X[2] <= 4.8\\nsquared_error = 0.699\\nsamples = 13\\nvalue = 18.062'),\n",
" Text(0.724812030075188, 0.5416666666666666, 'squared_error = 0.0\\nsamples = 1\\nvalue = 19.6'),\n",
" Text(0.7368421052631579, 0.5416666666666666, 'X[0] <= 192.4\\nsquared_error = 0.544\\nsamples = 12\\nvalue = 17.933'),\n",
" Text(0.7097744360902256, 0.4583333333333333, 'X[1] <= 36.2\\nsquared_error = 0.272\\nsamples = 8\\nvalue = 17.537'),\n",
" Text(0.6917293233082706, 0.375, 'X[2] <= 28.45\\nsquared_error = 0.027\\nsamples = 4\\nvalue = 17.15'),\n",
" Text(0.6857142857142857, 0.2916666666666667, 'squared_error = 0.0\\nsamples = 2\\nvalue = 17.3'),\n",
" Text(0.6977443609022557, 0.2916666666666667, 'X[0] <= 170.15\\nsquared_error = 0.01\\nsamples = 2\\nvalue = 17.0'),\n",
" Text(0.6917293233082706, 0.20833333333333334, 'squared_error = 0.0\\nsamples = 1\\nvalue = 16.9'),\n",
" Text(0.7037593984962406, 0.20833333333333334, 'squared_error = 0.0\\nsamples = 1\\nvalue = 17.1'),\n",
" Text(0.7278195488721805, 0.375, 'X[2] <= 52.2\\nsquared_error = 0.217\\nsamples = 4\\nvalue = 17.925'),\n",
" Text(0.7218045112781954, 0.2916666666666667, 'X[2] <= 37.4\\nsquared_error = 0.142\\nsamples = 3\\nvalue = 17.733'),\n",
" Text(0.7157894736842105, 0.20833333333333334, 'squared_error = 0.0\\nsamples = 2\\nvalue = 18.0'),\n",
" Text(0.7278195488721805, 0.20833333333333334, 'squared_error = -0.0\\nsamples = 1\\nvalue = 17.2'),\n",
" Text(0.7338345864661654, 0.2916666666666667, 'squared_error = 0.0\\nsamples = 1\\nvalue = 18.5'),\n",
" Text(0.7639097744360902, 0.4583333333333333, 'X[2] <= 42.35\\nsquared_error = 0.147\\nsamples = 4\\nvalue = 18.725'),\n",
" Text(0.7518796992481203, 0.375, 'X[2] <= 24.0\\nsquared_error = 0.002\\nsamples = 2\\nvalue = 18.35'),\n",
" Text(0.7458646616541353, 0.2916666666666667, 'squared_error = 0.0\\nsamples = 1\\nvalue = 18.4'),\n",
" Text(0.7578947368421053, 0.2916666666666667, 'squared_error = -0.0\\nsamples = 1\\nvalue = 18.3'),\n",
" Text(0.7759398496240602, 0.375, 'X[2] <= 60.8\\nsquared_error = 0.01\\nsamples = 2\\nvalue = 19.1'),\n",
" Text(0.7699248120300752, 0.2916666666666667, 'squared_error = 0.0\\nsamples = 1\\nvalue = 19.0'),\n",
" Text(0.7819548872180451, 0.2916666666666667, 'squared_error = -0.0\\nsamples = 1\\nvalue = 19.2'),\n",
" Text(0.806015037593985, 0.625, 'X[2] <= 58.85\\nsquared_error = 1.317\\nsamples = 4\\nvalue = 20.875'),\n",
" Text(0.8, 0.5416666666666666, 'X[0] <= 190.15\\nsquared_error = 0.509\\nsamples = 3\\nvalue = 21.433'),\n",
" Text(0.793984962406015, 0.4583333333333333, 'X[0] <= 183.75\\nsquared_error = 0.062\\nsamples = 2\\nvalue = 20.95'),\n",
" Text(0.7879699248120301, 0.375, 'squared_error = 0.0\\nsamples = 1\\nvalue = 21.2'),\n",
" Text(0.8, 0.375, 'squared_error = 0.0\\nsamples = 1\\nvalue = 20.7'),\n",
" Text(0.806015037593985, 0.4583333333333333, 'squared_error = -0.0\\nsamples = 1\\nvalue = 22.4'),\n",
" Text(0.8120300751879699, 0.5416666666666666, 'squared_error = 0.0\\nsamples = 1\\nvalue = 19.2'),\n",
" Text(0.8902255639097745, 0.7083333333333334, 'X[1] <= 35.3\\nsquared_error = 5.493\\nsamples = 26\\nvalue = 22.058'),\n",
" Text(0.8300751879699249, 0.625, 'X[0] <= 219.1\\nsquared_error = 0.721\\nsamples = 11\\nvalue = 19.809'),\n",
" Text(0.8240601503759398, 0.5416666666666666, 'squared_error = 0.0\\nsamples = 1\\nvalue = 18.0'),\n",
" Text(0.8360902255639098, 0.5416666666666666, 'X[1] <= 27.6\\nsquared_error = 0.433\\nsamples = 10\\nvalue = 19.99'),\n",
" Text(0.8195488721804511, 0.4583333333333333, 'X[0] <= 252.1\\nsquared_error = 0.18\\nsamples = 3\\nvalue = 19.2'),\n",
" Text(0.8135338345864662, 0.375, 'squared_error = 0.0\\nsamples = 2\\nvalue = 18.9'),\n",
" Text(0.825563909774436, 0.375, 'squared_error = 0.0\\nsamples = 1\\nvalue = 19.8'),\n",
" Text(0.8526315789473684, 0.4583333333333333, 'X[0] <= 229.35\\nsquared_error = 0.159\\nsamples = 7\\nvalue = 20.329'),\n",
" Text(0.837593984962406, 0.375, 'X[1] <= 33.35\\nsquared_error = 0.062\\nsamples = 2\\nvalue = 19.85'),\n",
" Text(0.8315789473684211, 0.2916666666666667, 'squared_error = 0.0\\nsamples = 1\\nvalue = 20.1'),\n",
" Text(0.843609022556391, 0.2916666666666667, 'squared_error = -0.0\\nsamples = 1\\nvalue = 19.6'),\n",
" Text(0.8676691729323308, 0.375, 'X[0] <= 268.2\\nsquared_error = 0.07\\nsamples = 5\\nvalue = 20.52'),\n",
" Text(0.8556390977443609, 0.2916666666666667, 'X[2] <= 10.6\\nsquared_error = 0.056\\nsamples = 3\\nvalue = 20.367'),\n",
" Text(0.849624060150376, 0.20833333333333334, 'squared_error = 0.0\\nsamples = 1\\nvalue = 20.7'),\n",
" Text(0.8616541353383459, 0.20833333333333334, 'X[2] <= 18.1\\nsquared_error = 0.0\\nsamples = 2\\nvalue = 20.2'),\n",
" Text(0.8556390977443609, 0.125, 'squared_error = 0.0\\nsamples = 1\\nvalue = 20.2'),\n",
" Text(0.8676691729323308, 0.125, 'squared_error = 0.0\\nsamples = 1\\nvalue = 20.2'),\n",
" Text(0.8796992481203008, 0.2916666666666667, 'X[2] <= 30.75\\nsquared_error = 0.002\\nsamples = 2\\nvalue = 20.75'),\n",
" Text(0.8736842105263158, 0.20833333333333334, 'squared_error = 0.0\\nsamples = 1\\nvalue = 20.7'),\n",
" Text(0.8857142857142857, 0.20833333333333334, 'squared_error = 0.0\\nsamples = 1\\nvalue = 20.8'),\n",
" Text(0.9503759398496241, 0.625, 'X[0] <= 258.9\\nsquared_error = 2.566\\nsamples = 15\\nvalue = 23.707'),\n",
" Text(0.924812030075188, 0.5416666666666666, 'X[1] <= 46.45\\nsquared_error = 1.534\\nsamples = 10\\nvalue = 22.95'),\n",
" Text(0.9097744360902256, 0.4583333333333333, 'X[2] <= 22.85\\nsquared_error = 0.273\\nsamples = 7\\nvalue = 22.229'),\n",
" Text(0.9037593984962407, 0.375, 'squared_error = 0.0\\nsamples = 1\\nvalue = 23.2'),\n",
" Text(0.9157894736842105, 0.375, 'X[2] <= 36.7\\nsquared_error = 0.136\\nsamples = 6\\nvalue = 22.067'),\n",
" Text(0.9037593984962407, 0.2916666666666667, 'X[2] <= 29.6\\nsquared_error = 0.116\\nsamples = 3\\nvalue = 21.833'),\n",
" Text(0.8977443609022556, 0.20833333333333334, 'squared_error = 0.0\\nsamples = 1\\nvalue = 22.3'),\n",
" Text(0.9097744360902256, 0.20833333333333334, 'X[0] <= 220.75\\nsquared_error = 0.01\\nsamples = 2\\nvalue = 21.6'),\n",
" Text(0.9037593984962407, 0.125, 'squared_error = 0.0\\nsamples = 1\\nvalue = 21.7'),\n",
" Text(0.9157894736842105, 0.125, 'squared_error = -0.0\\nsamples = 1\\nvalue = 21.5'),\n",
" Text(0.9278195488721804, 0.2916666666666667, 'X[2] <= 54.4\\nsquared_error = 0.047\\nsamples = 3\\nvalue = 22.3'),\n",
" Text(0.9218045112781955, 0.20833333333333334, 'squared_error = 0.0\\nsamples = 1\\nvalue = 22.6'),\n",
" Text(0.9338345864661655, 0.20833333333333334, 'X[1] <= 37.15\\nsquared_error = 0.002\\nsamples = 2\\nvalue = 22.15'),\n",
" Text(0.9278195488721804, 0.125, 'squared_error = 0.0\\nsamples = 1\\nvalue = 22.2'),\n",
" Text(0.9398496240601504, 0.125, 'squared_error = 0.0\\nsamples = 1\\nvalue = 22.1'),\n",
" Text(0.9398496240601504, 0.4583333333333333, 'X[0] <= 215.55\\nsquared_error = 0.429\\nsamples = 3\\nvalue = 24.633'),\n",
" Text(0.9338345864661655, 0.375, 'squared_error = 0.0\\nsamples = 1\\nvalue = 23.8'),\n",
" Text(0.9458646616541353, 0.375, 'X[0] <= 231.75\\nsquared_error = 0.123\\nsamples = 2\\nvalue = 25.05'),\n",
" Text(0.9398496240601504, 0.2916666666666667, 'squared_error = 0.0\\nsamples = 1\\nvalue = 24.7'),\n",
" Text(0.9518796992481203, 0.2916666666666667, 'squared_error = 0.0\\nsamples = 1\\nvalue = 25.4'),\n",
" Text(0.9759398496240601, 0.5416666666666666, 'X[1] <= 40.8\\nsquared_error = 1.194\\nsamples = 5\\nvalue = 25.22'),\n",
" Text(0.9639097744360903, 0.4583333333333333, 'X[2] <= 78.35\\nsquared_error = 0.09\\nsamples = 2\\nvalue = 24.1'),\n",
" Text(0.9578947368421052, 0.375, 'squared_error = 0.0\\nsamples = 1\\nvalue = 24.4'),\n",
" Text(0.9699248120300752, 0.375, 'squared_error = -0.0\\nsamples = 1\\nvalue = 23.8'),\n",
" Text(0.98796992481203, 0.4583333333333333, 'X[1] <= 46.35\\nsquared_error = 0.536\\nsamples = 3\\nvalue = 25.967'),\n",
" Text(0.9819548872180451, 0.375, 'X[1] <= 42.9\\nsquared_error = 0.002\\nsamples = 2\\nvalue = 25.45'),\n",
" Text(0.9759398496240601, 0.2916666666666667, 'squared_error = 0.0\\nsamples = 1\\nvalue = 25.5'),\n",
" Text(0.98796992481203, 0.2916666666666667, 'squared_error = -0.0\\nsamples = 1\\nvalue = 25.4'),\n",
" Text(0.9939849624060151, 0.375, 'squared_error = 0.0\\nsamples = 1\\nvalue = 27.0')]"
]
},
"metadata": {},
"execution_count": 13
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 1080x1080 with 1 Axes>"
],
"image/png": "\n"
},
"metadata": {
"needs_background": "light"
}
}
]
},
{
"cell_type": "code",
"source": [
"simple_dt_reg = tree.DecisionTreeRegressor(random_state=0)\n",
"\n",
"scores = cross_val_score(simple_dt_reg, X_reg, y_reg, cv=5)\n",
"print(\"%0.2f R^2 with a standard deviation of %0.2f\" % (scores.mean(), scores.std()))\n",
"\n",
"scores = cross_val_score(simple_dt_reg, X_reg, y_reg, cv=5, scoring='neg_mean_squared_error')\n",
"print(\"%0.2f MSE with a standard deviation of %0.2f\" % (scores.mean(), scores.std()))"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Ry9vRQTlZKq2",
"outputId": "b11a2215-cc60-467f-f79b-c1bec260792d"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"0.95 R^2 with a standard deviation of 0.01\n",
"-1.23 MSE with a standard deviation of 0.26\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"plt.rcParams['figure.figsize'] = [8, 8]\n",
"simple_dt_reg = tree.DecisionTreeRegressor(random_state=0)\n",
"simple_dt_reg.fit(X_reg_train, y_reg_train)\n",
"importances = simple_dt_reg.feature_importances_\n",
"feature_names = [\"TV\", \"Radio\", \"Newspaper\"]\n",
"forest_importances = pd.Series(importances, index=feature_names)\n",
"#std = np.std([tree.feature_importances_ for tree in simple_dt_reg.estimators_], axis=0)\n",
"fig, ax = plt.subplots()\n",
"forest_importances.plot.bar( ax=ax)\n",
"ax.set_title(\"Feature importances\")\n",
"fig.tight_layout()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 585
},
"id": "-2Wb3_CEqXat",
"outputId": "c3d1fcc6-6b43-4050-8a76-19b0bfad71c5"
},
"execution_count": null,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 576x576 with 1 Axes>"
],
"image/png": "\n"
},
"metadata": {
"needs_background": "light"
}
}
]
},
{
"cell_type": "markdown",
"source": [
"## Классификация\n",
"\n",
"Аналогично задаче регрессии выполняется жадный алгоритм бинарного деления, но целевая функция соответсвует качественным данных.\n",
"\n",
"Варианты целевой функции:\n",
"1. Gini index $G = \\sum_{k=1}^K \\hat{p}_{mk}(1-\\hat{p}_{mk})$\n",
"2. Entropy $D = -\\sum_{k=1}^K \\hat{p}_{mk}\\log(\\hat{p}_{mk})$\n",
"\n",
"где $\\hat{p}_{mk}$ это отношение частоты наблюдений класса $k$ к общему числу наблюдений в области $m$.\n"
],
"metadata": {
"id": "Cd2PaHLWZ32U"
}
},
{
"cell_type": "code",
"source": [
"# https://scikit-learn.org/stable/modules/generated/sklearn.tree.DecisionTreeClassifier.html\n",
"simple_dt_class = tree.DecisionTreeClassifier(random_state=0, max_depth=3)\n",
"simple_dt_class.fit(X_class_train, y_class_train)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "X3bMnPMHZwP3",
"outputId": "262f06b3-bce6-4aea-a1cc-b682ab40bec4"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"DecisionTreeClassifier(max_depth=3, random_state=0)"
]
},
"metadata": {},
"execution_count": 16
}
]
},
{
"cell_type": "code",
"source": [
"plt.rcParams['figure.figsize'] = [15, 15]\n",
"tree.plot_tree(simple_dt_class, fontsize=10)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
},
"id": "S6My912WcOV1",
"outputId": "19e07c77-1612-4464-84e4-67b3f3aa7d7e"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"[Text(0.625, 0.875, 'X[6] <= 5.615\\ngini = 0.664\\nsamples = 133\\nvalue = [44, 49, 40]'),\n",
" Text(0.5, 0.625, 'X[0] <= 13.115\\ngini = 0.51\\nsamples = 85\\nvalue = [44, 1, 40]'),\n",
" Text(0.25, 0.375, 'X[6] <= 4.826\\ngini = 0.227\\nsamples = 46\\nvalue = [6, 0, 40]'),\n",
" Text(0.125, 0.125, 'gini = 0.278\\nsamples = 6\\nvalue = [5, 0, 1]'),\n",
" Text(0.375, 0.125, 'gini = 0.049\\nsamples = 40\\nvalue = [1, 0, 39]'),\n",
" Text(0.75, 0.375, 'X[4] <= 3.544\\ngini = 0.05\\nsamples = 39\\nvalue = [38, 1, 0]'),\n",
" Text(0.625, 0.125, 'gini = 0.0\\nsamples = 37\\nvalue = [37, 0, 0]'),\n",
" Text(0.875, 0.125, 'gini = 0.5\\nsamples = 2\\nvalue = [1, 1, 0]'),\n",
" Text(0.75, 0.625, 'gini = 0.0\\nsamples = 48\\nvalue = [0, 48, 0]')]"
]
},
"metadata": {},
"execution_count": 17
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 1080x1080 with 1 Axes>"
],
"image/png": "\n"
},
"metadata": {
"needs_background": "light"
}
}
]
},
{
"cell_type": "code",
"source": [
"simple_dt_class = tree.DecisionTreeClassifier(random_state=0)\n",
"scores = cross_val_score(simple_dt_class, X_class, y_class, cv=5, scoring='roc_auc_ovr')\n",
"print(\"Simple DT has %0.4f OvR AUC with a standard deviation of %0.2f\" % (scores.mean(), scores.std()))"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "w7ixK22Hcc47",
"outputId": "6f587233-226c-4b4e-8acc-b6f65d11d16c"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Simple DT has 0.9135 OvR AUC with a standard deviation of 0.03\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"plt.rcParams['figure.figsize'] = [8, 8]\n",
"simple_dt_class = tree.DecisionTreeClassifier(random_state=0)\n",
"simple_dt_class.fit(X_class_train, y_class_train)\n",
"importances = simple_dt_class.feature_importances_\n",
"feature_names = [\"area\", \"perimeter\", \"compactness\", \"length\", \"width\", \"asymmetry\", \"length_groove\"]\n",
"forest_importances = pd.Series(importances, index=feature_names)\n",
"#std = np.std([tree.feature_importances_ for tree in simple_dt_reg.estimators_], axis=0)\n",
"fig, ax = plt.subplots()\n",
"forest_importances.plot.bar( ax=ax)\n",
"ax.set_title(\"Feature importances\")\n",
"fig.tight_layout()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 585
},
"id": "UOvCQhJDt2Vb",
"outputId": "af078d15-972e-467d-b11c-cbfdc599829a"
},
"execution_count": null,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 576x576 with 1 Axes>"
],
"image/png": "\n"
},
"metadata": {
"needs_background": "light"
}
}
]
},
{
"cell_type": "markdown",
"source": [
"# Ансабли моделей"
],
"metadata": {
"id": "AgBq6oZhdZHf"
}
},
{
"cell_type": "code",
"source": [
"# https://scikit-learn.org/stable/modules/ensemble.html"
],
"metadata": {
"id": "hPvjJhqUfrhp"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Bagging\n"
],
"metadata": {
"id": "JgCcijQlfftr"
}
},
{
"cell_type": "markdown",
"source": [
"Алгоритм:\n",
"\n",
"1. Симулировать $B$ различны тренирововчных данных (на основе ECDF).\n",
"2. Построить $B$ моделей $\\hat{f}^1(x),\\hat{f}^2(x),...,\\hat{f}^B(x)$ с использование $B$ наборов данных методом бинарного деления.\n",
"3. Взять среднее значение (для регрессии) или принять класс большинством предсказаний (классификация): $\\hat{f}^{bag}(x)=\\frac{1}{B}\\sum_{b=1}^B\\hat{f}^b(x)$\n",
"\n"
],
"metadata": {
"id": "PcioAa3T6ZqX"
}
},
{
"cell_type": "markdown",
"source": [
"### Регрессия"
],
"metadata": {
"id": "IxPQszz9f8D9"
}
},
{
"cell_type": "code",
"source": [
"# https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.BaggingRegressor.html\n",
"bagging_reg = BaggingRegressor(n_estimators=10, random_state=0)\n",
"scores = cross_val_score(bagging_reg, X_reg, y_reg, cv=5)\n",
"print(\"%0.2f R^2 with a standard deviation of %0.2f\" % (scores.mean(), scores.std()))\n",
"\n",
"scores = cross_val_score(bagging_reg, X_reg, y_reg, cv=5, scoring='neg_mean_squared_error')\n",
"print(\"%0.2f MSE with a standard deviation of %0.2f\" % (scores.mean(), scores.std()))"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Xv6jGaQbgee3",
"outputId": "f95991b4-e78b-424d-9294-97af974bb75e"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"0.97 R^2 with a standard deviation of 0.01\n",
"-0.84 MSE with a standard deviation of 0.28\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"### Классификация"
],
"metadata": {
"id": "gRMrVl5Hgf93"
}
},
{
"cell_type": "code",
"source": [
"# https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.BaggingClassifier.html\n",
"bagging_class = BaggingClassifier(random_state=0)\n",
"scores = cross_val_score(bagging_class, X_class, y_class, cv=5, scoring='roc_auc_ovr')\n",
"print(\"bagging has %0.4f OvR AUC with a standard deviation of %0.2f\" % (scores.mean(), scores.std()))"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "eEfj_ADGf55i",
"outputId": "c54d5a85-b718-4fe7-9a6a-61e0f51fe029"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"bagging has 0.9814 OvR AUC with a standard deviation of 0.01\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"## Random forest\n"
],
"metadata": {
"id": "3IxqI6VifSb-"
}
},
{
"cell_type": "markdown",
"source": [
"Алгоритм:\n",
"\n",
"1. Симулировать $B$ различных тренировочных данных (на основе ECDF).\n",
"2. Построить $B$ моделей $\\hat{f}^1(x),\\hat{f}^2(x),...,\\hat{f}^B(x)$ с использование $B$ наборов данных методом бинарного деления. **Для каждого дерева выбираются $m$ предикторов $m \\leq p$**.\n",
"3. Взять среднее значение (для регрессии) или принять класс большинством предсказаний (классификация): $\\hat{f}^{bag}(x)=\\frac{1}{B}\\sum_{b=1}^B\\hat{f}^b(x)$"
],
"metadata": {
"id": "k9ohp-QvBwVv"
}
},
{
"cell_type": "markdown",
"source": [
"### Регрессия\n"
],
"metadata": {
"id": "29WRwXwEicqR"
}
},
{
"cell_type": "code",
"source": [
"# https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestRegressor.html#sklearn.ensemble.RandomForestRegressor\n",
"random_forest_reg = RandomForestRegressor(random_state=0)\n",
"scores = cross_val_score(random_forest_reg, X_reg, y_reg, cv=5)\n",
"print(\"%0.2f R^2 with a standard deviation of %0.2f\" % (scores.mean(), scores.std()))\n",
"\n",
"scores = cross_val_score(random_forest_reg, X_reg, y_reg, cv=5, scoring='neg_mean_squared_error')\n",
"print(\"%0.2f MSE with a standard deviation of %0.2f\" % (scores.mean(), scores.std()))"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "6HcICOj6ijMH",
"outputId": "0086baad-c1b0-41f7-c0ad-37b23070d07e"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"0.98 R^2 with a standard deviation of 0.01\n",
"-0.64 MSE with a standard deviation of 0.23\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"plt.rcParams['figure.figsize'] = [8, 8]\n",
"random_forest_reg = RandomForestRegressor(random_state=0)\n",
"random_forest_reg.fit(X_reg_train, y_reg_train)\n",
"importances = random_forest_reg.feature_importances_\n",
"feature_names = [\"TV\", \"Radio\", \"Newspaper\"]\n",
"forest_importances = pd.Series(importances, index=feature_names)\n",
"std = np.std([tree.feature_importances_ for tree in random_forest_reg.estimators_], axis=0)\n",
"fig, ax = plt.subplots()\n",
"forest_importances.plot.bar(yerr=std,ax=ax)\n",
"ax.set_title(\"Feature importances\")\n",
"fig.tight_layout()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 585
},
"id": "VmqEKMjXunks",
"outputId": "2e3166cf-b5ae-4d41-fe52-51b9fc9c0460"
},
"execution_count": null,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 576x576 with 1 Axes>"
],
"image/png": "\n"
},
"metadata": {
"needs_background": "light"
}
}
]
},
{
"cell_type": "markdown",
"source": [
"### Классификация"
],
"metadata": {
"id": "iS0NxlALigUc"
}
},
{
"cell_type": "code",
"source": [
"# https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestClassifier.html#sklearn.ensemble.RandomForestClassifier\n",
"random_forest_class = RandomForestClassifier(random_state=0)\n",
"scores = cross_val_score(random_forest_class, X_class, y_class, cv=5, scoring='roc_auc_ovr')\n",
"print(\"Random Forest has %0.4f OvR AUC with a standard deviation of %0.2f\" % (scores.mean(), scores.std()))"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "2K52DSRxfY3A",
"outputId": "25093335-3e10-48a1-8b93-485d5f93a498"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Random Forest has 0.9832 OvR AUC with a standard deviation of 0.02\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"plt.rcParams['figure.figsize'] = [8, 8]\n",
"random_forest_class = RandomForestClassifier(random_state=0)\n",
"random_forest_class.fit(X_class_train, y_class_train)\n",
"importances = random_forest_class.feature_importances_\n",
"feature_names = [\"area\", \"perimeter\", \"compactness\", \"length\", \"width\", \"asymmetry\", \"length_groove\"]\n",
"forest_importances = pd.Series(importances, index=feature_names)\n",
"std = np.std([tree.feature_importances_ for tree in random_forest_class.estimators_], axis=0)\n",
"fig, ax = plt.subplots()\n",
"forest_importances.plot.bar(yerr=std,ax=ax)\n",
"ax.set_title(\"Feature importances\")\n",
"fig.tight_layout()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 585
},
"id": "zI_rRXmMvNC_",
"outputId": "b8c6a45e-1a78-4608-883c-655b4ce2ad59"
},
"execution_count": null,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 576x576 with 1 Axes>"
],
"image/png": "\n"
},
"metadata": {
"needs_background": "light"
}
}
]
},
{
"cell_type": "markdown",
"source": [
"## Boosting"
],
"metadata": {
"id": "vPB9QDX7fWM2"
}
},
{
"cell_type": "markdown",
"source": [
"Алгоритм:\n",
"\n",
"1. Инициализация $f_0(x)=\\mathrm{argmin}_\\gamma\\sum_{i=1}^N L(y_i,\\gamma)$\n",
"2. Для $m=1,2,...,M$:\n",
"\n",
"(а) вычисление псевдо-невязок (pseudo-residuals): $r_{im}=-\\left[ \\frac{\\partial L(y_i, f(x_i))}{\\partial f(x_i)} \\right]_{f=f_{m-1}}$ для $i=1,2,...,N$\n",
"\n",
"(б) подбор дерева на наборе $\\{(x_i,r_{im}) \\}_{i=1}^N$ для заданного набора областей $R_{jm}$, $j=1,2,...,J_m$\n",
"\n",
"(в) вычислить множитель $\\gamma$:\n",
"\n",
"$\\gamma_{m}=\\mathrm{argmin}_\\gamma \\sum_{x_i \\in R_{jm}}L(y_i,f_{m-1}(x_i)+\\gamma)$\n",
"\n",
"г) обновить модель: $f_m(x)=f_{m-1}(x)+\\alpha\\sum_{j=1}^{J_m}\\gamma_m I(x \\in R_{jm})$"
],
"metadata": {
"id": "5VfPPsgwCx4P"
}
},
{
"cell_type": "markdown",
"source": [
"### Регрессия"
],
"metadata": {
"id": "ceoFCPtAmL6u"
}
},
{
"cell_type": "code",
"source": [
"# https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.GradientBoostingRegressor.html#sklearn.ensemble.GradientBoostingRegressor\n",
"boosting_reg = GradientBoostingRegressor(random_state=0, criterion='squared_error', learning_rate=0.1)\n",
"scores = cross_val_score(boosting_reg, X_reg, y_reg, cv=5)\n",
"print(\"%0.2f R^2 with a standard deviation of %0.2f\" % (scores.mean(), scores.std()))\n",
"\n",
"scores = cross_val_score(boosting_reg, X_reg, y_reg, cv=5, scoring='neg_mean_squared_error')\n",
"print(\"%0.2f MSE with a standard deviation of %0.2f\" % (scores.mean(), scores.std()))"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "O_b-JNbNmShW",
"outputId": "d9586316-436c-45d6-8749-0db6ea745136"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"0.98 R^2 with a standard deviation of 0.01\n",
"-0.59 MSE with a standard deviation of 0.24\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
" ### Классификация "
],
"metadata": {
"id": "YgMZnPFOmOVZ"
}
},
{
"cell_type": "code",
"source": [
"# https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.GradientBoostingClassifier.html#sklearn.ensemble.GradientBoostingClassifier\n",
"boosting_class = GradientBoostingClassifier(random_state=0, n_estimators=1000, max_depth=3, learning_rate=0.9)\n",
"scores = cross_val_score(boosting_class, X_class, y_class, cv=5, scoring='roc_auc_ovr')\n",
"print(\"Boosting has %0.4f OvR AUC with a standard deviation of %0.2f\" % (scores.mean(), scores.std()))"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Jxi9Vg4SdGq8",
"outputId": "d968ee29-33e5-4c2c-ba51-4b2bbfdcaac9"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Boosting has 0.9842 OvR AUC with a standard deviation of 0.03\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"# Выбор модели\n"
],
"metadata": {
"id": "_br7bMzZprvW"
}
},
{
"cell_type": "markdown",
"source": [
"## Регрессия\n",
"\n"
],
"metadata": {
"id": "-CuS7pElpvUV"
}
},
{
"cell_type": "markdown",
"source": [
"Model Name | $r^2$ | Mean Squared Error|\n",
"----------------|------------------|-------------------|\n",
"LR |$0.89\\pm0.04$ |$3.07\\pm1.28$ |\n",
"LR poly 2 |$0.98\\pm0.01$ |$0.44\\pm0.39$ |\n",
"LR poly 3 |$\\bf0.99\\pm0.01$ |$\\bf0.31\\pm0.24$ |\n",
"NN |$0.91\\pm1.61$ |$1.86\\pm1.49$ |\n",
"Simple DT |$0.95\\pm0.1$ |$1.23 \\pm0.26$ |\n",
"Bagging |$0.97 \\pm0.1$ |$0.84 \\pm0.28$ | \n",
"Random Forest |$0.98\\pm0.1$ |$0.64 \\pm0.23$ | \n",
"Boosting |$0.98\\pm0.1$ |$0.59 \\pm0.24$ | "
],
"metadata": {
"id": "I5x60DCFp52H"
}
},
{
"cell_type": "markdown",
"source": [
"## Классификация\n",
"\n",
"```\n",
"Logit has 0.9961 ± 0.01 OvR AUC\n",
"LDA has 0.9937 ± 0.01 OvR AUC\n",
"QDA has 0.9888 ± 0.02 OvR AUC\n",
"Simple DT has 0.9135 ± 0.03 OvR AUC\n",
"bagging has 0.9814 ± 0.01 OvR AUC\n",
"Random Forest has 0.9832 ± 0.02 OvR AUC\n",
"Boosting has 0.9842 ± 0.03 OvR AUC\n",
"```"
],
"metadata": {
"id": "GYj6zH9Bxhrh"
}
},
{
"cell_type": "code",
"source": [
""
],
"metadata": {
"id": "a7SEwgWqyzW8"
},
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment