Code for blog post: https://welivein.space/machine-learning/math/python/2017/10/16/decision_titanic.html
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"# The imports we are using\n", | |
"import pandas as pd\n", | |
"import graphviz\n", | |
"from sklearn import tree" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"# Get the data. Which is in our case very easy, because it's from kaggle: https://www.kaggle.com/c/titanic/data\n", | |
"data_raw = pd.read_csv('train.csv')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": { | |
"scrolled": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>PassengerId</th>\n", | |
" <th>Survived</th>\n", | |
" <th>Pclass</th>\n", | |
" <th>Name</th>\n", | |
" <th>Sex</th>\n", | |
" <th>Age</th>\n", | |
" <th>SibSp</th>\n", | |
" <th>Parch</th>\n", | |
" <th>Ticket</th>\n", | |
" <th>Fare</th>\n", | |
" <th>Cabin</th>\n", | |
" <th>Embarked</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>1</td>\n", | |
" <td>0</td>\n", | |
" <td>3</td>\n", | |
" <td>Braund, Mr. Owen Harris</td>\n", | |
" <td>male</td>\n", | |
" <td>22.0</td>\n", | |
" <td>1</td>\n", | |
" <td>0</td>\n", | |
" <td>A/5 21171</td>\n", | |
" <td>7.2500</td>\n", | |
" <td>NaN</td>\n", | |
" <td>S</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>2</td>\n", | |
" <td>1</td>\n", | |
" <td>1</td>\n", | |
" <td>Cumings, Mrs. John Bradley (Florence Briggs Th...</td>\n", | |
" <td>female</td>\n", | |
" <td>38.0</td>\n", | |
" <td>1</td>\n", | |
" <td>0</td>\n", | |
" <td>PC 17599</td>\n", | |
" <td>71.2833</td>\n", | |
" <td>C85</td>\n", | |
" <td>C</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>3</td>\n", | |
" <td>1</td>\n", | |
" <td>3</td>\n", | |
" <td>Heikkinen, Miss. Laina</td>\n", | |
" <td>female</td>\n", | |
" <td>26.0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>STON/O2. 3101282</td>\n", | |
" <td>7.9250</td>\n", | |
" <td>NaN</td>\n", | |
" <td>S</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>4</td>\n", | |
" <td>1</td>\n", | |
" <td>1</td>\n", | |
" <td>Futrelle, Mrs. Jacques Heath (Lily May Peel)</td>\n", | |
" <td>female</td>\n", | |
" <td>35.0</td>\n", | |
" <td>1</td>\n", | |
" <td>0</td>\n", | |
" <td>113803</td>\n", | |
" <td>53.1000</td>\n", | |
" <td>C123</td>\n", | |
" <td>S</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>5</td>\n", | |
" <td>0</td>\n", | |
" <td>3</td>\n", | |
" <td>Allen, Mr. William Henry</td>\n", | |
" <td>male</td>\n", | |
" <td>35.0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>373450</td>\n", | |
" <td>8.0500</td>\n", | |
" <td>NaN</td>\n", | |
" <td>S</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>5</th>\n", | |
" <td>6</td>\n", | |
" <td>0</td>\n", | |
" <td>3</td>\n", | |
" <td>Moran, Mr. James</td>\n", | |
" <td>male</td>\n", | |
" <td>NaN</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>330877</td>\n", | |
" <td>8.4583</td>\n", | |
" <td>NaN</td>\n", | |
" <td>Q</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>6</th>\n", | |
" <td>7</td>\n", | |
" <td>0</td>\n", | |
" <td>1</td>\n", | |
" <td>McCarthy, Mr. Timothy J</td>\n", | |
" <td>male</td>\n", | |
" <td>54.0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>17463</td>\n", | |
" <td>51.8625</td>\n", | |
" <td>E46</td>\n", | |
" <td>S</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>7</th>\n", | |
" <td>8</td>\n", | |
" <td>0</td>\n", | |
" <td>3</td>\n", | |
" <td>Palsson, Master. Gosta Leonard</td>\n", | |
" <td>male</td>\n", | |
" <td>2.0</td>\n", | |
" <td>3</td>\n", | |
" <td>1</td>\n", | |
" <td>349909</td>\n", | |
" <td>21.0750</td>\n", | |
" <td>NaN</td>\n", | |
" <td>S</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>8</th>\n", | |
" <td>9</td>\n", | |
" <td>1</td>\n", | |
" <td>3</td>\n", | |
" <td>Johnson, Mrs. Oscar W (Elisabeth Vilhelmina Berg)</td>\n", | |
" <td>female</td>\n", | |
" <td>27.0</td>\n", | |
" <td>0</td>\n", | |
" <td>2</td>\n", | |
" <td>347742</td>\n", | |
" <td>11.1333</td>\n", | |
" <td>NaN</td>\n", | |
" <td>S</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>9</th>\n", | |
" <td>10</td>\n", | |
" <td>1</td>\n", | |
" <td>2</td>\n", | |
" <td>Nasser, Mrs. Nicholas (Adele Achem)</td>\n", | |
" <td>female</td>\n", | |
" <td>14.0</td>\n", | |
" <td>1</td>\n", | |
" <td>0</td>\n", | |
" <td>237736</td>\n", | |
" <td>30.0708</td>\n", | |
" <td>NaN</td>\n", | |
" <td>C</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" PassengerId Survived Pclass \\\n", | |
"0 1 0 3 \n", | |
"1 2 1 1 \n", | |
"2 3 1 3 \n", | |
"3 4 1 1 \n", | |
"4 5 0 3 \n", | |
"5 6 0 3 \n", | |
"6 7 0 1 \n", | |
"7 8 0 3 \n", | |
"8 9 1 3 \n", | |
"9 10 1 2 \n", | |
"\n", | |
" Name Sex Age SibSp \\\n", | |
"0 Braund, Mr. Owen Harris male 22.0 1 \n", | |
"1 Cumings, Mrs. John Bradley (Florence Briggs Th... female 38.0 1 \n", | |
"2 Heikkinen, Miss. Laina female 26.0 0 \n", | |
"3 Futrelle, Mrs. Jacques Heath (Lily May Peel) female 35.0 1 \n", | |
"4 Allen, Mr. William Henry male 35.0 0 \n", | |
"5 Moran, Mr. James male NaN 0 \n", | |
"6 McCarthy, Mr. Timothy J male 54.0 0 \n", | |
"7 Palsson, Master. Gosta Leonard male 2.0 3 \n", | |
"8 Johnson, Mrs. Oscar W (Elisabeth Vilhelmina Berg) female 27.0 0 \n", | |
"9 Nasser, Mrs. Nicholas (Adele Achem) female 14.0 1 \n", | |
"\n", | |
" Parch Ticket Fare Cabin Embarked \n", | |
"0 0 A/5 21171 7.2500 NaN S \n", | |
"1 0 PC 17599 71.2833 C85 C \n", | |
"2 0 STON/O2. 3101282 7.9250 NaN S \n", | |
"3 0 113803 53.1000 C123 S \n", | |
"4 0 373450 8.0500 NaN S \n", | |
"5 0 330877 8.4583 NaN Q \n", | |
"6 0 17463 51.8625 E46 S \n", | |
"7 1 349909 21.0750 NaN S \n", | |
"8 2 347742 11.1333 NaN S \n", | |
"9 0 237736 30.0708 NaN C " | |
] | |
}, | |
"execution_count": 3, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# Lets have a look at it\n", | |
"# We use the iloc method so see the first 10 entries\n", | |
"data_raw.iloc[:10]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"# Next we need to convert the 'sex' column to binary values.\n", | |
"# For that we use an if/then clause within list comprehension\n", | |
"data_raw['SexNumerical'] = [1 if x=='male' else 0 for x in data_raw['Sex']]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"PassengerId 0.005007\n", | |
"SibSp 0.035322\n", | |
"Age 0.077221\n", | |
"Parch 0.081629\n", | |
"Fare 0.257307\n", | |
"Pclass 0.338481\n", | |
"SexNumerical 0.543351\n", | |
"Survived 1.000000\n", | |
"Name: Survived, dtype: float64" | |
] | |
}, | |
"execution_count": 5, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# a good idea for predicting survival is look at the correlations\n", | |
"# with the 'Survived' feature. The stronger the correlation the more\n", | |
"# the corresponding feature explains the survival.\n", | |
"data_raw.corr()['Survived'].abs().sort_values()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"233 women and 109 men survived\n" | |
] | |
} | |
], | |
"source": [ | |
"print('{:.0f} women and {:.0f} men survived'.format(data_raw.where(data_raw['Sex']=='female')['Survived'].sum(),data_raw.where(data_raw['Sex']=='male')['Survived'].sum()))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"PassengerId 0.008790\n", | |
"Age 0.116109\n", | |
"Fare 0.218466\n", | |
"Parch 0.223644\n", | |
"SibSp 0.263284\n", | |
"Pclass 0.477114\n", | |
"Survived 1.000000\n", | |
"SexNumerical NaN\n", | |
"Name: Survived, dtype: float64" | |
] | |
}, | |
"execution_count": 7, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# Since the sex has such a strong correlation we split the dataset\n", | |
"# and look at all correlations for female and male passengers.\n", | |
"data_raw.where(data_raw['Sex']=='female').corr()['Survived'].abs().sort_values()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": { | |
"scrolled": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"SibSp 0.020238\n", | |
"PassengerId 0.040477\n", | |
"Parch 0.096318\n", | |
"Age 0.119618\n", | |
"Fare 0.171288\n", | |
"Pclass 0.220618\n", | |
"Survived 1.000000\n", | |
"SexNumerical NaN\n", | |
"Name: Survived, dtype: float64" | |
] | |
}, | |
"execution_count": 8, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"data_raw.where(data_raw['Sex']=='male').corr()['Survived'].abs().sort_values()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"# We want to use only four classes: 'sex', 'age', 'Pclass' and 'survived'.\n", | |
"# Since some of them may contain NaNs, we need to filter these rows out\n", | |
"data_raw = data_raw.dropna(axis=0, how='any', subset = ['Sex', 'Age', 'Pclass','Survived'])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"# Now we split the data in a training set and an evaluation set.\n", | |
"# we use 90% of all data for the training and 10% for the evaluation\n", | |
"idx = int(data_raw.shape[0]*.9)\n", | |
"train_data = data_raw.iloc[:idx].copy()\n", | |
"eval_data = data_raw.iloc[idx:].copy()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"# We now prepare our three feature columns for the decision tree routine\n", | |
"# by stacking them all together\n", | |
"samples = pd.concat([train_data['SexNumerical'], train_data['Pclass'], train_data['Age']],axis=1)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"# Now to the fitting part. Define the classifier and fit it to the data\n", | |
"clf = tree.DecisionTreeClassifier(criterion='gini', max_depth=3, max_leaf_nodes=7)\n", | |
"clf = clf.fit(samples, train_data['Survived'])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"metadata": { | |
"scrolled": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"image/svg+xml": [ | |
"<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n", | |
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n", | |
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n", | |
"<!-- Generated by graphviz version 2.40.1 (20161225.0304)\n", | |
" -->\n", | |
"<!-- Title: Tree Pages: 1 -->\n", | |
"<svg width=\"756pt\" height=\"552pt\"\n", | |
" viewBox=\"0.00 0.00 756.00 552.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n", | |
"<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 548)\">\n", | |
"<title>Tree</title>\n", | |
"<polygon fill=\"#ffffff\" stroke=\"transparent\" points=\"-4,4 -4,-548 752,-548 752,4 -4,4\"/>\n", | |
"<!-- 0 -->\n", | |
"<g id=\"node1\" class=\"node\">\n", | |
"<title>0</title>\n", | |
"<path fill=\"#e58139\" fill-opacity=\"0.305882\" stroke=\"#000000\" d=\"M357.5,-544C357.5,-544 232.5,-544 232.5,-544 226.5,-544 220.5,-538 220.5,-532 220.5,-532 220.5,-473 220.5,-473 220.5,-467 226.5,-461 232.5,-461 232.5,-461 357.5,-461 357.5,-461 363.5,-461 369.5,-467 369.5,-473 369.5,-473 369.5,-532 369.5,-532 369.5,-538 363.5,-544 357.5,-544\"/>\n", | |
"<text text-anchor=\"start\" x=\"260\" y=\"-528.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">Sex ≤ 0.5</text>\n", | |
"<text text-anchor=\"start\" x=\"251\" y=\"-513.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">gini = 0.484</text>\n", | |
"<text text-anchor=\"start\" x=\"242\" y=\"-498.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">samples = 642</text>\n", | |
"<text text-anchor=\"start\" x=\"228.5\" y=\"-483.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">value = [379, 263]</text>\n", | |
"<text text-anchor=\"start\" x=\"251\" y=\"-468.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">class = Died</text>\n", | |
"</g>\n", | |
"<!-- 1 -->\n", | |
"<g id=\"node2\" class=\"node\">\n", | |
"<title>1</title>\n", | |
"<path fill=\"#399de5\" fill-opacity=\"0.678431\" stroke=\"#000000\" d=\"M274,-425C274,-425 158,-425 158,-425 152,-425 146,-419 146,-413 146,-413 146,-354 146,-354 146,-348 152,-342 158,-342 158,-342 274,-342 274,-342 280,-342 286,-348 286,-354 286,-354 286,-413 286,-413 286,-419 280,-425 274,-425\"/>\n", | |
"<text text-anchor=\"start\" x=\"175.5\" y=\"-409.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">Class ≤ 2.5</text>\n", | |
"<text text-anchor=\"start\" x=\"172\" y=\"-394.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">gini = 0.369</text>\n", | |
"<text text-anchor=\"start\" x=\"163\" y=\"-379.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">samples = 234</text>\n", | |
"<text text-anchor=\"start\" x=\"154\" y=\"-364.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">value = [57, 177]</text>\n", | |
"<text text-anchor=\"start\" x=\"157\" y=\"-349.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">class = Survived</text>\n", | |
"</g>\n", | |
"<!-- 0->1 -->\n", | |
"<g id=\"edge1\" class=\"edge\">\n", | |
"<title>0->1</title>\n", | |
"<path fill=\"none\" stroke=\"#000000\" d=\"M267.3696,-460.8796C261.57,-452.1434 255.3941,-442.8404 249.4092,-433.8253\"/>\n", | |
"<polygon fill=\"#000000\" stroke=\"#000000\" points=\"252.1951,-431.6935 243.7483,-425.2981 246.3632,-435.5652 252.1951,-431.6935\"/>\n", | |
"<text text-anchor=\"middle\" x=\"238.7978\" y=\"-446.103\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">True</text>\n", | |
"</g>\n", | |
"<!-- 2 -->\n", | |
"<g id=\"node5\" class=\"node\">\n", | |
"<title>2</title>\n", | |
"<path fill=\"#e58139\" fill-opacity=\"0.733333\" stroke=\"#000000\" d=\"M432,-425C432,-425 316,-425 316,-425 310,-425 304,-419 304,-413 304,-413 304,-354 304,-354 304,-348 310,-342 316,-342 316,-342 432,-342 432,-342 438,-342 444,-348 444,-354 444,-354 444,-413 444,-413 444,-419 438,-425 432,-425\"/>\n", | |
"<text text-anchor=\"start\" x=\"334\" y=\"-409.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">Age ≤ 13.0</text>\n", | |
"<text text-anchor=\"start\" x=\"330\" y=\"-394.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">gini = 0.333</text>\n", | |
"<text text-anchor=\"start\" x=\"321\" y=\"-379.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">samples = 408</text>\n", | |
"<text text-anchor=\"start\" x=\"312\" y=\"-364.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">value = [322, 86]</text>\n", | |
"<text text-anchor=\"start\" x=\"330\" y=\"-349.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">class = Died</text>\n", | |
"</g>\n", | |
"<!-- 0->2 -->\n", | |
"<g id=\"edge4\" class=\"edge\">\n", | |
"<title>0->2</title>\n", | |
"<path fill=\"none\" stroke=\"#000000\" d=\"M322.6304,-460.8796C328.43,-452.1434 334.6059,-442.8404 340.5908,-433.8253\"/>\n", | |
"<polygon fill=\"#000000\" stroke=\"#000000\" points=\"343.6368,-435.5652 346.2517,-425.2981 337.8049,-431.6935 343.6368,-435.5652\"/>\n", | |
"<text text-anchor=\"middle\" x=\"351.2022\" y=\"-446.103\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">False</text>\n", | |
"</g>\n", | |
"<!-- 3 -->\n", | |
"<g id=\"node3\" class=\"node\">\n", | |
"<title>3</title>\n", | |
"<path fill=\"#399de5\" fill-opacity=\"0.941176\" stroke=\"#000000\" d=\"M122,-298.5C122,-298.5 12,-298.5 12,-298.5 6,-298.5 0,-292.5 0,-286.5 0,-286.5 0,-242.5 0,-242.5 0,-236.5 6,-230.5 12,-230.5 12,-230.5 122,-230.5 122,-230.5 128,-230.5 134,-236.5 134,-242.5 134,-242.5 134,-286.5 134,-286.5 134,-292.5 128,-298.5 122,-298.5\"/>\n", | |
"<text text-anchor=\"start\" x=\"23\" y=\"-283.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">gini = 0.106</text>\n", | |
"<text text-anchor=\"start\" x=\"14\" y=\"-268.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">samples = 143</text>\n", | |
"<text text-anchor=\"start\" x=\"9.5\" y=\"-253.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">value = [8, 135]</text>\n", | |
"<text text-anchor=\"start\" x=\"8\" y=\"-238.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">class = Survived</text>\n", | |
"</g>\n", | |
"<!-- 1->3 -->\n", | |
"<g id=\"edge2\" class=\"edge\">\n", | |
"<title>1->3</title>\n", | |
"<path fill=\"none\" stroke=\"#000000\" d=\"M163.887,-341.8796C149.0174,-330.0038 132.836,-317.0804 118.0317,-305.2568\"/>\n", | |
"<polygon fill=\"#000000\" stroke=\"#000000\" points=\"119.9638,-302.3207 109.9658,-298.8149 115.5954,-307.7904 119.9638,-302.3207\"/>\n", | |
"</g>\n", | |
"<!-- 4 -->\n", | |
"<g id=\"node4\" class=\"node\">\n", | |
"<title>4</title>\n", | |
"<path fill=\"#e58139\" fill-opacity=\"0.141176\" stroke=\"#000000\" d=\"M271.5,-298.5C271.5,-298.5 164.5,-298.5 164.5,-298.5 158.5,-298.5 152.5,-292.5 152.5,-286.5 152.5,-286.5 152.5,-242.5 152.5,-242.5 152.5,-236.5 158.5,-230.5 164.5,-230.5 164.5,-230.5 271.5,-230.5 271.5,-230.5 277.5,-230.5 283.5,-236.5 283.5,-242.5 283.5,-242.5 283.5,-286.5 283.5,-286.5 283.5,-292.5 277.5,-298.5 271.5,-298.5\"/>\n", | |
"<text text-anchor=\"start\" x=\"174\" y=\"-283.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">gini = 0.497</text>\n", | |
"<text text-anchor=\"start\" x=\"169.5\" y=\"-268.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">samples = 91</text>\n", | |
"<text text-anchor=\"start\" x=\"160.5\" y=\"-253.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">value = [49, 42]</text>\n", | |
"<text text-anchor=\"start\" x=\"174\" y=\"-238.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">class = Died</text>\n", | |
"</g>\n", | |
"<!-- 1->4 -->\n", | |
"<g id=\"edge3\" class=\"edge\">\n", | |
"<title>1->4</title>\n", | |
"<path fill=\"none\" stroke=\"#000000\" d=\"M216.6995,-341.8796C216.8788,-331.2134 217.0722,-319.7021 217.2538,-308.9015\"/>\n", | |
"<polygon fill=\"#000000\" stroke=\"#000000\" points=\"220.7546,-308.8724 217.4233,-298.8149 213.7556,-308.7547 220.7546,-308.8724\"/>\n", | |
"</g>\n", | |
"<!-- 5 -->\n", | |
"<g id=\"node6\" class=\"node\">\n", | |
"<title>5</title>\n", | |
"<path fill=\"#399de5\" fill-opacity=\"0.278431\" stroke=\"#000000\" d=\"M427,-306C427,-306 317,-306 317,-306 311,-306 305,-300 305,-294 305,-294 305,-235 305,-235 305,-229 311,-223 317,-223 317,-223 427,-223 427,-223 433,-223 439,-229 439,-235 439,-235 439,-294 439,-294 439,-300 433,-306 427,-306\"/>\n", | |
"<text text-anchor=\"start\" x=\"331.5\" y=\"-290.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">Class ≤ 2.5</text>\n", | |
"<text text-anchor=\"start\" x=\"328\" y=\"-275.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">gini = 0.487</text>\n", | |
"<text text-anchor=\"start\" x=\"323.5\" y=\"-260.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">samples = 31</text>\n", | |
"<text text-anchor=\"start\" x=\"314.5\" y=\"-245.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">value = [13, 18]</text>\n", | |
"<text text-anchor=\"start\" x=\"313\" y=\"-230.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">class = Survived</text>\n", | |
"</g>\n", | |
"<!-- 2->5 -->\n", | |
"<g id=\"edge5\" class=\"edge\">\n", | |
"<title>2->5</title>\n", | |
"<path fill=\"none\" stroke=\"#000000\" d=\"M373.3005,-341.8796C373.1628,-333.6838 373.0166,-324.9891 372.874,-316.5013\"/>\n", | |
"<polygon fill=\"#000000\" stroke=\"#000000\" points=\"376.3701,-316.2378 372.7025,-306.2981 369.3711,-316.3555 376.3701,-316.2378\"/>\n", | |
"</g>\n", | |
"<!-- 6 -->\n", | |
"<g id=\"node9\" class=\"node\">\n", | |
"<title>6</title>\n", | |
"<path fill=\"#e58139\" fill-opacity=\"0.780392\" stroke=\"#000000\" d=\"M585,-306C585,-306 469,-306 469,-306 463,-306 457,-300 457,-294 457,-294 457,-235 457,-235 457,-229 463,-223 469,-223 469,-223 585,-223 585,-223 591,-223 597,-229 597,-235 597,-235 597,-294 597,-294 597,-300 591,-306 585,-306\"/>\n", | |
"<text text-anchor=\"start\" x=\"486.5\" y=\"-290.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">Class ≤ 1.5</text>\n", | |
"<text text-anchor=\"start\" x=\"483\" y=\"-275.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">gini = 0.296</text>\n", | |
"<text text-anchor=\"start\" x=\"474\" y=\"-260.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">samples = 377</text>\n", | |
"<text text-anchor=\"start\" x=\"465\" y=\"-245.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">value = [309, 68]</text>\n", | |
"<text text-anchor=\"start\" x=\"483\" y=\"-230.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">class = Died</text>\n", | |
"</g>\n", | |
"<!-- 2->6 -->\n", | |
"<g id=\"edge8\" class=\"edge\">\n", | |
"<title>2->6</title>\n", | |
"<path fill=\"none\" stroke=\"#000000\" d=\"M427.512,-341.8796C439.7412,-332.368 452.8344,-322.1843 465.3732,-312.432\"/>\n", | |
"<polygon fill=\"#000000\" stroke=\"#000000\" points=\"467.7735,-314.9991 473.5182,-306.0969 463.4759,-309.4736 467.7735,-314.9991\"/>\n", | |
"</g>\n", | |
"<!-- 9 -->\n", | |
"<g id=\"node7\" class=\"node\">\n", | |
"<title>9</title>\n", | |
"<path fill=\"#399de5\" stroke=\"#000000\" d=\"M283,-179.5C283,-179.5 173,-179.5 173,-179.5 167,-179.5 161,-173.5 161,-167.5 161,-167.5 161,-123.5 161,-123.5 161,-117.5 167,-111.5 173,-111.5 173,-111.5 283,-111.5 283,-111.5 289,-111.5 295,-117.5 295,-123.5 295,-123.5 295,-167.5 295,-167.5 295,-173.5 289,-179.5 283,-179.5\"/>\n", | |
"<text text-anchor=\"start\" x=\"193\" y=\"-164.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">gini = 0.0</text>\n", | |
"<text text-anchor=\"start\" x=\"179.5\" y=\"-149.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">samples = 10</text>\n", | |
"<text text-anchor=\"start\" x=\"175\" y=\"-134.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">value = [0, 10]</text>\n", | |
"<text text-anchor=\"start\" x=\"169\" y=\"-119.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">class = Survived</text>\n", | |
"</g>\n", | |
"<!-- 5->9 -->\n", | |
"<g id=\"edge6\" class=\"edge\">\n", | |
"<title>5->9</title>\n", | |
"<path fill=\"none\" stroke=\"#000000\" d=\"M321.6358,-222.8796C307.2651,-211.0038 291.6267,-198.0804 277.3192,-186.2568\"/>\n", | |
"<polygon fill=\"#000000\" stroke=\"#000000\" points=\"279.462,-183.4872 269.524,-179.8149 275.0029,-188.8831 279.462,-183.4872\"/>\n", | |
"</g>\n", | |
"<!-- 10 -->\n", | |
"<g id=\"node8\" class=\"node\">\n", | |
"<title>10</title>\n", | |
"<path fill=\"#e58139\" fill-opacity=\"0.384314\" stroke=\"#000000\" d=\"M423,-179.5C423,-179.5 325,-179.5 325,-179.5 319,-179.5 313,-173.5 313,-167.5 313,-167.5 313,-123.5 313,-123.5 313,-117.5 319,-111.5 325,-111.5 325,-111.5 423,-111.5 423,-111.5 429,-111.5 435,-117.5 435,-123.5 435,-123.5 435,-167.5 435,-167.5 435,-173.5 429,-179.5 423,-179.5\"/>\n", | |
"<text text-anchor=\"start\" x=\"330\" y=\"-164.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">gini = 0.472</text>\n", | |
"<text text-anchor=\"start\" x=\"325.5\" y=\"-149.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">samples = 21</text>\n", | |
"<text text-anchor=\"start\" x=\"321\" y=\"-134.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">value = [13, 8]</text>\n", | |
"<text text-anchor=\"start\" x=\"330\" y=\"-119.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">class = Died</text>\n", | |
"</g>\n", | |
"<!-- 5->10 -->\n", | |
"<g id=\"edge7\" class=\"edge\">\n", | |
"<title>5->10</title>\n", | |
"<path fill=\"none\" stroke=\"#000000\" d=\"M372.6995,-222.8796C372.8788,-212.2134 373.0722,-200.7021 373.2538,-189.9015\"/>\n", | |
"<polygon fill=\"#000000\" stroke=\"#000000\" points=\"376.7546,-189.8724 373.4233,-179.8149 369.7556,-189.7547 376.7546,-189.8724\"/>\n", | |
"</g>\n", | |
"<!-- 7 -->\n", | |
"<g id=\"node10\" class=\"node\">\n", | |
"<title>7</title>\n", | |
"<path fill=\"#e58139\" fill-opacity=\"0.384314\" stroke=\"#000000\" d=\"M577.5,-187C577.5,-187 470.5,-187 470.5,-187 464.5,-187 458.5,-181 458.5,-175 458.5,-175 458.5,-116 458.5,-116 458.5,-110 464.5,-104 470.5,-104 470.5,-104 577.5,-104 577.5,-104 583.5,-104 589.5,-110 589.5,-116 589.5,-116 589.5,-175 589.5,-175 589.5,-181 583.5,-187 577.5,-187\"/>\n", | |
"<text text-anchor=\"start\" x=\"484\" y=\"-171.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">Age ≤ 43.0</text>\n", | |
"<text text-anchor=\"start\" x=\"480\" y=\"-156.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">gini = 0.471</text>\n", | |
"<text text-anchor=\"start\" x=\"475.5\" y=\"-141.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">samples = 92</text>\n", | |
"<text text-anchor=\"start\" x=\"466.5\" y=\"-126.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">value = [57, 35]</text>\n", | |
"<text text-anchor=\"start\" x=\"480\" y=\"-111.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">class = Died</text>\n", | |
"</g>\n", | |
"<!-- 6->7 -->\n", | |
"<g id=\"edge9\" class=\"edge\">\n", | |
"<title>6->7</title>\n", | |
"<path fill=\"none\" stroke=\"#000000\" d=\"M525.9507,-222.8796C525.7441,-214.6838 525.5249,-205.9891 525.311,-197.5013\"/>\n", | |
"<polygon fill=\"#000000\" stroke=\"#000000\" points=\"528.8047,-197.2067 525.0537,-187.2981 521.807,-197.3831 528.8047,-197.2067\"/>\n", | |
"</g>\n", | |
"<!-- 8 -->\n", | |
"<g id=\"node13\" class=\"node\">\n", | |
"<title>8</title>\n", | |
"<path fill=\"#e58139\" fill-opacity=\"0.870588\" stroke=\"#000000\" d=\"M736,-179.5C736,-179.5 620,-179.5 620,-179.5 614,-179.5 608,-173.5 608,-167.5 608,-167.5 608,-123.5 608,-123.5 608,-117.5 614,-111.5 620,-111.5 620,-111.5 736,-111.5 736,-111.5 742,-111.5 748,-117.5 748,-123.5 748,-123.5 748,-167.5 748,-167.5 748,-173.5 742,-179.5 736,-179.5\"/>\n", | |
"<text text-anchor=\"start\" x=\"634\" y=\"-164.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">gini = 0.205</text>\n", | |
"<text text-anchor=\"start\" x=\"625\" y=\"-149.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">samples = 285</text>\n", | |
"<text text-anchor=\"start\" x=\"616\" y=\"-134.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">value = [252, 33]</text>\n", | |
"<text text-anchor=\"start\" x=\"634\" y=\"-119.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">class = Died</text>\n", | |
"</g>\n", | |
"<!-- 6->8 -->\n", | |
"<g id=\"edge12\" class=\"edge\">\n", | |
"<title>6->8</title>\n", | |
"<path fill=\"none\" stroke=\"#000000\" d=\"M579.8125,-222.8796C594.8817,-211.0038 611.2803,-198.0804 626.2834,-186.2568\"/>\n", | |
"<polygon fill=\"#000000\" stroke=\"#000000\" points=\"628.7698,-188.7536 634.4575,-179.8149 624.4369,-183.2557 628.7698,-188.7536\"/>\n", | |
"</g>\n", | |
"<!-- 11 -->\n", | |
"<g id=\"node11\" class=\"node\">\n", | |
"<title>11</title>\n", | |
"<path fill=\"#399de5\" fill-opacity=\"0.082353\" stroke=\"#000000\" d=\"M503,-68C503,-68 393,-68 393,-68 387,-68 381,-62 381,-56 381,-56 381,-12 381,-12 381,-6 387,0 393,0 393,0 503,0 503,0 509,0 515,-6 515,-12 515,-12 515,-56 515,-56 515,-62 509,-68 503,-68\"/>\n", | |
"<text text-anchor=\"start\" x=\"404\" y=\"-52.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">gini = 0.499</text>\n", | |
"<text text-anchor=\"start\" x=\"399.5\" y=\"-37.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">samples = 46</text>\n", | |
"<text text-anchor=\"start\" x=\"390.5\" y=\"-22.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">value = [22, 24]</text>\n", | |
"<text text-anchor=\"start\" x=\"389\" y=\"-7.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">class = Survived</text>\n", | |
"</g>\n", | |
"<!-- 7->11 -->\n", | |
"<g id=\"edge10\" class=\"edge\">\n", | |
"<title>7->11</title>\n", | |
"<path fill=\"none\" stroke=\"#000000\" d=\"M495.7004,-103.9815C489.6238,-95.0666 483.1926,-85.6313 477.0868,-76.6734\"/>\n", | |
"<polygon fill=\"#000000\" stroke=\"#000000\" points=\"479.8789,-74.5555 471.3546,-68.2637 474.0948,-78.498 479.8789,-74.5555\"/>\n", | |
"</g>\n", | |
"<!-- 12 -->\n", | |
"<g id=\"node12\" class=\"node\">\n", | |
"<title>12</title>\n", | |
"<path fill=\"#e58139\" fill-opacity=\"0.686275\" stroke=\"#000000\" d=\"M652.5,-68C652.5,-68 545.5,-68 545.5,-68 539.5,-68 533.5,-62 533.5,-56 533.5,-56 533.5,-12 533.5,-12 533.5,-6 539.5,0 545.5,0 545.5,0 652.5,0 652.5,0 658.5,0 664.5,-6 664.5,-12 664.5,-12 664.5,-56 664.5,-56 664.5,-62 658.5,-68 652.5,-68\"/>\n", | |
"<text text-anchor=\"start\" x=\"555\" y=\"-52.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">gini = 0.364</text>\n", | |
"<text text-anchor=\"start\" x=\"550.5\" y=\"-37.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">samples = 46</text>\n", | |
"<text text-anchor=\"start\" x=\"541.5\" y=\"-22.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">value = [35, 11]</text>\n", | |
"<text text-anchor=\"start\" x=\"555\" y=\"-7.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">class = Died</text>\n", | |
"</g>\n", | |
"<!-- 7->12 -->\n", | |
"<g id=\"edge11\" class=\"edge\">\n", | |
"<title>7->12</title>\n", | |
"<path fill=\"none\" stroke=\"#000000\" d=\"M551.9272,-103.9815C557.9238,-95.0666 564.2704,-85.6313 570.2959,-76.6734\"/>\n", | |
"<polygon fill=\"#000000\" stroke=\"#000000\" points=\"573.2755,-78.5147 575.9527,-68.2637 567.4672,-74.6078 573.2755,-78.5147\"/>\n", | |
"</g>\n", | |
"</g>\n", | |
"</svg>\n" | |
], | |
"text/plain": [ | |
"<graphviz.files.Source at 0x7f4623126e80>" | |
] | |
}, | |
"execution_count": 13, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# To visualize our graph we use graphviz:\n", | |
"dot_data = tree.export_graphviz(clf, out_file=None, \n", | |
" feature_names=['Sex','Class','Age'], \n", | |
" class_names=['Died','Survived'], \n", | |
" filled=True, rounded=True, \n", | |
" special_characters=True,\n", | |
" leaves_parallel=False) \n", | |
"graph = graphviz.Source(dot_data)\n", | |
"graph" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"11.58 %\n" | |
] | |
} | |
], | |
"source": [ | |
"# We'll use the \"predict_proba\" method to get predictions\n", | |
"prob = clf.predict_proba([[1,2,33]])[0,1]\n", | |
"print('{:.02f} %'.format(100*prob))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"For me, a 33 year old man who hasn't the money to travel first class, the odds to survive were: **11.58%** ... yikes \n", | |
"\n", | |
"**But wait** ! How good is our model anyway ? We need to test it. *So let's write an evaluation routine.*" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"metadata": { | |
"collapsed": true, | |
"scrolled": true | |
}, | |
"outputs": [], | |
"source": [ | |
"# we define a lambda function for that\n", | |
"probs = lambda x: clf.predict([x])\n", | |
"# now use the eval data !\n", | |
"eval_samples = pd.concat([eval_data['SexNumerical'], eval_data['Pclass'], eval_data['Age']],axis=1)\n", | |
"eval_data['Predicted_Probablity'] = [probs(idx) for idx in eval_samples.values]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"def f(x): \n", | |
" return 0 if x['Predicted_Probablity'] != x['Survived'] else 1\n", | |
"\n", | |
"eval_data['prediction_was_correct'] = eval_data.apply(f, axis=1)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"80.56 % was correctly predicted\n" | |
] | |
} | |
], | |
"source": [ | |
"precision = eval_data['prediction_was_correct'].sum(axis=0)/eval_data.shape[0]\n", | |
"# Get the predictions and print it out nicely\n", | |
"print('{:.02f} % was correctly predicted'.format(100*precision))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"######################################################################################\n", | |
"###### This is the second part of the article. It is only for tutorial purposes ######\n", | |
"######################################################################################\n", | |
"\n", | |
"# we need linspace and isscalar from numpy\n", | |
"from numpy import linspace, isscalar\n", | |
"\n", | |
"#### we calculate everything for the root node as an example\n", | |
"\n", | |
"# The number of samples in the root node\n", | |
"n_all = len(train_data)\n", | |
"\n", | |
"# the proportional probabilites\n", | |
"p_mk = lambda n_m, indicator: (1/n_m)*indicator\n", | |
"# gini metric\n", | |
"gini = lambda pmk_0, pmk_1: pmk_0*(1-pmk_0)+pmk_1*(1-pmk_1)\n", | |
"# cost function\n", | |
"cost_fun = lambda n_0, n_1, gini_0, gini_1: (1/n_all)*(n_0*gini_0+n_1*gini_1)\n", | |
"# all the subsets\n", | |
"data_provider = lambda cond: train_data.where(cond).dropna(axis=0, how='any', \n", | |
" subset = ['Sex', 'Age', 'Pclass','Survived'])\n", | |
"\n", | |
"def get_subset(t_m, feat):\n", | |
" # smaller than t_m\n", | |
" subsets = [data_provider((train_data['Survived']==cat[0]) & (train_data[feat]<=cat[1])) \n", | |
" for cat in [[0,t_m],[1,t_m]]]\n", | |
" # larger than t_m\n", | |
" subsets.extend([data_provider((train_data['Survived']==cat[0]) & (train_data[feat]>=cat[1])) \n", | |
" for cat in [[0,t_m],[1,t_m]]])\n", | |
" return subsets\n", | |
"\n", | |
"def calculate_cost(data_subsets):\n", | |
" # the indicator function is just the length of each subset\n", | |
" n_k = [len(subset) for subset in data_subsets]\n", | |
" # splitting into smaller/larger\n", | |
" n_1 = sum(n_k[:2])\n", | |
" n_2 = sum(n_k[2:])\n", | |
"\n", | |
" # intermediate probabilities\n", | |
" pmk_1 = p_mk(n_1, n_k[0])\n", | |
" pmk_2 = p_mk(n_1, n_k[1])\n", | |
" pmk_3 = p_mk(n_2, n_k[2])\n", | |
" pmk_4 = p_mk(n_2, n_k[3])\n", | |
"\n", | |
" # gini for binary choices\n", | |
" gini_1 = gini(pmk_1, pmk_2)\n", | |
" gini_2 = gini(pmk_3, pmk_4)\n", | |
"\n", | |
" # impurity cost function\n", | |
" return cost_fun(n_1, n_2, gini_1, gini_2),gini_1 , gini_2" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"The best results for each feature are:\n", | |
"'SexNumerical' -> Cost: 0.346, Threshold: 0.50, Gini_Left: 0.369, Gini_Right: 0.333\n", | |
"The best results for each feature are:\n", | |
"'Age' -> Cost: 0.470, Threshold: 6.05, Gini_Left: 0.393, Gini_Right: 0.475\n", | |
"The best results for each feature are:\n", | |
"'Pclass' -> Cost: 0.572, Threshold: 2.00, Gini_Left: 0.490, Gini_Right: 0.439\n" | |
] | |
} | |
], | |
"source": [ | |
"# Tuple to account for all classes, we'll use 100 datapoints to rasterize Age\n", | |
"it_tuple = ((.5, 'SexNumerical'),((linspace(train_data['Age'].min(),train_data['Age'].max(),100)),'Age'), ((1,2,3),'Pclass'))\n", | |
"\n", | |
"# we store the results in a dict\n", | |
"results = {}\n", | |
"# helper variable\n", | |
"cost_start = 1\n", | |
"# iterate over all values in it_tuple\n", | |
"for item in it_tuple:\n", | |
" # if we don't have a scalar then iterate over to first item\n", | |
" if not isscalar(item[0]):\n", | |
" for thresholds in item[0]:\n", | |
" # get all four subsets\n", | |
" data_subsets = get_subset(thresholds, item[1])\n", | |
" # calculate the cost function, return it as well as both gini scores\n", | |
" cost, gini_imp_1 , gini_imp_2 = calculate_cost(data_subsets)\n", | |
" # here we need the helper, write only in the dict if the\n", | |
" # value is smaller then everything before it\n", | |
" if cost < cost_start:\n", | |
" results.update({item[1]: (thresholds, cost, gini_imp_1 , gini_imp_2)})\n", | |
" cost_start = cost\n", | |
" # reset cost_start\n", | |
" cost_start = 1\n", | |
" else:\n", | |
" # the same for scalars\n", | |
" data_subsets = get_subset(item[0], item[1])\n", | |
" cost, gini_imp_1 , gini_imp_2 = calculate_cost(data_subsets)\n", | |
" results.update({item[1]: (item[0], cost, gini_imp_1 , gini_imp_2)})\n", | |
"\n", | |
"# print out results\n", | |
"for feature, vals in results.items():\n", | |
" print('The best results for each feature are:')\n", | |
" print('\\'{}\\' -> Cost: {:.03f}, Threshold: {:.02f}, Gini_Left: {:.03f}, Gini_Right: {:.03f}'. format(feature, vals[1], vals[0], vals[2], vals[3]))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"gist_id": "45b94505b330378d12765e65b2814f6b", | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.4" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment