Skip to content

Instantly share code, notes, and snippets.

@bentrevett
Created February 8, 2024 21:21
Show Gist options
  • Save bentrevett/60bedce84237aeadabf9e344ef538e7d to your computer and use it in GitHub Desktop.
Save bentrevett/60bedce84237aeadabf9e344ef538e7d to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "2c2dd28e-ade8-435f-a670-f354958a9b15",
"metadata": {},
"outputs": [],
"source": [
"# Sources\n",
"# https://anderfernandez.com/en/blog/code-decision-tree-python-from-scratch/"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "ff7525de-162f-40ce-bb31-c95af1383ca1",
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"import itertools\n",
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "a90b7f7f-f295-49c3-8724-3ee81f49baaf",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Gender</th>\n",
" <th>Height</th>\n",
" <th>Weight</th>\n",
" <th>Index</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Male</td>\n",
" <td>174</td>\n",
" <td>96</td>\n",
" <td>4</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>Male</td>\n",
" <td>189</td>\n",
" <td>87</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>Female</td>\n",
" <td>185</td>\n",
" <td>110</td>\n",
" <td>4</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>Female</td>\n",
" <td>195</td>\n",
" <td>104</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>Male</td>\n",
" <td>149</td>\n",
" <td>61</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>495</th>\n",
" <td>Female</td>\n",
" <td>150</td>\n",
" <td>153</td>\n",
" <td>5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>496</th>\n",
" <td>Female</td>\n",
" <td>184</td>\n",
" <td>121</td>\n",
" <td>4</td>\n",
" </tr>\n",
" <tr>\n",
" <th>497</th>\n",
" <td>Female</td>\n",
" <td>141</td>\n",
" <td>136</td>\n",
" <td>5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>498</th>\n",
" <td>Male</td>\n",
" <td>150</td>\n",
" <td>95</td>\n",
" <td>5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>499</th>\n",
" <td>Male</td>\n",
" <td>173</td>\n",
" <td>131</td>\n",
" <td>5</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>500 rows × 4 columns</p>\n",
"</div>"
],
"text/plain": [
" Gender Height Weight Index\n",
"0 Male 174 96 4\n",
"1 Male 189 87 2\n",
"2 Female 185 110 4\n",
"3 Female 195 104 3\n",
"4 Male 149 61 3\n",
".. ... ... ... ...\n",
"495 Female 150 153 5\n",
"496 Female 184 121 4\n",
"497 Female 141 136 5\n",
"498 Male 150 95 5\n",
"499 Male 173 131 5\n",
"\n",
"[500 rows x 4 columns]"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data = pd.read_csv(\"bmi-data.csv\")\n",
"\n",
"data"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "b6824114-10c3-46ef-b9e5-0ad808e08b90",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Gender</th>\n",
" <th>Height</th>\n",
" <th>Weight</th>\n",
" <th>Obese</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Male</td>\n",
" <td>174</td>\n",
" <td>96</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>Male</td>\n",
" <td>189</td>\n",
" <td>87</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>Female</td>\n",
" <td>185</td>\n",
" <td>110</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>Female</td>\n",
" <td>195</td>\n",
" <td>104</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>Male</td>\n",
" <td>149</td>\n",
" <td>61</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>495</th>\n",
" <td>Female</td>\n",
" <td>150</td>\n",
" <td>153</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>496</th>\n",
" <td>Female</td>\n",
" <td>184</td>\n",
" <td>121</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>497</th>\n",
" <td>Female</td>\n",
" <td>141</td>\n",
" <td>136</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>498</th>\n",
" <td>Male</td>\n",
" <td>150</td>\n",
" <td>95</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>499</th>\n",
" <td>Male</td>\n",
" <td>173</td>\n",
" <td>131</td>\n",
" <td>1</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>500 rows × 4 columns</p>\n",
"</div>"
],
"text/plain": [
" Gender Height Weight Obese\n",
"0 Male 174 96 1\n",
"1 Male 189 87 0\n",
"2 Female 185 110 1\n",
"3 Female 195 104 0\n",
"4 Male 149 61 0\n",
".. ... ... ... ...\n",
"495 Female 150 153 1\n",
"496 Female 184 121 1\n",
"497 Female 141 136 1\n",
"498 Male 150 95 1\n",
"499 Male 173 131 1\n",
"\n",
"[500 rows x 4 columns]"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data[\"Obese\"] = (data[\"Index\"] >= 4).astype(\"int\")\n",
"data = data.drop(\"Index\", axis=1)\n",
"\n",
"data"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "9b329465-7e6b-4067-91e3-0f30e82dc3d2",
"metadata": {},
"outputs": [],
"source": [
"# Want to predict if someone is Obese\n",
"# A decision tree tells us different rules, e.g. if weight >= 100kg, then Obese\n",
"# Not all splits are precise, e.g. not everyone who is >= 100kg is Obese\n",
"\n",
"# Decision trees create new branches (splits) that refine predictions\n",
"# Create a node at each split\n",
"# Keep going until we get a node that doesn't split, this is a leaf node\n",
"\n",
"# A decision tree uses a cost function, typically Gini index or entropy\n",
"# Both are based on measuring \"impurity\"\n",
"\n",
"# Impurity = when we make a split, how likely is the target value to be classified incorrectly?\n",
"# If we made the split at 100kg weight, what's the impurity? What about 80kg?"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "89b5373b-e34c-4d54-bf92-1a71e1078f45",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(18, 63)"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df_100kg = data[(data[\"Weight\"] >= 100) & (data[\"Obese\"] == 0)]\n",
"df_80kg = data[(data[\"Weight\"] >= 80) & (data[\"Obese\"] == 0)]\n",
"\n",
"len(df_100kg), len(df_80kg)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "fb41f186-e638-4fc1-b6f5-9a3cf3b79a7b",
"metadata": {},
"outputs": [],
"source": [
"# Impurity at 100kg is 18 (number of incorrect classifications), and impurity at 80kg is 63\n",
"# 18 is less impure, therefore is a \"better\" split\n",
"# Cost functions in a decision tree seek to find splits that minimize impurity\n",
"\n",
"# Gini index is most widely used cost function in decision trees\n",
"# Calculates the probability a characteristic will be classified incorrectly when randomly selected\n",
"# 0 = pure cut, no impurity\n",
"# 0.5 = divides the data equally (when there is two classes)\n",
"\n",
"# Calculated as: G = 1 - sum^n_{i=1}(P_i)^2\n",
"# P_i is the probability of being that class"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "f0f1b220-9bf2-4d5d-b135-4012602aaed5",
"metadata": {},
"outputs": [],
"source": [
"def gini_index(series):\n",
" assert isinstance(series, pd.Series)\n",
" p = series.value_counts() / len(series)\n",
" gini = 1 - np.sum(p**2)\n",
" return gini"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "2f9a56a7-52d9-4a98-9071-b32a1a4260cb",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.4998"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"gini_index(data[\"Gender\"])"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "9d8e86f0-dbbd-412b-8b4d-d74be1827ad8",
"metadata": {},
"outputs": [],
"source": [
"# Impurity of almost 0.5 for gender because we have almost equal amount of male and female\n",
"\n",
"# How do we calculate impurity with Entropy?\n",
"# Entropy measures randomness in data points\n",
"# Defined by: E = - sum^n_{i=1}p_i \\log_2 p_i\n",
"\n",
"# For both Gini index and entropy: higher values are more impure"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "b8216e84-14f4-4204-9360-a69f74a650f5",
"metadata": {},
"outputs": [],
"source": [
"def entropy(series):\n",
" assert isinstance(series, pd.Series)\n",
" p = series.value_counts() / len(series)\n",
" entropy = - np.sum(p * np.log2(p + 1e-10))\n",
" return entropy"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "ea97512a-6b36-4db0-984e-ac149eff68e5",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.9997114414642708"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"entropy(data[\"Gender\"])"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "9d8f1dc1-a2a8-458e-bcb5-119f089d842d",
"metadata": {},
"outputs": [],
"source": [
"# Entropy is between 0 and 1, and reaches maximum when there's an equal amount of each class.\n",
"# The entropy for the Gender split is high, therefore high impurity and a bad split.\n",
"\n",
"# OK. We know how to calculate impurity (to tell if a split is good or not)\n",
"# but how do we decide which splits to do?\n",
"\n",
"# Splits are compared by their impurity\n",
"# For this we use \"Information Gain\"\n",
"\n",
"# IG measures the improvement when performing splits\n",
"# We can do this with entropy\n",
"# Can also with with Gini Index, but if we do then it's not information gain (entropy is related to information)\n",
"\n",
"# IG_classification = E(d) - \\sum |s|/|d| * E(s)\n",
"# IG_regression = E(d) - \\sum |s|/|d| * Var(s)\n",
"\n",
"# d is across all values, s is across values n the split\n",
"\n",
"def get_information_gain(series, mask):\n",
" n_true_split = sum(mask)\n",
" n_false_split = len(mask) - n_true_split\n",
" if n_true_split == 0 or n_false_split == 0:\n",
" return 0\n",
" original_entropy = entropy(series)\n",
" true_split_entropy = entropy(series[mask])\n",
" false_split_entropy = entropy(series[~mask])\n",
" weighted_average_entropy = n_true_split / len(mask) * true_split_entropy + n_false_split / len(mask) * false_split_entropy\n",
" information_gain = original_entropy - weighted_average_entropy\n",
" return information_gain"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "06eafaf1-7150-440a-bcc3-9ca0307b46f7",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.0005506911187600494"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"get_information_gain(data[\"Obese\"], data[\"Gender\"] == \"Male\")"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "8a5b5d9c-fd48-4324-a51c-d4adc8f8ed62",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[('Blue',),\n",
" ('Red',),\n",
" ('Green',),\n",
" ('Blue', 'Red'),\n",
" ('Blue', 'Green'),\n",
" ('Red', 'Green'),\n",
" ('Blue', 'Red', 'Green')]"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# A high information gain indicates the feature reduces uncertainty in predicting the target variable\n",
"# this makes it a suitable candidate for splitting\n",
"# If the IG = 0, the features providers no additional inforation\n",
"\n",
"# To make our decision tree\n",
"# 1. Calculate information gain for all variables\n",
"# 2. Choose split that generates the highest information gain\n",
"# 3. Repeat until some stopping criterea\n",
"\n",
"# How do we choose where to split with numerical variables? What about with >2 categorical variables?\n",
"\n",
"# For splitting numeric variables:\n",
"# Get all values the variable is taking\n",
"# For each value, calculate IG when filtering all values less than that value\n",
"# If we sort values acending, first split means dropping that entire column\n",
"\n",
"# For categorical variables:\n",
"# Calculate IG for all possible combinations of that variable\n",
"# (exclude the one that includes all options as that would be doing no split)\n",
"# Combinatorial explosion when there's lots of classes so usually set a limit\n",
"\n",
"def get_categorical_options(series):\n",
" assert isinstance(series, pd.Series)\n",
" series = set(series)\n",
" options = []\n",
" for i, _ in enumerate(series):\n",
" subset = itertools.combinations(series, i+1)\n",
" options.extend(subset)\n",
" return options\n",
"\n",
"get_categorical_options(pd.Series([\"Red\", \"Red\", \"Blue\", \"Blue\", \"Green\"]))"
]
},
{
"cell_type": "code",
"execution_count": 156,
"id": "5c9e4114-27a0-4166-a0fe-3f3e2a07ed2f",
"metadata": {},
"outputs": [],
"source": [
"def get_max_information_gain(x_series, y_series):\n",
" is_numeric = x_series.dtype != \"object\"\n",
" if is_numeric:\n",
" split_values = x_series.sort_values().unique()[1:].tolist()\n",
" else:\n",
" split_values = get_categorical_options(x_series)\n",
" results = []\n",
" if not split_values:\n",
" # Handle the case when all values are the same!\n",
" return {\"split_value\": None, \"information_gain\": 0, \"is_numeric\": is_numeric}\n",
" for split_value in split_values:\n",
" mask = x_series < split_value if is_numeric else x_series.isin(split_value)\n",
" split_information_gain = get_information_gain(y_series, mask)\n",
" results.append({\n",
" \"value\": split_value,\n",
" \"information_gain\": split_information_gain,\n",
" })\n",
" results = sorted(results, key=lambda x: x[\"information_gain\"], reverse=True)\n",
" best_split_value = results[0][\"value\"]\n",
" best_information_gain = results[0][\"information_gain\"]\n",
" return {\"split_value\": best_split_value, \"information_gain\": best_information_gain, \"is_numeric\": is_numeric}"
]
},
{
"cell_type": "code",
"execution_count": 157,
"id": "b60045fd-90e2-4978-968f-5352ed62ec73",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'split_value': 103,\n",
" 'information_gain': 0.3824541370911896,\n",
" 'is_numeric': True}"
]
},
"execution_count": 157,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"get_max_information_gain(data[\"Weight\"], data[\"Obese\"])"
]
},
{
"cell_type": "code",
"execution_count": 158,
"id": "8d36f64a-0ac8-4556-a9a1-93b0150dca28",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Text(0, 0.5, 'Information Gain')"
]
},
"execution_count": 158,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"values = data[\"Weight\"].sort_values().unique()[1:]\n",
"igs = [get_information_gain(data[\"Weight\"], data[\"Weight\"] < value) for value in values]\n",
"\n",
"fig, ax = plt.subplots()\n",
"ax.plot(values, igs)\n",
"ax.set_xlabel(\"Weight\")\n",
"ax.set_ylabel(\"Information Gain\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "18037ae6-3a21-4fd6-b2d8-8a7fd089b118",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 159,
"id": "08c32adb-8615-4cda-b2a8-e355ea7af9e9",
"metadata": {},
"outputs": [],
"source": [
"# First row is the value which obtained the best information gain\n",
"# Second row is the information gain value\n",
"\n",
"# We can see the best thing to do in the above (in terms of maximizing information gain) is to split weight at 103\n",
"# This split would generate two DataFrames\n",
"# We would keep applying this recursively and create an entire decision tree\n",
"\n",
"# How do we decide when to stop splitting? Usually three metrics:\n",
"# max_depth: Maximum depth of the tree, if we leave it to None, then it will grow until all leaves are pure\n",
"# min_samples_split: The minimum number of observations in a split to keep creating new nodes\n",
"# min_information_gain: The minimum amount of IG for the tree to keep growing\n",
"\n",
"# Steps:\n",
"# Make sure min_samples_split and max_depth are met\n",
"# Make a split\n",
"# Make sure that min_information_gain is met\n",
"# Save the split and repeat"
]
},
{
"cell_type": "code",
"execution_count": 161,
"id": "bab774ed-4221-47c6-b40d-9d36727dc275",
"metadata": {},
"outputs": [],
"source": [
"def get_best_split(df, y):\n",
" column_max_information_gains = [\n",
" {\"column_name\": column, **get_max_information_gain(df[column], df[y])} for column in df.columns if column != y\n",
" ]\n",
" best_split_info = sorted(column_max_information_gains, key=lambda x: x[\"information_gain\"])[-1]\n",
" return best_split_info\n",
"\n",
"def make_split(df, split_info):\n",
" column_name = split_info[\"column_name\"]\n",
" split_value = split_info[\"split_value\"]\n",
" is_numeric = split_info[\"is_numeric\"]\n",
" assert is_numeric == df[column_name].dtype != \"object\"\n",
" if is_numeric:\n",
" mask = df[column_name] < split_value\n",
" else:\n",
" mask = df[column_name].isin(split_value)\n",
" df_left = df[mask]\n",
" df_right = df[~mask]\n",
" return df_left, df_right"
]
},
{
"cell_type": "code",
"execution_count": 162,
"id": "95c37e9f-4bcf-4952-adf2-7596ca996337",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'column_name': 'Weight',\n",
" 'split_value': 103,\n",
" 'information_gain': 0.3824541370911896,\n",
" 'is_numeric': True}"
]
},
"execution_count": 162,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"split_info = get_best_split(data, \"Obese\")\n",
"\n",
"split_info"
]
},
{
"cell_type": "code",
"execution_count": 163,
"id": "ab56bf34-954c-4945-ba4b-011e02f3df3d",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(500, 229, 271)"
]
},
"execution_count": 163,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"left, right = make_split(data, split_info)\n",
"\n",
"len(data), len(left), len(right)"
]
},
{
"cell_type": "code",
"execution_count": 164,
"id": "b7c5a98d-1e93-4c65-814b-310918998588",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'column_name': 'Height',\n",
" 'split_value': 178,\n",
" 'information_gain': 0.28026630900174687,\n",
" 'is_numeric': True}"
]
},
"execution_count": 164,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"split_info = get_best_split(left, \"Obese\")\n",
"\n",
"split_info"
]
},
{
"cell_type": "code",
"execution_count": 165,
"id": "93f9aa35-eac3-4a21-8cc7-361b01dcd678",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'column_name': 'Weight',\n",
" 'split_value': 116,\n",
" 'information_gain': 0.09289094500737183,\n",
" 'is_numeric': True}"
]
},
"execution_count": 165,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"split_info = get_best_split(right, \"Obese\")\n",
"\n",
"split_info"
]
},
{
"cell_type": "code",
"execution_count": 232,
"id": "2efe3f4d-891c-463b-899c-80597cc15f9d",
"metadata": {},
"outputs": [],
"source": [
"class DecisionNode:\n",
" def __init__(self, column, split_value, is_numeric, left, right, prediction=None):\n",
" self.column = column\n",
" self.split_value = split_value\n",
" self.is_numeric = is_numeric\n",
" self.left = left\n",
" self.right = right\n",
" self.prediction = prediction\n",
"\n",
" def __repr__(self, depth=0, indent=\" \"):\n",
" prefix = depth * indent\n",
" # Representation for leaf node\n",
" if self.is_leaf():\n",
" return f\"{prefix}Leaf(prediction={self.prediction})\"\n",
" description = f\"{prefix}DecisionNode(column={self.column}, split_value={self.split_value}, is_numeric={self.is_numeric})\\n\"\n",
" if self.left:\n",
" description += f\"{prefix}left:\\n{self.left.__repr__(depth + 1, indent)}\\n\"\n",
" if self.right:\n",
" description += f\"{prefix}right:\\n{self.right.__repr__(depth + 1, indent)}\"\n",
" return description\n",
"\n",
" def is_leaf(self):\n",
" return self.left is None and self.right is None\n",
"\n",
"def build_tree(df, y, depth, max_depth, min_samples_split, min_information_gain, is_classification):\n",
"\n",
" print(\"depth:\", depth)\n",
" \n",
" if depth >= max_depth:\n",
" print(\"hit max depth\", depth)\n",
" return create_leaf_node(df, y, is_classification)\n",
"\n",
" if len(df) < min_samples_split:\n",
" print(\"hit min samples\", len(df))\n",
" return create_leaf_node(df, y, is_classification)\n",
"\n",
" print(\"getting best split\", len(df))\n",
" split_info = get_best_split(df, y)\n",
"\n",
" if split_info[\"information_gain\"] < min_information_gain:\n",
" print(\"hit min info gain\", split_info[\"information_gain\"])\n",
" return create_leaf_node(df, y, is_classification)\n",
"\n",
" print(f\"Size before split: {len(df)}\")\n",
" \n",
" print(\"splitting on:\", split_info)\n",
" df_left, df_right = make_split(df, split_info)\n",
"\n",
" print(f\"left size: {len(df_left)}, right size: {len(df_right)}\")\n",
" \n",
" subtree_left = build_tree(df_left, y, depth+1, max_depth, min_samples_split, min_information_gain, is_classification)\n",
" subtree_right = build_tree(df_right, y, depth+1, max_depth, min_samples_split, min_information_gain, is_classification)\n",
"\n",
" return DecisionNode(\n",
" column=split_info[\"column_name\"],\n",
" split_value=split_info[\"split_value\"],\n",
" is_numeric=split_info[\"is_numeric\"],\n",
" left=subtree_left,\n",
" right=subtree_right,\n",
" prediction=None,\n",
" )\n",
"\n",
"def create_leaf_node(df, y, is_classification=True):\n",
" if is_classification:\n",
" prediction = df[y].mode()[0]\n",
" else:\n",
" prediction = df[y].mean()\n",
"\n",
" return DecisionNode(\n",
" column=None,\n",
" split_value=None,\n",
" is_numeric=None,\n",
" left=None,\n",
" right=None,\n",
" prediction=prediction,\n",
" )\n",
"\n",
"def train_decision_tree(df, y, max_depth=10, min_samples_split=2, min_information_gain=0.01, is_classification=True):\n",
" return build_tree(df, y, 0, max_depth, min_samples_split, min_information_gain, is_classification)"
]
},
{
"cell_type": "code",
"execution_count": 233,
"id": "2d65d461-a722-4be1-b949-57490c926f0d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"depth: 0\n",
"getting best split 500\n",
"Size before split: 500\n",
"splitting on: {'column_name': 'Weight', 'split_value': 103, 'information_gain': 0.3824541370911896, 'is_numeric': True}\n",
"left size: 229, right size: 271\n",
"depth: 1\n",
"getting best split 229\n",
"Size before split: 229\n",
"splitting on: {'column_name': 'Height', 'split_value': 178, 'information_gain': 0.28026630900174687, 'is_numeric': True}\n",
"left size: 138, right size: 91\n",
"depth: 2\n",
"getting best split 138\n",
"Size before split: 138\n",
"splitting on: {'column_name': 'Weight', 'split_value': 66, 'information_gain': 0.3905984031684069, 'is_numeric': True}\n",
"left size: 41, right size: 97\n",
"depth: 3\n",
"getting best split 41\n",
"hit min info gain 2.5849394142282115e-26\n",
"depth: 3\n",
"getting best split 97\n",
"Size before split: 97\n",
"splitting on: {'column_name': 'Height', 'split_value': 151, 'information_gain': 0.16796853890498697, 'is_numeric': True}\n",
"left size: 35, right size: 62\n",
"depth: 4\n",
"getting best split 35\n",
"Size before split: 35\n",
"splitting on: {'column_name': 'Weight', 'split_value': 67, 'information_gain': 0.13003339959432547, 'is_numeric': True}\n",
"left size: 2, right size: 33\n",
"depth: 5\n",
"getting best split 2\n",
"Size before split: 2\n",
"splitting on: {'column_name': 'Height', 'split_value': 149, 'information_gain': 0.9999999998557304, 'is_numeric': True}\n",
"left size: 1, right size: 1\n",
"depth: 6\n",
"hit min samples 1\n",
"depth: 6\n",
"hit min samples 1\n",
"depth: 5\n",
"getting best split 33\n",
"hit min info gain 0.0\n",
"depth: 4\n",
"getting best split 62\n",
"Size before split: 62\n",
"splitting on: {'column_name': 'Weight', 'split_value': 82, 'information_gain': 0.31434467087549045, 'is_numeric': True}\n",
"left size: 26, right size: 36\n",
"depth: 5\n",
"getting best split 26\n",
"Size before split: 26\n",
"splitting on: {'column_name': 'Height', 'split_value': 161, 'information_gain': 0.3216587044834349, 'is_numeric': True}\n",
"left size: 10, right size: 16\n",
"depth: 6\n",
"getting best split 10\n",
"Size before split: 10\n",
"splitting on: {'column_name': 'Weight', 'split_value': 74, 'information_gain': 0.6099865469532797, 'is_numeric': True}\n",
"left size: 4, right size: 6\n",
"depth: 7\n",
"getting best split 4\n",
"hit min info gain 0.0\n",
"depth: 7\n",
"getting best split 6\n",
"Size before split: 6\n",
"splitting on: {'column_name': 'Height', 'split_value': 154, 'information_gain': 0.3166890882188412, 'is_numeric': True}\n",
"left size: 2, right size: 4\n",
"depth: 8\n",
"getting best split 2\n",
"Size before split: 2\n",
"splitting on: {'column_name': 'Weight', 'split_value': 78, 'information_gain': 0.9999999998557304, 'is_numeric': True}\n",
"left size: 1, right size: 1\n",
"depth: 9\n",
"hit min samples 1\n",
"depth: 9\n",
"hit min samples 1\n",
"depth: 8\n",
"getting best split 4\n",
"hit min info gain 0.0\n",
"depth: 6\n",
"getting best split 16\n",
"hit min info gain 0.0\n",
"depth: 5\n",
"getting best split 36\n",
"Size before split: 36\n",
"splitting on: {'column_name': 'Height', 'split_value': 173, 'information_gain': 0.23084979872365263, 'is_numeric': True}\n",
"left size: 27, right size: 9\n",
"depth: 6\n",
"getting best split 27\n",
"hit min info gain 0.06748201253360614\n",
"depth: 6\n",
"getting best split 9\n",
"Size before split: 9\n",
"splitting on: {'column_name': 'Weight', 'split_value': 95, 'information_gain': 0.9910760596939526, 'is_numeric': True}\n",
"left size: 5, right size: 4\n",
"depth: 7\n",
"getting best split 5\n",
"hit min info gain 0.0\n",
"depth: 7\n",
"getting best split 4\n",
"hit min info gain 0.0\n",
"depth: 2\n",
"getting best split 91\n",
"hit min info gain 2.5849394142282115e-26\n",
"depth: 1\n",
"getting best split 271\n",
"hit min info gain 0.09289094500737183\n"
]
}
],
"source": [
"y = \"Obese\"\n",
"max_depth = 10\n",
"min_samples_split = 2\n",
"min_information_gain = 0.1\n",
"is_classification = True\n",
"\n",
"tree = train_decision_tree(data, y, max_depth, min_samples_split, min_information_gain, is_classification=True)"
]
},
{
"cell_type": "code",
"execution_count": 234,
"id": "47d36982-6371-4355-92e3-e4b70a5a97d1",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"DecisionNode(column=Weight, split_value=103, is_numeric=True)\n",
"left:\n",
" DecisionNode(column=Height, split_value=178, is_numeric=True)\n",
" left:\n",
" DecisionNode(column=Weight, split_value=66, is_numeric=True)\n",
" left:\n",
" Leaf(prediction=0)\n",
" right:\n",
" DecisionNode(column=Height, split_value=151, is_numeric=True)\n",
" left:\n",
" DecisionNode(column=Weight, split_value=67, is_numeric=True)\n",
" left:\n",
" DecisionNode(column=Height, split_value=149, is_numeric=True)\n",
" left:\n",
" Leaf(prediction=1)\n",
" right:\n",
" Leaf(prediction=0)\n",
" right:\n",
" Leaf(prediction=1)\n",
" right:\n",
" DecisionNode(column=Weight, split_value=82, is_numeric=True)\n",
" left:\n",
" DecisionNode(column=Height, split_value=161, is_numeric=True)\n",
" left:\n",
" DecisionNode(column=Weight, split_value=74, is_numeric=True)\n",
" left:\n",
" Leaf(prediction=0)\n",
" right:\n",
" DecisionNode(column=Height, split_value=154, is_numeric=True)\n",
" left:\n",
" DecisionNode(column=Weight, split_value=78, is_numeric=True)\n",
" left:\n",
" Leaf(prediction=1)\n",
" right:\n",
" Leaf(prediction=0)\n",
" right:\n",
" Leaf(prediction=1)\n",
" right:\n",
" Leaf(prediction=0)\n",
" right:\n",
" DecisionNode(column=Height, split_value=173, is_numeric=True)\n",
" left:\n",
" Leaf(prediction=1)\n",
" right:\n",
" DecisionNode(column=Weight, split_value=95, is_numeric=True)\n",
" left:\n",
" Leaf(prediction=0)\n",
" right:\n",
" Leaf(prediction=1)\n",
" right:\n",
" Leaf(prediction=0)\n",
"right:\n",
" Leaf(prediction=1)"
]
},
"execution_count": 234,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tree"
]
},
{
"cell_type": "code",
"execution_count": 235,
"id": "fdf0165f-8568-4241-bb2e-2a7f76f85114",
"metadata": {},
"outputs": [],
"source": [
"def _predict(tree, example):\n",
" if tree.is_leaf():\n",
" return tree.prediction\n",
" if tree.is_numeric:\n",
" if example[tree.column] < tree.split_value:\n",
" return _predict(tree.left, example)\n",
" else:\n",
" return _predict(tree.right, example)\n",
" else:\n",
" if example[tree.column] in tree.split_value:\n",
" return _predict(tree.left, example)\n",
" else:\n",
" return _predict(tree.right, example)\n",
"\n",
"def predict(tree, df):\n",
" return df.apply(lambda row: _predict(tree, row), axis=1)"
]
},
{
"cell_type": "code",
"execution_count": 236,
"id": "9216abe0-da28-4218-9edf-1d47aef2455f",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Gender Male\n",
"Height 174\n",
"Weight 96\n",
"Obese 1\n",
"Name: 0, dtype: object"
]
},
"execution_count": 236,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"example = data.iloc[0]\n",
"\n",
"example"
]
},
{
"cell_type": "code",
"execution_count": 237,
"id": "aa2f778c-1610-4d84-84ec-9917feb9b128",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1"
]
},
"execution_count": 237,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"_predict(tree, example)"
]
},
{
"cell_type": "code",
"execution_count": 238,
"id": "8550a121-307f-4226-9e4a-b045c34394ee",
"metadata": {},
"outputs": [],
"source": [
"predictions = predict(tree, data)"
]
},
{
"cell_type": "code",
"execution_count": 239,
"id": "b453fabd-5f84-44b9-ac13-7b6f6394c032",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0 1\n",
"1 0\n",
"2 1\n",
"3 1\n",
"4 0\n",
" ..\n",
"495 1\n",
"496 1\n",
"497 1\n",
"498 1\n",
"499 1\n",
"Length: 500, dtype: int64"
]
},
"execution_count": 239,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"predictions"
]
},
{
"cell_type": "code",
"execution_count": 207,
"id": "d2926d7e-ae58-42a0-bb14-e6da06f7840f",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0 True\n",
"1 True\n",
"2 False\n",
"3 False\n",
"4 True\n",
" ... \n",
"495 False\n",
"496 False\n",
"497 False\n",
"498 True\n",
"499 False\n",
"Length: 500, dtype: bool"
]
},
"execution_count": 207,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"predictions == data[\"Obese\"]"
]
},
{
"cell_type": "code",
"execution_count": 240,
"id": "c8b540c3-7f17-42ff-b01a-30f8918d59d3",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"depth: 0\n",
"getting best split 5\n",
"Size before split: 5\n",
"splitting on: {'column_name': 'name', 'split_value': ('d', 'b'), 'information_gain': 0.9709505943103991, 'is_numeric': False}\n",
"left size: 2, right size: 3\n",
"depth: 1\n",
"getting best split 2\n",
"hit min info gain 0.0\n",
"depth: 1\n",
"getting best split 3\n",
"hit min info gain 0\n"
]
}
],
"source": [
"new_tree = train_decision_tree(new_data, \"Obese\")"
]
},
{
"cell_type": "code",
"execution_count": 242,
"id": "379aa34a-566a-4b34-a3fa-108488b7b5ff",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0"
]
},
"execution_count": 242,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"_predict(new_tree, new_data.iloc[0])"
]
},
{
"cell_type": "code",
"execution_count": 243,
"id": "981dc982-88de-4f50-bbe9-5e06e9a60a90",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>name</th>\n",
" <th>Obese</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>a</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>b</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>c</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>d</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>e</td>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" name Obese\n",
"0 a 0\n",
"1 b 1\n",
"2 c 0\n",
"3 d 1\n",
"4 e 0"
]
},
"execution_count": 243,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"new_data"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "29b65fc1-5390-44de-91bd-bf7f675cbacd",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment