Skip to content

Instantly share code, notes, and snippets.

@avidale
Last active May 25, 2023 21:23
Show Gist options
  • Save avidale/2aa018e6264a3e3887132d522b783561 to your computer and use it in GitHub Desktop.
Save avidale/2aa018e6264a3e3887132d522b783561 to your computer and use it in GitHub Desktop.
micro_segments.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "micro_segments.ipynb",
"provenance": [],
"collapsed_sections": [],
"authorship_tag": "ABX9TyPGpgoCSwJAdfRpy9ck68my",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"widgets": {
"application/vnd.jupyter.widget-state+json": {
"d5704653ebc8401181381c3a1d770285": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"state": {
"_view_name": "HBoxView",
"_dom_classes": [],
"_model_name": "HBoxModel",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.5.0",
"box_style": "",
"layout": "IPY_MODEL_509388a136054f9e9e4f880f6e1c08c5",
"_model_module": "@jupyter-widgets/controls",
"children": [
"IPY_MODEL_6727b6dbc88a48e986b220cbd4c09c47",
"IPY_MODEL_3d486747308a4264b75131b5c8106554"
]
}
},
"509388a136054f9e9e4f880f6e1c08c5": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"6727b6dbc88a48e986b220cbd4c09c47": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"state": {
"_view_name": "ProgressView",
"style": "IPY_MODEL_c4f04b9d42994e25a4f33bdc418e8db0",
"_dom_classes": [],
"description": "100%",
"_model_name": "FloatProgressModel",
"bar_style": "success",
"max": 131,
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"value": 131,
"_view_count": null,
"_view_module_version": "1.5.0",
"orientation": "horizontal",
"min": 0,
"description_tooltip": null,
"_model_module": "@jupyter-widgets/controls",
"layout": "IPY_MODEL_ea70bcd7a2504ce295ce001201b2e67c"
}
},
"3d486747308a4264b75131b5c8106554": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"state": {
"_view_name": "HTMLView",
"style": "IPY_MODEL_03db337abd384a07a426b5171947a9df",
"_dom_classes": [],
"description": "",
"_model_name": "HTMLModel",
"placeholder": "​",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"value": " 131/131 [11:16<00:00, 5.16s/it]",
"_view_count": null,
"_view_module_version": "1.5.0",
"description_tooltip": null,
"_model_module": "@jupyter-widgets/controls",
"layout": "IPY_MODEL_62f89de11c794b859bbb58cf1764b1ac"
}
},
"c4f04b9d42994e25a4f33bdc418e8db0": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"state": {
"_view_name": "StyleView",
"_model_name": "ProgressStyleModel",
"description_width": "initial",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"bar_color": null,
"_model_module": "@jupyter-widgets/controls"
}
},
"ea70bcd7a2504ce295ce001201b2e67c": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"03db337abd384a07a426b5171947a9df": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"state": {
"_view_name": "StyleView",
"_model_name": "DescriptionStyleModel",
"description_width": "",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"_model_module": "@jupyter-widgets/controls"
}
},
"62f89de11c794b859bbb58cf1764b1ac": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
}
}
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/avidale/2aa018e6264a3e3887132d522b783561/micro_segments.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "I-KZFqkfUk-H"
},
"source": [
"В этом блокноте решается следующая задача: разделить пользователей на интерпретируемые микро-сегменты в зависимости от того, насколько в них выражена определённая характеристика. Делать это будем с помощью деревьев решений. \n",
"\n",
"Данные для примера возьмём из репозитория UCI: https://archive.ics.uci.edu/ml/datasets/Bank+Marketing.\n",
"\n",
"Они описывают маркетинговую кампанию банка. Целевая переменная - удалось ли впарить клиенту депозит. \n",
"\n",
"Получается, что мы должны выделить интерпретируемые сегментов, такие, которые или очень хорошо, или очень плохо откликаются на кампанию. Примеры этих сегментов могут выглядеть, например, так:\n",
"\n",
"* \"если исход прошлой маркетинговой кампании для данного клиента был успешным, и разговор длится больше 8 минут, то с вероятностью 80% клиент откроет депозит\"\n",
"* \"если у клиента есть ипотека, разговор шёл меньше 4 минут и был не в мае, то клиент откроет депозит только с вероятностью 0.3%\""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "S_6F9eJdkwXE"
},
"source": [
"# Обучение классификаторов"
]
},
{
"cell_type": "code",
"metadata": {
"id": "F2AvqtgwUOsp",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "a9f4a744-04ed-4150-f754-e541e878630f"
},
"source": [
"!wget https://archive.ics.uci.edu/ml/machine-learning-databases/00222/bank.zip\n",
"!unzip bank.zip"
],
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"text": [
"--2021-04-10 22:26:02-- https://archive.ics.uci.edu/ml/machine-learning-databases/00222/bank.zip\n",
"Resolving archive.ics.uci.edu (archive.ics.uci.edu)... 128.195.10.252\n",
"Connecting to archive.ics.uci.edu (archive.ics.uci.edu)|128.195.10.252|:443... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 579043 (565K) [application/x-httpd-php]\n",
"Saving to: ‘bank.zip’\n",
"\n",
"bank.zip 100%[===================>] 565.47K 1.72MB/s in 0.3s \n",
"\n",
"2021-04-10 22:26:03 (1.72 MB/s) - ‘bank.zip’ saved [579043/579043]\n",
"\n",
"Archive: bank.zip\n",
" inflating: bank-full.csv \n",
" inflating: bank-names.txt \n",
" inflating: bank.csv \n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "T74eJHdpUZ5z"
},
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import sklearn\n",
"from sklearn.model_selection import train_test_split, cross_val_score\n",
"from sklearn.base import BaseEstimator, TransformerMixin\n",
"from sklearn.pipeline import make_pipeline\n",
"from sklearn.tree import DecisionTreeClassifier\n",
"from sklearn.ensemble import RandomForestClassifier\n",
"from sklearn.metrics import roc_auc_score\n",
"import copy\n",
"from tqdm.auto import tqdm, trange"
],
"execution_count": 2,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "273VyZtjUdzm"
},
"source": [
"df = pd.read_csv('bank-full.csv', sep=';')"
],
"execution_count": 3,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 136
},
"id": "xW57fr-DUybF",
"outputId": "054cff36-c613-46a4-f27a-1d7ee1ddcf6b"
},
"source": [
"df.sample(3)"
],
"execution_count": 4,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>age</th>\n",
" <th>job</th>\n",
" <th>marital</th>\n",
" <th>education</th>\n",
" <th>default</th>\n",
" <th>balance</th>\n",
" <th>housing</th>\n",
" <th>loan</th>\n",
" <th>contact</th>\n",
" <th>day</th>\n",
" <th>month</th>\n",
" <th>duration</th>\n",
" <th>campaign</th>\n",
" <th>pdays</th>\n",
" <th>previous</th>\n",
" <th>poutcome</th>\n",
" <th>y</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>33063</th>\n",
" <td>32</td>\n",
" <td>services</td>\n",
" <td>single</td>\n",
" <td>secondary</td>\n",
" <td>no</td>\n",
" <td>49</td>\n",
" <td>yes</td>\n",
" <td>no</td>\n",
" <td>cellular</td>\n",
" <td>17</td>\n",
" <td>apr</td>\n",
" <td>184</td>\n",
" <td>4</td>\n",
" <td>-1</td>\n",
" <td>0</td>\n",
" <td>unknown</td>\n",
" <td>no</td>\n",
" </tr>\n",
" <tr>\n",
" <th>21261</th>\n",
" <td>51</td>\n",
" <td>blue-collar</td>\n",
" <td>married</td>\n",
" <td>secondary</td>\n",
" <td>no</td>\n",
" <td>193</td>\n",
" <td>no</td>\n",
" <td>no</td>\n",
" <td>cellular</td>\n",
" <td>18</td>\n",
" <td>aug</td>\n",
" <td>146</td>\n",
" <td>5</td>\n",
" <td>-1</td>\n",
" <td>0</td>\n",
" <td>unknown</td>\n",
" <td>no</td>\n",
" </tr>\n",
" <tr>\n",
" <th>37061</th>\n",
" <td>42</td>\n",
" <td>blue-collar</td>\n",
" <td>divorced</td>\n",
" <td>secondary</td>\n",
" <td>no</td>\n",
" <td>-580</td>\n",
" <td>yes</td>\n",
" <td>yes</td>\n",
" <td>cellular</td>\n",
" <td>13</td>\n",
" <td>may</td>\n",
" <td>172</td>\n",
" <td>1</td>\n",
" <td>-1</td>\n",
" <td>0</td>\n",
" <td>unknown</td>\n",
" <td>no</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" age job marital education ... pdays previous poutcome y\n",
"33063 32 services single secondary ... -1 0 unknown no\n",
"21261 51 blue-collar married secondary ... -1 0 unknown no\n",
"37061 42 blue-collar divorced secondary ... -1 0 unknown no\n",
"\n",
"[3 rows x 17 columns]"
]
},
"metadata": {
"tags": []
},
"execution_count": 4
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "l-2dUQT0aEvG"
},
"source": [
"Из 45 тысяч клиентов 5 в итоге окрыли депозит. Крутая конверсия!"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "T62tISn7U5iY",
"outputId": "c8c60cd0-ab46-4683-fb30-43703d8984d5"
},
"source": [
"df.y.value_counts()"
],
"execution_count": 5,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"no 39922\n",
"yes 5289\n",
"Name: y, dtype: int64"
]
},
"metadata": {
"tags": []
},
"execution_count": 5
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "XPVmX_OeV-Fx"
},
"source": [
"X = df.drop('y', axis=1)\n",
"y = (df.y == 'yes').astype(int)\n",
"\n",
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=1)"
],
"execution_count": 6,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "KfP5ydUcU7n4"
},
"source": [
"class DummyTransformer(BaseEstimator, TransformerMixin):\n",
" def fit(self, X, y=None):\n",
" assert isinstance(X, pd.DataFrame)\n",
" self.columns_ = pd.get_dummies(X).columns\n",
" return self\n",
"\n",
" def transform(self, X, y=None):\n",
" X = pd.get_dummies(X)\n",
" return X.reindex(self.columns_, axis=1).fillna(0)"
],
"execution_count": 7,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "AIcvjfrmV7L3"
},
"source": [
"dt = DummyTransformer()\n",
"X_train2 = dt.fit_transform(X_train)\n",
"X_test2 = dt.transform(X_test)"
],
"execution_count": 8,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "YABrczaxZHSK"
},
"source": [
"Обучаем большой Random Forest. Его точность классификации - это примерно бейзлайн, на который мы можем равняться, оценивая качество решающего дерева. "
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "CZYbGGLzW5Cu",
"outputId": "07dfea09-0efc-420c-c966-cbab905b0079"
},
"source": [
"clf = RandomForestClassifier(n_estimators=300, min_samples_leaf=5, random_state=1)\n",
"clf.fit(X_train2, y_train)\n",
"print(roc_auc_score(y_test, clf.predict_proba(X_test2)[:, 1]))"
],
"execution_count": 9,
"outputs": [
{
"output_type": "stream",
"text": [
"0.9324686481623802\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dfq7pS_dZQES"
},
"source": [
"Сложность деревьев можно контролировать параметром min_samples_leaf, контролирующим размер минимально возможный размер листа. \n",
"\n",
"Чем этот параметр больше, тем раньше будет дерево прекращать отращивать новые ветки. Это сокращает возможность дерева подогнаться под обучающую выборку, но зато и сокращает возможность переобучения. \n",
"\n",
"Конкретно на нашем датасете оптимальная сложность дерева наступает при примерно 150 наблюдениях в листе. Качество всё ещё ниже, чем у random forest, но не сильно - 91% против 93%. Ради интерпретируемости на такую жертву пойти можно. "
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "ErkhVLs2X0Kq",
"outputId": "501fbb10-5ad2-4d21-fe50-2be401c35657"
},
"source": [
"for msl in np.logspace(0, 3, num=16):\n",
" msl = int(msl)\n",
" clf = DecisionTreeClassifier(min_samples_leaf=msl)\n",
" clf.fit(X_train2, y_train)\n",
" print('{}\\t {:2.4f}\\t {:2.4f}'.format(\n",
" msl, \n",
" roc_auc_score(y_test, clf.predict_proba(X_test2)[:, 1]),\n",
" roc_auc_score(y_train, clf.predict_proba(X_train2)[:, 1])\n",
" ))"
],
"execution_count": 10,
"outputs": [
{
"output_type": "stream",
"text": [
"1\t 0.7278\t 1.0000\n",
"1\t 0.7197\t 1.0000\n",
"2\t 0.7545\t 0.9976\n",
"3\t 0.7790\t 0.9932\n",
"6\t 0.8279\t 0.9802\n",
"10\t 0.8569\t 0.9689\n",
"15\t 0.8749\t 0.9611\n",
"25\t 0.8955\t 0.9507\n",
"39\t 0.8992\t 0.9432\n",
"63\t 0.9056\t 0.9355\n",
"100\t 0.9069\t 0.9284\n",
"158\t 0.9096\t 0.9222\n",
"251\t 0.9062\t 0.9136\n",
"398\t 0.8954\t 0.9006\n",
"630\t 0.8883\t 0.8930\n",
"1000\t 0.8665\t 0.8737\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IvZKlBFKfINf"
},
"source": [
"Нарисуем оптимальное дерево. Оно получается довольно сложным, несмотря на ограничения. "
]
},
{
"cell_type": "code",
"metadata": {
"id": "IcTDnlo6fG-n"
},
"source": [
"clf = DecisionTreeClassifier(min_samples_leaf=150)\n",
"clf.fit(X_train2, y_train);"
],
"execution_count": 11,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 357
},
"id": "t1oFa0NTfHoz",
"outputId": "1d0c676d-7bd7-4c36-f309-5bf9d136736f"
},
"source": [
"plt.figure(figsize=(12, 6))\n",
"sklearn.tree.plot_tree(clf, feature_names=dt.columns_, filled=True, );"
],
"execution_count": 12,
"outputs": [
{
"output_type": "display_data",
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 864x432 with 1 Axes>"
]
},
"metadata": {
"tags": [],
"needs_background": "light"
}
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "I-dsFdPzk1C-"
},
"source": [
"# Извлечение листьев"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "gnYmGgcgYlAi"
},
"source": [
"Теперь извлечём из нашего дерева листья.\n",
"\n",
"Структура дерева хранится в `clf.tree_` в виде списков cвойств узлов:\n",
"* `feature` - номер признака, по которому шло ветвление\n",
"* `threshold` - порог ветвления\n",
"* `children_left` - левые дочерние узлы\n",
"* `children_right` - правые дочерние узлы\n",
"* `value` - частоты классов в этом узле\n",
"\n",
"Листья - это те узлы, у которых нет дочерних. \n",
"\n",
"Оказывается, у нас 131 лист, и значит потенциально 131 микросегмент. "
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "OCXD9TMvcF6M",
"outputId": "a8947e09-b0f8-4d5d-a797-75531df5411c"
},
"source": [
"\n",
"print('nodes:', len(clf.tree_.children_left))\n",
"leaf_indices = np.where(clf.tree_.children_left == -1)[0]\n",
"leaf_indices\n",
"print('leaves:', len(leaf_indices))"
],
"execution_count": 13,
"outputs": [
{
"output_type": "stream",
"text": [
"nodes: 261\n",
"leaves: 131\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "OA45v39Yd7t0"
},
"source": [
"Очень много листьев с около-нулевой конверсией, но есть и много листьев с очень высокой конверсией. "
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 281
},
"id": "KH6leNvrcJan",
"outputId": "221c9ff7-e2f5-40c5-9799-48c20f48daab"
},
"source": [
"node_conversion = clf.tree_.value[:, 0, 1] / clf.tree_.value[:, 0, :].sum(axis=-1)\n",
"leaf_conversion = node_conversion[leaf_indices]\n",
"root_conversion = node_conversion[0]\n",
"plt.hist(leaf_conversion, bins=20)\n",
"plt.vlines([root_conversion], *plt.ylim())\n",
"plt.title('Распределение листьев по конверсии в них')\n",
"plt.legend(['средняя конверсия', 'листья']);"
],
"execution_count": 14,
"outputs": [
{
"output_type": "display_data",
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"tags": [],
"needs_background": "light"
}
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "1UQNCzigegoG"
},
"source": [
"Восстановим правила, соответствующие каждому ветвлению, и заодно запишем номер родительского для каждого узла. "
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Dr5Xz-4VeYRF",
"outputId": "304ed095-57b8-432d-ef6f-f5200ac9a28c"
},
"source": [
"t = clf.tree_\n",
"n = len(t.threshold)\n",
"named_feaures = [dt.columns_[f_id] for f_id in t.feature]\n",
"parents = [-1] * n\n",
"\n",
"rules = [[]] * n\n",
"for i in range(0, n):\n",
" if t.children_left[i] < 0:\n",
" continue\n",
" rules[t.children_left[i]] = [named_feaures[i], '<=', t.threshold[i]]\n",
" rules[t.children_right[i]] = [named_feaures[i], '>', t.threshold[i]]\n",
" parents[t.children_left[i]] = i\n",
" parents[t.children_right[i]] = i\n",
"\n",
"print(rules[1:5])"
],
"execution_count": 15,
"outputs": [
{
"output_type": "stream",
"text": [
"[['duration', '<=', 473.5], ['poutcome_success', '<=', 0.5], ['age', '<=', 60.5], ['month_mar', '<=', 0.5]]\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "E0GOO6BVdhfK"
},
"source": [
"Напишем функцию, составляющую путь до узла с заданным номером"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "g4wVp58rjOl8",
"outputId": "10d03e4c-374f-4035-b63d-b924f2f1cf31"
},
"source": [
"def get_path(i):\n",
" parent = parents[i]\n",
" if parent > 0:\n",
" result = get_path(parent)\n",
" else:\n",
" result = []\n",
" result.append(i)\n",
" return result\n",
"\n",
"print(get_path(127))"
],
"execution_count": 16,
"outputs": [
{
"output_type": "stream",
"text": [
"[1, 2, 3, 4, 5, 115, 116, 117, 118, 120, 121, 122, 123, 127]\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "spTnO5utjzoj"
},
"source": [
"Теперь мы можем описать, в каких листьях конверсия самая большая..."
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "dZeJiNDq4-dm",
"outputId": "4f4bc592-abd0-4de2-8c02-f876dcc58b19"
},
"source": [
"leaf_indices"
],
"execution_count": 17,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"array([ 12, 13, 15, 16, 19, 21, 22, 28, 29, 30, 34, 35, 38,\n",
" 39, 40, 42, 45, 47, 49, 51, 52, 54, 55, 56, 61, 63,\n",
" 65, 66, 68, 69, 71, 72, 74, 76, 78, 79, 81, 86, 87,\n",
" 89, 90, 92, 93, 95, 96, 100, 101, 103, 104, 105, 108, 111,\n",
" 112, 113, 114, 119, 125, 126, 127, 130, 131, 132, 136, 137, 138,\n",
" 139, 141, 142, 145, 146, 147, 151, 152, 154, 155, 157, 158, 163,\n",
" 164, 166, 172, 173, 175, 176, 178, 179, 182, 184, 185, 187, 188,\n",
" 190, 192, 193, 195, 196, 197, 199, 200, 201, 203, 206, 207, 208,\n",
" 210, 212, 214, 215, 222, 224, 225, 228, 229, 230, 234, 235, 237,\n",
" 238, 239, 242, 244, 245, 246, 247, 250, 251, 254, 255, 257, 259,\n",
" 260])"
]
},
"metadata": {
"tags": []
},
"execution_count": 17
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "3XpbffyPimzN",
"outputId": "9d36bfb4-2b21-4af3-f486-3d5cac670d54"
},
"source": [
"ranks = np.argsort(leaf_conversion)\n",
"for id in ranks[:-4:-1]:\n",
" print(leaf_conversion[id])\n",
" print([rules[i] for i in get_path(leaf_indices[id])])"
],
"execution_count": 18,
"outputs": [
{
"output_type": "stream",
"text": [
"0.8175438596491228\n",
"[['duration', '<=', 473.5], ['poutcome_success', '>', 0.5], ['duration', '>', 132.5], ['housing_no', '>', 0.5], ['duration', '>', 254.0]]\n",
"0.8012422360248447\n",
"[['duration', '>', 473.5], ['duration', '<=', 827.5], ['poutcome_success', '>', 0.5]]\n",
"0.7316017316017316\n",
"[['duration', '>', 473.5], ['duration', '>', 827.5], ['contact_cellular', '>', 0.5], ['day', '<=', 15.5], ['day', '>', 8.5]]\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "otEK1Ha9jp7Z"
},
"source": [
"И самая маленькая"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "6xwIIS4bkOQM",
"outputId": "41e0d59c-4048-4551-eb54-e78c6682cebb"
},
"source": [
"ranks = np.argsort(leaf_conversion)\n",
"for id in ranks[:3]:\n",
" print(leaf_conversion[id])\n",
" print([rules[i] for i in get_path(leaf_indices[id])])"
],
"execution_count": 19,
"outputs": [
{
"output_type": "stream",
"text": [
"0.0\n",
"[['duration', '<=', 473.5], ['poutcome_success', '<=', 0.5], ['age', '<=', 60.5], ['month_mar', '<=', 0.5], ['month_oct', '<=', 0.5], ['duration', '<=', 204.5], ['month_apr', '<=', 0.5], ['month_feb', '<=', 0.5], ['pdays', '<=', 9.0], ['age', '>', 28.5], ['day', '>', 3.5], ['duration', '<=', 149.5], ['month_jan', '<=', 0.5], ['balance', '>', 150.5], ['age', '>', 34.5], ['balance', '>', 196.5], ['job_admin.', '>', 0.5], ['day', '>', 17.5]]\n",
"0.0\n",
"[['duration', '<=', 473.5], ['poutcome_success', '<=', 0.5], ['age', '<=', 60.5], ['month_mar', '<=', 0.5], ['month_oct', '<=', 0.5], ['duration', '>', 204.5], ['housing_no', '>', 0.5], ['month_apr', '<=', 0.5], ['pdays', '<=', 20.0], ['age', '>', 29.5], ['day', '>', 3.5], ['contact_unknown', '<=', 0.5], ['duration', '<=', 313.5], ['balance', '<=', 1027.0], ['day', '>', 19.5], ['balance', '<=', 195.5]]\n",
"0.0\n",
"[['duration', '<=', 473.5], ['poutcome_success', '<=', 0.5], ['age', '<=', 60.5], ['month_mar', '<=', 0.5], ['month_oct', '<=', 0.5], ['duration', '<=', 204.5], ['month_apr', '<=', 0.5], ['month_feb', '<=', 0.5], ['pdays', '>', 9.0], ['pdays', '>', 100.5], ['housing_no', '<=', 0.5], ['duration', '<=', 155.5], ['day', '>', 12.5], ['duration', '>', 38.5]]\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "VFTavBfU4yAu"
},
"source": [
"Проблема в том, что из-за высокой глубины дерева большинство листьев содержит довольно много правил (половина листьев - больше 12 правил). \n",
"\n",
"Более того, часть этих правил - избыточные, например, если разговор был дольше 827 секунд, то он автоматически дольше 473 секунд. "
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 432
},
"id": "uJhyU7KK45LY",
"outputId": "b115f1de-01b8-42c0-edfc-5425d785f2e4"
},
"source": [
"plt.hist([len(get_path(i)) for i in leaf_indices])\n",
"plt.title('распределение глубин листов');\n",
"pd.Series([len(get_path(i)) for i in leaf_indices]).describe()"
],
"execution_count": 21,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"count 131.000000\n",
"mean 11.778626\n",
"std 4.354114\n",
"min 3.000000\n",
"25% 8.000000\n",
"50% 12.000000\n",
"75% 15.000000\n",
"max 21.000000\n",
"dtype: float64"
]
},
"metadata": {
"tags": []
},
"execution_count": 21
},
{
"output_type": "display_data",
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXAAAAEICAYAAABGaK+TAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAVBElEQVR4nO3de9QkdX3n8fcHGJKIiDMywQEZJlHUkGwCnBE1GpcEgwhRMEdZWRfBy07YlRM4B9dFzRrOxuSAGzWbuGuCghBFlCyoKBhB1Jh4wQXkKhIEh3AZGJCr6xX47h9Vj+lpup/79PP89P06p09XV/2q6tv11PPp6l9Vd6eqkCS1Z5ulLkCSND8GuCQ1ygCXpEYZ4JLUKANckhplgEvLQJIVS12D2mOAS0sgybokH0xyc5L7gP+51DWpPQa4NGFJVgJfAq4Bfq2qVlbVf17istSg+EEeabKSnATsUlX/aalrUds8Al8mkmxM8uYk30hyX5IPJPn5ftrKJJ9Kcnc/7VNJnjIw76q+/R399I/34/dP8miS7w7cHklydD/96CRfSvKeJA8k+WaSAwaWu1OS05JsSnJ7krcn2XZg+tOS1NCyXz8w/TlJvpzk/iRXJdl/6DmfkeRH/bzfT3LbwLRnJrk4yb1Jbkhy+NB8bx+uY+DxF6bqSLJNkmuGlr1rknP77fntJH84zd9lsMap2079tHWjnn+SJyf5XpInDSxn3359K4D9gCf13Sf3Jjk/ya4DbSvJ0/rhtf22+dDQOrcbaP+h/kVh3HMY3g8eTfLCMdvywsHlT7Nv3d8v6wf9855a9qv66S9Ncl3f7gtJfmVgHRv75/Tdfr86dlztmp4Bvry8CngR8FTg6cAf9eO3AT4A7AGsBb4PvGdgvg8CjwN+FfhF4N0D0+6oqsdP3YCvDK3z2cBNwM7AHwPnJVnVTzsDeBh4GrAPcCDw+oF5AzCw7H/8yYRkN+AC4O3AKuCNwLlJVg/Mvw1wSj/viwfm3QG4GPhw/3xeCfzvJHuN2mgzOApYObDsbYBPAlcBuwEHAMcnedE0y3jH4DasqgcG6gfYafD5V9WdwBeAwweWcSTwkar6Md3f6neAlwNrgFuAj4xZ958A35nlcx1nG+D2gb/Tv4xqlOS3gV8fGj1y36qqJ/bLOgb4ysC2OSvJ04GzgeOB1cCFwCeTbD+w3Jf08/974C+TPGGBz/FnkgG+vLynqm6tqnuBPwWOAKiq71TVuVX1vap6qJ/2bwGSrKELv2Oq6r6q+nFV/cMc1rkZ+It+vo8CNwCHJNkFOBg4vqr+X1VtpvvnfeXAvL8A/GjMcv8DcGFVXVhVj1bVxcBl/TKnbD9m/t8DNlbVB6rq4ar6OnAu8Io5PC/SvYN5G10ITnkWsLqq/ntV/aiqbgbeN/S8Zmt74NGqemTEtDPptgH9u5Yj6MJwyulVdUVV/RB4M/DcJOuG6v914Ln9shZi3HYeXFeAd9Btr6lx8923/h1wQVVd3L9g/TndvvKbI9puBzw4U30abbuZm2iCbh0YvgXYFSDJ4+jC8yD+9Whyxz4Ydgfurar75rnO22vLEyFT690DWAFs6v63ge4Ff7DGJwN3j1nuHsArkrxkYNwK4PMDj1cBo+reA3h2kvsHxm3HlgH4xoG33uMORI4D/p7uRWlw2bsOLXtbBt49zMG4+gE+Afx1kl8CngE8UFVf66f9kG47A1BV303yHbp3BBsHlnEK8N+AX+Gx7hn4uzwO+LN51jnlcOAe4HMD4+a7b+3Kls/v0SS30j2/KR9P8iiwA/DmqvrBHNchDPDlZveB4bXAHf3wCXQh8OyqujPJ3sDX6bowbgVWJXliVd3P3O2WJAMhvhY4v1/uD4Gdq+rhMfPuQ9cVMcqtwAer6j9Os+6nA/88Zt5/qKrfnWbeP6+qP4KuDxy4cWj6KuBYuiPuZw4t+9tVtec0y56tcfVTVT9Icg7dUfgz2fLF51/oXkiAn3QZPQm4faDN7/TjzqHr2hr2k7/LVP/4fOrsraB7l/LyofHz3bfuAP7N1IP+6H53tnx+h1XVZ/uutq8l+ceqGu7e0wzsQlle3pDkKX0f9FuBj/bjd6Tr976/n/aTf+iq2gR8mq6PeGWSFUleMId1/iLwh/18r6A72ruwX+5FwDuTPKE/GfjUJFNdN08Ajqbr6xzlQ8BLkrwoybZJfr4/mfaUJNslOQbYot98wKeApyc5sq9rRZJnDZ4Im4XjgdP6/uhBXwMeSvJfk/xCX9uvJXnWHJZNkt3pjvA/Pk2zv6XbRi9lywA/G3hNkr2T/Bzd0fOlVbVxoM1JwJuG3h3NWX/e4LUz1Hkk8OWqunpw5AL2rXPouuEOSHfS9gS6g4Evj2g71f20esQ0zcAAX14+TBeaN9OdWJy6OuAv6PoQ7wG+StctMOhI4MfAN+n6tI+fwzovBfbsl/2nwMurauqk2avp+k+/QfcW/P/QnXSDrj/7mcDfTF2BAPwW8J4ka6vqVuBQ4C103Sy3Av+Fbp97HfAa4NCq+v5wQX0//4F0/dJ3AHfSdSf83Bye17Z0fa/Dy36Ero99b+Db/fN+P7DTHJYN8Bm6E5XvHtegqr4EPApcUVWDXQqfo+trPhfYRHfSergP/utV9YU51rSF/sj+IuBvquqcaZqupOuqGWXO+1ZV3UD3zuOv6LbvS+hOWg72c3+y32euBs6jO+GtOfI68GUiyUbg9VX12Qmu8+h+nc+fx7wbq2rdiPHvB94+dDT5MyvJ54APV9X7l7oW/fSxD1zztWnM+HvpLj38mdd3y+xL905EWnQGuOalqp47ZvybJl3LcpTkTOAw4Li+S0hadHahSFKjPIkpSY2aaBfKzjvvXOvWrZvkKiWpeZdffvk9VfWYSy0nGuDr1q3jsssum+QqJal5SW4ZNd4uFElqlAEuSY0ywCWpUQa4JDXKAJekRhngktSoGQM8ye5JPp/utxqvS3JcP/6k/vfsruxvB8+0LEnS4pnNdeAPAydU1RVJdgQuT3JxP+3dVfWYr+yUJG19MwZ4/6Xum/rhh5Jcz5Y/jSRJWgJz+iRm/6Or+9D9CMDzgGOTvJruy/1PGPXbeUk2ABsA1q5du8BypZ8+605cmt8y2HjyIUuyXi2eWZ/ETPJ4ul8QOb6qHgTeS/dLInvTHaG/c9R8VXVqVa2vqvWrV/urSZK0WGYV4P3v2p0LnFVV5wFU1V1V9UhVPQq8D9hv65UpSRo2m6tQApwGXF9V7xoYv2ag2cuAaxe/PEnSOLPpA38e3Q+bXpPkyn7cW4AjkuwNFLAR+IOtUqEkaaTZXIXyT0BGTLpw8cuRJM2Wn8SUpEYZ4JLUKANckhplgEtSowxwSWqUAS5JjTLAJalRBrgkNcoAl6RGGeCS1CgDXJIaZYBLUqMMcElqlAEuSY0ywCWpUQa4JDXKAJekRhngktQoA1ySGmWAS1KjDHBJapQBLkmNMsAlqVEGuCQ1ygCXpEYZ4JLUKANckhplgEtSowxwSWqUAS5JjTLAJalRBrgkNWrGAE+ye5LPJ/lGkuuSHNePX5Xk4iQ39vcrt365kqQpszkCfxg4oar2Ap4DvCHJXsCJwCVVtSdwSf9YkjQhMwZ4VW2qqiv64YeA64HdgEOBM/tmZwKHba0iJUmPNac+8CTrgH2AS4FdqmpTP+lOYJcx82xIclmSy+6+++4FlCpJGjTrAE/yeOBc4PiqenBwWlUVUKPmq6pTq2p9Va1fvXr1goqVJP2rWQV4khV04X1WVZ3Xj74ryZp++hpg89YpUZI0ymyuQglwGnB9Vb1rYNL5wFH98FHAJxa/PEnSONvNos3zgCOBa5Jc2Y97C3AycE6S1wG3AIdvnRIlSaPMGOBV9U9Axkw+YHHLkSTNlp/ElKRGGeCS1CgDXJIaZYBLUqMMcElqlAEuSY0ywCWpUQa4JDXKAJekRhngktQoA1ySGjWbL7OS9FNo3YkXLNm6N558yJKt+6eJR+CS1CgDXJIaZYBLUqMMcElqlAEuSY0ywCWpUQa4JDXKAJekRhngktQoA1ySGmWAS1KjDHBJapQBLkmNMsAlqVEGuCQ1yu8Dl1ja78aW5ssjcElqlAEuSY0ywCWpUQa4JDVqxgBPcnqSzUmuHRh3UpLbk1zZ3w7eumVKkobN5gj8DOCgEePfXVV797cLF7csSdJMZgzwqvoicO8EapEkzcFC+sCPTXJ138WyctEqkiTNynw/yPNe4E+A6u/fCbx2VMMkG4ANAGvXrp3n6pbWUn3IY+PJhyzJepeSH6iRZm9eR+BVdVdVPVJVjwLvA/abpu2pVbW+qtavXr16vnVKkobMK8CTrBl4+DLg2nFtJUlbx4xdKEnOBvYHdk5yG/DHwP5J9qbrQtkI/MFWrFGSNMKMAV5VR4wYfdpWqEWSNAd+ElOSGmWAS1KjDHBJapQ/6LCMef25pOl4BC5JjTLAJalRBrgkNcoAl6RGGeCS1CgDXJIaZYBLUqMMcElqlAEuSY0ywCWpUQa4JDXKAJekRhngktQoA1ySGmWAS1KjDHBJapQBLkmNMsAlqVEGuCQ1ygCXpEYZ4JLUKANckhplgEtSowxwSWqUAS5JjTLAJalRBrgkNcoAl6RGGeCS1KgZAzzJ6Uk2J7l2YNyqJBcnubG/X7l1y5QkDZvNEfgZwEFD404ELqmqPYFL+seSpAmaMcCr6ovAvUOjDwXO7IfPBA5b5LokSTPYbp7z7VJVm/rhO4FdxjVMsgHYALB27dp5rk6TtO7EC5a6BEmzsOCTmFVVQE0z/dSqWl9V61evXr3Q1UmSevMN8LuSrAHo7zcvXkmSpNmYb4CfDxzVDx8FfGJxypEkzdZsLiM8G/gK8IwktyV5HXAy8LtJbgRe2D+WJE3QjCcxq+qIMZMOWORaJElz4CcxJalRBrgkNcoAl6RGGeCS1CgDXJIaZYBLUqMMcElqlAEuSY0ywCWpUQa4JDXKAJekRs33Bx0kad6W6kdDNp58yJKsd2vxCFySGmWAS1KjDHBJapQBLkmNMsAlqVEGuCQ1ygCXpEY1cx34Ul03KknLlUfgktQoA1ySGmWAS1KjDHBJapQBLkmNMsAlqVEGuCQ1ygCXpEYZ4JLUKANckhplgEtSowxwSWrUgr7MKslG4CHgEeDhqlq/GEVJkma2GN9G+NtVdc8iLEeSNAd2oUhSoxYa4AVclOTyJBtGNUiyIcllSS67++67F7g6SdKUhQb486tqX+DFwBuSvGC4QVWdWlXrq2r96tWrF7g6SdKUBQV4Vd3e328GPgbstxhFSZJmNu8AT7JDkh2nhoEDgWsXqzBJ0vQWchXKLsDHkkwt58NV9feLUpUkaUbzDvCquhn4jUWsRZI0B15GKEmNMsAlqVEGuCQ1ygCXpEYZ4JLUKANckhplgEtSowxwSWqUAS5JjTLAJalRBrgkNWoxflJNkpqw7sQLlmzdG08+ZNGX6RG4JDXKAJekRhngktQoA1ySGmWAS1KjDHBJapQBLkmNMsAlqVEGuCQ1ygCXpEYZ4JLUKANckhplgEtSowxwSWqUAS5JjTLAJalRBrgkNcoAl6RGGeCS1CgDXJIaZYBLUqMWFOBJDkpyQ5JvJTlxsYqSJM1s3gGeZFvgfwEvBvYCjkiy12IVJkma3kKOwPcDvlVVN1fVj4CPAIcuTlmSpJlst4B5dwNuHXh8G/Ds4UZJNgAb+offTXLDmOXtDNyzgHomxToXXyu1WufiaqVOWIRac8qC1r/HqJELCfBZqapTgVNnapfksqpav7XrWSjrXHyt1Gqdi6uVOmH51rqQLpTbgd0HHj+lHydJmoCFBPj/BfZM8ktJtgdeCZy/OGVJkmYy7y6Uqno4ybHAZ4BtgdOr6roF1DJjN8syYZ2Lr5VarXNxtVInLNNaU1VLXYMkaR78JKYkNcoAl6RGTSzAk+ye5PNJvpHkuiTHjWizf5IHklzZ3942qfpG1LIxyTV9HZeNmJ4kf9l/jcDVSfZdghqfMbCtrkzyYJLjh9os2TZNcnqSzUmuHRi3KsnFSW7s71eOmfeovs2NSY5agjr/R5Jv9n/bjyV54ph5p91PJlDnSUluH/j7Hjxm3ol97cWYOj86UOPGJFeOmXeS23NkJi3HfXSsqprIDVgD7NsP7wj8M7DXUJv9gU9NqqYZ6t0I7DzN9IOBTwMBngNcusT1bgvcCeyxXLYp8AJgX+DagXHvAE7sh08EThkx3yrg5v5+ZT+8csJ1Hghs1w+fMqrO2ewnE6jzJOCNs9g3bgJ+GdgeuGr4f29r1zk0/Z3A25bB9hyZSctxHx13m9gReFVtqqor+uGHgOvpPs3ZqkOBv63OV4EnJlmzhPUcANxUVbcsYQ1bqKovAvcOjT4UOLMfPhM4bMSsLwIurqp7q+o+4GLgoEnWWVUXVdXD/cOv0n3OYUmN2Z6zMdGvvZiuziQBDgfO3lrrn61pMmnZ7aPjLEkfeJJ1wD7ApSMmPzfJVUk+neRXJ1rYlgq4KMnl/dcBDBv1VQJL+YL0Ssb/UyyXbQqwS1Vt6ofvBHYZ0Wa5bdvX0r3bGmWm/WQSju27ek4f83Z/OW3P3wLuqqobx0xfku05lEnN7KMTD/AkjwfOBY6vqgeHJl9B1wXwG8BfAR+fdH0Dnl9V+9J92+IbkrxgCWuZVv9BqpcCfzdi8nLapluo7r3osr6ONclbgYeBs8Y0Wer95L3AU4G9gU103RPL2RFMf/Q98e05XSYt9310ogGeZAXdhjqrqs4bnl5VD1bVd/vhC4EVSXaeZI0Dtdze328GPkb3NnTQcvoqgRcDV1TVXcMTltM27d011dXU328e0WZZbNskRwO/B7yq/0d+jFnsJ1tVVd1VVY9U1aPA+8asf7lsz+2A3wc+Oq7NpLfnmExqZh+d5FUoAU4Drq+qd41p8+S+HUn26+v7zqRqHKhjhyQ7Tg3TndC6dqjZ+cCr+6tRngM8MPC2a9LGHtUsl2064Hxg6oz9UcAnRrT5DHBgkpV9l8CB/biJSXIQ8CbgpVX1vTFtZrOfbFVD511eNmb9y+VrL14IfLOqbhs1cdLbc5pMamIfBSZ6Fcrz6d6KXA1c2d8OBo4BjunbHAtcR3eW/KvAb076rG5fxy/3NVzV1/PWfvxgraH7QYubgGuA9UtU6w50gbzTwLhlsU3pXlQ2AT+m6yN8HfAk4BLgRuCzwKq+7Xrg/QPzvhb4Vn97zRLU+S26Ps6pffWv+7a7AhdOt59MuM4P9vvf1XTBs2a4zv7xwXRXWdy0FHX248+Y2i8H2i7l9hyXSctuHx1386P0ktQoP4kpSY0ywCWpUQa4JDXKAJekRhngktQoA1ySGmWAS1Kj/j+AyuLBSwDg1QAAAABJRU5ErkJggg==\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"tags": [],
"needs_background": "light"
}
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "cAxhwXjTkPln"
},
"source": [
"# Упрощение микросегментов"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "mcPLG3a-kV8v"
},
"source": [
"Мы получили больше сотни сегментов, каждый из которых описывается довольно большой последовательностью фильтров. Попытаемся упростить наши сегменты, действуя по примерно следующими принципам:\n",
"\n",
"* Фильтры, которые минимально влияют на конверсию сегмента не слишком сильно, можно выкинуть. \n",
"* Можно выкидывать такие фильтры, пока разница конверсии микросегмента со средней по больнице не сократится на a%"
]
},
{
"cell_type": "code",
"metadata": {
"id": "7tj2v6lHlu8o"
},
"source": [
"a = 0.10 # максимальное приближение конверсии в листе к средней - 10% от разницы между ними"
],
"execution_count": 24,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "ToGcJLkUmDJu"
},
"source": [
"Напишем функцию, которая применяет фильтры к датасету и возвращает булеву маску наблюдений, относительно которых все фильтры сработали. \n",
"\n",
"Протестируем её на первом же листе - всё нормально, у отфильтрованных наблюдений конверсия такая же, как у соответствующего листа. "
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "okDiZxXMmMR7",
"outputId": "cecbe6ad-7965-4fce-a294-1f2bbedf5843"
},
"source": [
"def apply_filters(filters, data):\n",
" where = np.array([True] * data.shape[0])\n",
" for feature, side, threshold in filters:\n",
" f = data[feature] <= threshold\n",
" if side == '>':\n",
" f = ~f\n",
" where = where & f\n",
" return where\n",
"\n",
"j = 0\n",
"segment_rules = [rules[i] for i in get_path(leaf_indices[j])]\n",
"\n",
"ff = apply_filters(segment_rules, X_train2)\n",
"print(j, sum(ff), np.mean(y_train[ff]), leaf_conversion[j])"
],
"execution_count": 25,
"outputs": [
{
"output_type": "stream",
"text": [
"0 219 0.0182648401826484 0.0182648401826484\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "DDd1N68NuezJ"
},
"source": [
"Теперь напишем функцию, которая упрощает лист. Оказалось, что 12 правил из этого листа можно упростить всего до двух, и микросегмент от этого станет только контрастнее. "
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "AZuVI27Kp1Ys",
"outputId": "6695c333-aaa2-4994-e005-2204f4af58cf"
},
"source": [
"def trim_rules(rules, threshold, above, data, labels):\n",
" is_ok = lambda conv: conv >= threshold if above else conv <= threshold\n",
"\n",
" while True:\n",
" best_id = -1\n",
" best_conv = None\n",
" for i in range(len(rules)):\n",
" new_rules = copy.copy(rules)\n",
" new_rules.pop(i)\n",
" ff = apply_filters(new_rules, data)\n",
" new_conv = np.mean(labels[ff])\n",
" if not is_ok(new_conv):\n",
" continue\n",
" if best_conv is None or (above and new_conv > best_conv) or (not above and new_conv < best_conv):\n",
" best_conv = new_conv\n",
" best_id = i\n",
"\n",
" if best_id >= 0:\n",
" rules = copy.copy(rules)\n",
" rules.pop(best_id)\n",
" else:\n",
" break\n",
" return rules\n",
"\n",
"threshold = leaf_conversion[j] * (1-a) + root_conversion * a\n",
"above = threshold > root_conversion\n",
"\n",
"ff = apply_filters(segment_rules, X_train2)\n",
"print(len(segment_rules), sum(ff), np.mean(y_train[ff]))\n",
"\n",
"new_rules = trim_rules(segment_rules, threshold, above, X_train2, y_train)\n",
"ff = apply_filters(new_rules, X_train2)\n",
"print(len(new_rules), sum(ff), np.mean(y_train[ff]))"
],
"execution_count": 26,
"outputs": [
{
"output_type": "stream",
"text": [
"12 219 0.0182648401826484\n",
"2 11226 0.015766969535008016\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ColEiwxywBlw"
},
"source": [
"Применим эту функцию ко всем листьм. Оказывается, ни для одного из результирующих микросегментов на самом деле не нужно больше 6 правил!"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000,
"referenced_widgets": [
"d5704653ebc8401181381c3a1d770285",
"509388a136054f9e9e4f880f6e1c08c5",
"6727b6dbc88a48e986b220cbd4c09c47",
"3d486747308a4264b75131b5c8106554",
"c4f04b9d42994e25a4f33bdc418e8db0",
"ea70bcd7a2504ce295ce001201b2e67c",
"03db337abd384a07a426b5171947a9df",
"62f89de11c794b859bbb58cf1764b1ac"
]
},
"id": "rN_dbd5cpgsZ",
"outputId": "c51cdf96-e063-4b2d-9b6c-6136ac44d04c"
},
"source": [
"simplified_rules = []\n",
"for leaf_id, node_id in enumerate(tqdm(leaf_indices)):\n",
" segment_rules = [rules[i] for i in get_path(node_id)]\n",
" threshold = leaf_conversion[leaf_id] * (1-a) + root_conversion * a\n",
" above = threshold > root_conversion \n",
" ff0 = apply_filters(segment_rules, X_train2)\n",
" new_rules = trim_rules(segment_rules, threshold, above, X_train2, y_train)\n",
" simplified_rules.append(new_rules)\n",
" ff1 = apply_filters(new_rules, X_train2)\n",
" print('{:3d}: {:5d} -> {:5d} {:2.4f} -> {:2.4f}'.format(\n",
" node_id, len(segment_rules), len(new_rules), np.mean(y_train[ff0]), np.mean(y_train[ff1])\n",
" ))"
],
"execution_count": 27,
"outputs": [
{
"output_type": "display_data",
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "d5704653ebc8401181381c3a1d770285",
"version_minor": 0,
"version_major": 2
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=131.0), HTML(value='')))"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "stream",
"text": [
" 12: 12 -> 2 0.0183 -> 0.0158\n",
" 13: 12 -> 4 0.0000 -> 0.0068\n",
" 15: 12 -> 1 0.0332 -> 0.0135\n",
" 16: 12 -> 1 0.1502 -> 0.1859\n",
" 19: 12 -> 1 0.1000 -> 0.0357\n",
" 21: 13 -> 2 0.0253 -> 0.0158\n",
" 22: 13 -> 3 0.0000 -> 0.0091\n",
" 28: 16 -> 2 0.0060 -> 0.0128\n",
" 29: 16 -> 3 0.0000 -> 0.0098\n",
" 30: 15 -> 3 0.0000 -> 0.0100\n",
" 34: 17 -> 4 0.0000 -> 0.0115\n",
" 35: 17 -> 1 0.0245 -> 0.0225\n",
" 38: 18 -> 1 0.0067 -> 0.0164\n",
" 39: 18 -> 2 0.0000 -> 0.0104\n",
" 40: 17 -> 1 0.0132 -> 0.0225\n",
" 42: 16 -> 1 0.0134 -> 0.0225\n",
" 45: 18 -> 4 0.0000 -> 0.0100\n",
" 47: 19 -> 2 0.0000 -> 0.0114\n",
" 49: 20 -> 4 0.0000 -> 0.0100\n",
" 51: 21 -> 1 0.0061 -> 0.0073\n",
" 52: 21 -> 1 0.0131 -> 0.0225\n",
" 54: 18 -> 1 0.0152 -> 0.0225\n",
" 55: 18 -> 4 0.0000 -> 0.0089\n",
" 56: 13 -> 1 0.0128 -> 0.0225\n",
" 61: 16 -> 1 0.0979 -> 0.0357\n",
" 63: 17 -> 1 0.0394 -> 0.0357\n",
" 65: 18 -> 2 0.0065 -> 0.0064\n",
" 66: 18 -> 2 0.0067 -> 0.0064\n",
" 68: 16 -> 3 0.0000 -> 0.0102\n",
" 69: 16 -> 1 0.0287 -> 0.0357\n",
" 71: 15 -> 2 0.0067 -> 0.0089\n",
" 72: 15 -> 2 0.0000 -> 0.0089\n",
" 74: 14 -> 2 0.0197 -> 0.0049\n",
" 76: 15 -> 2 0.0000 -> 0.0049\n",
" 78: 16 -> 2 0.0116 -> 0.0049\n",
" 79: 16 -> 2 0.0000 -> 0.0049\n",
" 81: 10 -> 1 0.1588 -> 0.2313\n",
" 86: 14 -> 2 0.0045 -> 0.0123\n",
" 87: 14 -> 1 0.0200 -> 0.0245\n",
" 89: 14 -> 2 0.0065 -> 0.0123\n",
" 90: 14 -> 3 0.0000 -> 0.0097\n",
" 92: 13 -> 1 0.0596 -> 0.0357\n",
" 93: 13 -> 2 0.0134 -> 0.0158\n",
" 95: 12 -> 1 0.0359 -> 0.0158\n",
" 96: 12 -> 1 0.1316 -> 0.2313\n",
"100: 11 -> 3 0.0000 -> 0.0059\n",
"101: 11 -> 1 0.0133 -> 0.0155\n",
"103: 11 -> 2 0.0164 -> 0.0158\n",
"104: 11 -> 1 0.0721 -> 0.0357\n",
"105: 9 -> 2 0.1931 -> 0.4028\n",
"108: 9 -> 1 0.0437 -> 0.0357\n",
"111: 11 -> 1 0.0000 -> 0.0038\n",
"112: 11 -> 1 0.0130 -> 0.0154\n",
"113: 10 -> 2 0.0258 -> 0.0158\n",
"114: 8 -> 2 0.2412 -> 0.4209\n",
"119: 10 -> 1 0.1306 -> 0.2211\n",
"125: 15 -> 1 0.1097 -> 0.0650\n",
"126: 15 -> 1 0.1647 -> 0.2211\n",
"127: 14 -> 2 0.0259 -> 0.0282\n",
"130: 15 -> 2 0.0269 -> 0.0282\n",
"131: 15 -> 4 0.0068 -> 0.0136\n",
"132: 14 -> 2 0.0468 -> 0.0282\n",
"136: 15 -> 2 0.0000 -> 0.0103\n",
"137: 15 -> 2 0.0067 -> 0.0103\n",
"138: 14 -> 2 0.0198 -> 0.0126\n",
"139: 13 -> 2 0.0521 -> 0.0172\n",
"141: 12 -> 1 0.1792 -> 0.2211\n",
"142: 12 -> 2 0.0252 -> 0.0282\n",
"145: 11 -> 1 0.1828 -> 0.3509\n",
"146: 11 -> 1 0.1261 -> 0.3509\n",
"147: 10 -> 2 0.0485 -> 0.0280\n",
"151: 11 -> 2 0.0040 -> 0.0063\n",
"152: 11 -> 2 0.0267 -> 0.0063\n",
"154: 11 -> 2 0.0000 -> 0.0063\n",
"155: 11 -> 2 0.0063 -> 0.0063\n",
"157: 10 -> 1 0.0559 -> 0.0420\n",
"158: 10 -> 2 0.0202 -> 0.0079\n",
"163: 11 -> 1 0.1404 -> 0.2211\n",
"164: 11 -> 3 0.3307 -> 0.4286\n",
"166: 11 -> 2 0.2905 -> 0.3050\n",
"172: 16 -> 3 0.0166 -> 0.0172\n",
"173: 16 -> 1 0.0876 -> 0.0563\n",
"175: 16 -> 5 0.0000 -> 0.0075\n",
"176: 16 -> 5 0.0067 -> 0.0147\n",
"178: 15 -> 1 0.1465 -> 0.2211\n",
"179: 15 -> 1 0.0581 -> 0.0563\n",
"182: 15 -> 1 0.0693 -> 0.0709\n",
"184: 16 -> 1 0.1625 -> 0.2951\n",
"185: 16 -> 1 0.1015 -> 0.0726\n",
"187: 15 -> 1 0.2564 -> 0.2951\n",
"188: 15 -> 1 0.1509 -> 0.2951\n",
"190: 13 -> 2 0.0268 -> 0.0079\n",
"192: 14 -> 2 0.0000 -> 0.0061\n",
"193: 14 -> 2 0.0200 -> 0.0079\n",
"195: 10 -> 2 0.2917 -> 0.4130\n",
"196: 10 -> 2 0.4051 -> 0.4130\n",
"197: 8 -> 3 0.4865 -> 0.5853\n",
"199: 6 -> 1 0.2316 -> 0.4541\n",
"200: 6 -> 2 0.5513 -> 0.6006\n",
"201: 4 -> 1 0.4764 -> 0.5282\n",
"203: 4 -> 1 0.0647 -> 0.0163\n",
"206: 6 -> 1 0.3136 -> 0.4277\n",
"207: 6 -> 1 0.4013 -> 0.4277\n",
"208: 5 -> 2 0.5629 -> 0.6114\n",
"210: 3 -> 1 0.2065 -> 0.6466\n",
"212: 4 -> 1 0.5758 -> 0.6466\n",
"214: 5 -> 1 0.6940 -> 0.6466\n",
"215: 5 -> 2 0.8175 -> 0.7814\n",
"222: 7 -> 2 0.4700 -> 0.4731\n",
"224: 8 -> 1 0.2634 -> 0.4136\n",
"225: 8 -> 1 0.3827 -> 0.4136\n",
"228: 8 -> 1 0.2284 -> 0.4136\n",
"229: 8 -> 1 0.1680 -> 0.4136\n",
"230: 7 -> 1 0.3438 -> 0.4136\n",
"234: 8 -> 1 0.4208 -> 0.4952\n",
"235: 8 -> 1 0.5325 -> 0.4952\n",
"237: 8 -> 1 0.2955 -> 0.4952\n",
"238: 8 -> 1 0.3873 -> 0.5323\n",
"239: 6 -> 2 0.5756 -> 0.5432\n",
"242: 6 -> 1 0.2108 -> 0.4136\n",
"244: 7 -> 1 0.0655 -> 0.0420\n",
"245: 7 -> 1 0.1467 -> 0.4496\n",
"246: 5 -> 1 0.2872 -> 0.5214\n",
"247: 3 -> 2 0.8012 -> 0.8102\n",
"250: 4 -> 1 0.4415 -> 0.5850\n",
"251: 4 -> 1 0.5971 -> 0.5850\n",
"254: 5 -> 1 0.6144 -> 0.5850\n",
"255: 5 -> 3 0.7316 -> 0.6724\n",
"257: 5 -> 1 0.5155 -> 0.5850\n",
"259: 6 -> 1 0.5455 -> 0.6042\n",
"260: 6 -> 2 0.6625 -> 0.6417\n",
"\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SJsi11jDvKbB"
},
"source": [
"Упрощённых правил по-прежнему 131, но уникальных из них - только 77. "
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Un98ypziwIDu",
"outputId": "8d78490b-c97b-4590-f9bc-ab4f0b1f6323"
},
"source": [
"print(len(simplified_rules))\n",
"sr = list({\n",
" tuple(tuple(item) for item in rule)\n",
" for rule in simplified_rules\n",
"})\n",
"print(len(sr))"
],
"execution_count": 28,
"outputs": [
{
"output_type": "stream",
"text": [
"131\n",
"77\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fe-Du6MC2SAi"
},
"source": [
"Теперь самые \"горячие\" и \"холодные\" микросегменты выглядят значительно компактнее, а значит, интерпретируемее!\n",
"\n",
"Не все из них, конечно, логичные (в частности, зависимость от дат выглядит очень подозрительно). Но такие закономерности действительно есть в данных, так что вот так вот. "
]
},
{
"cell_type": "code",
"metadata": {
"id": "6ZKL2z-P1QfF"
},
"source": [
"final_conversions = [\n",
" np.mean(y_train[apply_filters(r, X_train2)])\n",
" for r in sr\n",
"]\n",
"ranks = np.argsort(final_conversions)"
],
"execution_count": 29,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "eTkahwQj1-Y_",
"outputId": "ac15e24a-edb2-43f4-fa1c-19280b901c34"
},
"source": [
"for id in ranks[:-6:-1]:\n",
" print(final_conversions[id])\n",
" print(list(sr[id]))"
],
"execution_count": 30,
"outputs": [
{
"output_type": "stream",
"text": [
"0.8101851851851852\n",
"[('duration', '>', 473.5), ('poutcome_success', '>', 0.5)]\n",
"0.7814029363784666\n",
"[('poutcome_success', '>', 0.5), ('duration', '>', 254.0)]\n",
"0.6723768736616702\n",
"[('duration', '>', 827.5), ('contact_cellular', '>', 0.5), ('day', '<=', 15.5)]\n",
"0.6466165413533834\n",
"[('poutcome_success', '>', 0.5)]\n",
"0.64171974522293\n",
"[('contact_cellular', '>', 0.5), ('duration', '>', 956.5)]\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "cuMg9zkQ2Geo",
"outputId": "172ee9b6-4d7f-414d-b08a-b1e3bf141f3a"
},
"source": [
"for id in ranks[:6]:\n",
" print(final_conversions[id])\n",
" print(list(sr[id]))"
],
"execution_count": 31,
"outputs": [
{
"output_type": "stream",
"text": [
"0.0038335158817086527\n",
"[('duration', '<=', 75.5)]\n",
"0.004943743607228094\n",
"[('duration', '<=', 204.5), ('contact_unknown', '>', 0.5)]\n",
"0.005905511811023622\n",
"[('month_feb', '>', 0.5), ('day', '<=', 7.0), ('duration', '<=', 124.5)]\n",
"0.006050674398083953\n",
"[('contact_unknown', '>', 0.5), ('duration', '<=', 324.5)]\n",
"0.0062871114215857496\n",
"[('contact_unknown', '>', 0.5), ('duration', '<=', 393.5)]\n",
"0.006416751097602162\n",
"[('duration', '<=', 204.5), ('month_may', '>', 0.5)]\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "d3LzRc0E2PIQ"
},
"source": [
"Обладая такими \"приметами\" интересных сегментов, судить о закономерностях в данных довольно легко. "
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "oHTOFDM86v0G"
},
"source": [
"Остался последний вопрос: а насколько наши правила сохранили свою предсказательную силу после упрощения? Не стало ли дерево решений бесполезным?\n",
"\n",
"Ответить на этот вопрос можно с помощью очень простой предсказательной модели: для каждого наблюдения предсказываем конверсию как среднее арифметическое конверсий по всем микросегментам, в которые оно попало. "
]
},
{
"cell_type": "code",
"metadata": {
"id": "AJM3yPSm8dBC"
},
"source": [
"def predict(rules, data, scores):\n",
" obs2rules = np.stack([\n",
" apply_filters(r, data)\n",
" for r in rules\n",
" ]).astype(int).T\n",
" return (obs2rules * np.array(scores)).sum(axis=1) / obs2rules.sum(axis=1)"
],
"execution_count": 45,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "UWWdDJYl9prO"
},
"source": [
"Применив такую максимально упрощённую модель, мы получаем 88.8% ROC AUC на обучающей выборке, и 87.8% на тестовой. Причём на самом деле непонятно, сколько мы потеряли за счёт упрощения правил, а сколько - за счёт того, что мы очень наивно агрегируем предсказания в случае, когда правил сработало несколько. "
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "36ogVUf6880a",
"outputId": "0a0dc8b5-b0df-497e-96e2-6c5b07d846e5"
},
"source": [
"print(roc_auc_score(y_train, predict(sr, X_train2, final_conversions)))\n",
"print(roc_auc_score(y_test, predict(sr, X_test2, final_conversions)))"
],
"execution_count": 49,
"outputs": [
{
"output_type": "stream",
"text": [
"0.8878048311233366\n",
"0.8788157494950939\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "4rFlLci3-T0d"
},
"source": [
""
],
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment