Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# The imports we are using\n",
"import pandas as pd\n",
"import graphviz\n",
"from sklearn import tree"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# Get the data. Which is in our case very easy, because it's from kaggle: https://www.kaggle.com/c/titanic/data\n",
"data_raw = pd.read_csv('train.csv')"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"scrolled": false
},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>PassengerId</th>\n",
" <th>Survived</th>\n",
" <th>Pclass</th>\n",
" <th>Name</th>\n",
" <th>Sex</th>\n",
" <th>Age</th>\n",
" <th>SibSp</th>\n",
" <th>Parch</th>\n",
" <th>Ticket</th>\n",
" <th>Fare</th>\n",
" <th>Cabin</th>\n",
" <th>Embarked</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>3</td>\n",
" <td>Braund, Mr. Owen Harris</td>\n",
" <td>male</td>\n",
" <td>22.0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>A/5 21171</td>\n",
" <td>7.2500</td>\n",
" <td>NaN</td>\n",
" <td>S</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>2</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>Cumings, Mrs. John Bradley (Florence Briggs Th...</td>\n",
" <td>female</td>\n",
" <td>38.0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>PC 17599</td>\n",
" <td>71.2833</td>\n",
" <td>C85</td>\n",
" <td>C</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>3</td>\n",
" <td>1</td>\n",
" <td>3</td>\n",
" <td>Heikkinen, Miss. Laina</td>\n",
" <td>female</td>\n",
" <td>26.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>STON/O2. 3101282</td>\n",
" <td>7.9250</td>\n",
" <td>NaN</td>\n",
" <td>S</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>4</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>Futrelle, Mrs. Jacques Heath (Lily May Peel)</td>\n",
" <td>female</td>\n",
" <td>35.0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>113803</td>\n",
" <td>53.1000</td>\n",
" <td>C123</td>\n",
" <td>S</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>5</td>\n",
" <td>0</td>\n",
" <td>3</td>\n",
" <td>Allen, Mr. William Henry</td>\n",
" <td>male</td>\n",
" <td>35.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>373450</td>\n",
" <td>8.0500</td>\n",
" <td>NaN</td>\n",
" <td>S</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>6</td>\n",
" <td>0</td>\n",
" <td>3</td>\n",
" <td>Moran, Mr. James</td>\n",
" <td>male</td>\n",
" <td>NaN</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>330877</td>\n",
" <td>8.4583</td>\n",
" <td>NaN</td>\n",
" <td>Q</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>7</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>McCarthy, Mr. Timothy J</td>\n",
" <td>male</td>\n",
" <td>54.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>17463</td>\n",
" <td>51.8625</td>\n",
" <td>E46</td>\n",
" <td>S</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <td>8</td>\n",
" <td>0</td>\n",
" <td>3</td>\n",
" <td>Palsson, Master. Gosta Leonard</td>\n",
" <td>male</td>\n",
" <td>2.0</td>\n",
" <td>3</td>\n",
" <td>1</td>\n",
" <td>349909</td>\n",
" <td>21.0750</td>\n",
" <td>NaN</td>\n",
" <td>S</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <td>9</td>\n",
" <td>1</td>\n",
" <td>3</td>\n",
" <td>Johnson, Mrs. Oscar W (Elisabeth Vilhelmina Berg)</td>\n",
" <td>female</td>\n",
" <td>27.0</td>\n",
" <td>0</td>\n",
" <td>2</td>\n",
" <td>347742</td>\n",
" <td>11.1333</td>\n",
" <td>NaN</td>\n",
" <td>S</td>\n",
" </tr>\n",
" <tr>\n",
" <th>9</th>\n",
" <td>10</td>\n",
" <td>1</td>\n",
" <td>2</td>\n",
" <td>Nasser, Mrs. Nicholas (Adele Achem)</td>\n",
" <td>female</td>\n",
" <td>14.0</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>237736</td>\n",
" <td>30.0708</td>\n",
" <td>NaN</td>\n",
" <td>C</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" PassengerId Survived Pclass \\\n",
"0 1 0 3 \n",
"1 2 1 1 \n",
"2 3 1 3 \n",
"3 4 1 1 \n",
"4 5 0 3 \n",
"5 6 0 3 \n",
"6 7 0 1 \n",
"7 8 0 3 \n",
"8 9 1 3 \n",
"9 10 1 2 \n",
"\n",
" Name Sex Age SibSp \\\n",
"0 Braund, Mr. Owen Harris male 22.0 1 \n",
"1 Cumings, Mrs. John Bradley (Florence Briggs Th... female 38.0 1 \n",
"2 Heikkinen, Miss. Laina female 26.0 0 \n",
"3 Futrelle, Mrs. Jacques Heath (Lily May Peel) female 35.0 1 \n",
"4 Allen, Mr. William Henry male 35.0 0 \n",
"5 Moran, Mr. James male NaN 0 \n",
"6 McCarthy, Mr. Timothy J male 54.0 0 \n",
"7 Palsson, Master. Gosta Leonard male 2.0 3 \n",
"8 Johnson, Mrs. Oscar W (Elisabeth Vilhelmina Berg) female 27.0 0 \n",
"9 Nasser, Mrs. Nicholas (Adele Achem) female 14.0 1 \n",
"\n",
" Parch Ticket Fare Cabin Embarked \n",
"0 0 A/5 21171 7.2500 NaN S \n",
"1 0 PC 17599 71.2833 C85 C \n",
"2 0 STON/O2. 3101282 7.9250 NaN S \n",
"3 0 113803 53.1000 C123 S \n",
"4 0 373450 8.0500 NaN S \n",
"5 0 330877 8.4583 NaN Q \n",
"6 0 17463 51.8625 E46 S \n",
"7 1 349909 21.0750 NaN S \n",
"8 2 347742 11.1333 NaN S \n",
"9 0 237736 30.0708 NaN C "
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Lets have a look at it\n",
"# We use the iloc method so see the first 10 entries\n",
"data_raw.iloc[:10]"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# Next we need to convert the 'sex' column to binary values.\n",
"# For that we use an if/then clause within list comprehension\n",
"data_raw['SexNumerical'] = [1 if x=='male' else 0 for x in data_raw['Sex']]"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"PassengerId 0.005007\n",
"SibSp 0.035322\n",
"Age 0.077221\n",
"Parch 0.081629\n",
"Fare 0.257307\n",
"Pclass 0.338481\n",
"SexNumerical 0.543351\n",
"Survived 1.000000\n",
"Name: Survived, dtype: float64"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# a good idea for predicting survival is look at the correlations\n",
"# with the 'Survived' feature. The stronger the correlation the more\n",
"# the corresponding feature explains the survival.\n",
"data_raw.corr()['Survived'].abs().sort_values()"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"233 women and 109 men survived\n"
]
}
],
"source": [
"print('{:.0f} women and {:.0f} men survived'.format(data_raw.where(data_raw['Sex']=='female')['Survived'].sum(),data_raw.where(data_raw['Sex']=='male')['Survived'].sum()))"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"PassengerId 0.008790\n",
"Age 0.116109\n",
"Fare 0.218466\n",
"Parch 0.223644\n",
"SibSp 0.263284\n",
"Pclass 0.477114\n",
"Survived 1.000000\n",
"SexNumerical NaN\n",
"Name: Survived, dtype: float64"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Since the sex has such a strong correlation we split the dataset\n",
"# and look at all correlations for female and male passengers.\n",
"data_raw.where(data_raw['Sex']=='female').corr()['Survived'].abs().sort_values()"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"scrolled": false
},
"outputs": [
{
"data": {
"text/plain": [
"SibSp 0.020238\n",
"PassengerId 0.040477\n",
"Parch 0.096318\n",
"Age 0.119618\n",
"Fare 0.171288\n",
"Pclass 0.220618\n",
"Survived 1.000000\n",
"SexNumerical NaN\n",
"Name: Survived, dtype: float64"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data_raw.where(data_raw['Sex']=='male').corr()['Survived'].abs().sort_values()"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# We want to use only four classes: 'sex', 'age', 'Pclass' and 'survived'.\n",
"# Since some of them may contain NaNs, we need to filter these rows out\n",
"data_raw = data_raw.dropna(axis=0, how='any', subset = ['Sex', 'Age', 'Pclass','Survived'])"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# Now we split the data in a training set and an evaluation set.\n",
"# we use 90% of all data for the training and 10% for the evaluation\n",
"idx = int(data_raw.shape[0]*.9)\n",
"train_data = data_raw.iloc[:idx].copy()\n",
"eval_data = data_raw.iloc[idx:].copy()"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# We now prepare our three feature columns for the decision tree routine\n",
"# by stacking them all together\n",
"samples = pd.concat([train_data['SexNumerical'], train_data['Pclass'], train_data['Age']],axis=1)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# Now to the fitting part. Define the classifier and fit it to the data\n",
"clf = tree.DecisionTreeClassifier(criterion='gini', max_depth=3, max_leaf_nodes=7)\n",
"clf = clf.fit(samples, train_data['Survived'])"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"scrolled": false
},
"outputs": [
{
"data": {
"image/svg+xml": [
"<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
"<!-- Generated by graphviz version 2.40.1 (20161225.0304)\n",
" -->\n",
"<!-- Title: Tree Pages: 1 -->\n",
"<svg width=\"756pt\" height=\"552pt\"\n",
" viewBox=\"0.00 0.00 756.00 552.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
"<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 548)\">\n",
"<title>Tree</title>\n",
"<polygon fill=\"#ffffff\" stroke=\"transparent\" points=\"-4,4 -4,-548 752,-548 752,4 -4,4\"/>\n",
"<!-- 0 -->\n",
"<g id=\"node1\" class=\"node\">\n",
"<title>0</title>\n",
"<path fill=\"#e58139\" fill-opacity=\"0.305882\" stroke=\"#000000\" d=\"M357.5,-544C357.5,-544 232.5,-544 232.5,-544 226.5,-544 220.5,-538 220.5,-532 220.5,-532 220.5,-473 220.5,-473 220.5,-467 226.5,-461 232.5,-461 232.5,-461 357.5,-461 357.5,-461 363.5,-461 369.5,-467 369.5,-473 369.5,-473 369.5,-532 369.5,-532 369.5,-538 363.5,-544 357.5,-544\"/>\n",
"<text text-anchor=\"start\" x=\"260\" y=\"-528.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">Sex ≤ 0.5</text>\n",
"<text text-anchor=\"start\" x=\"251\" y=\"-513.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">gini = 0.484</text>\n",
"<text text-anchor=\"start\" x=\"242\" y=\"-498.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">samples = 642</text>\n",
"<text text-anchor=\"start\" x=\"228.5\" y=\"-483.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">value = [379, 263]</text>\n",
"<text text-anchor=\"start\" x=\"251\" y=\"-468.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">class = Died</text>\n",
"</g>\n",
"<!-- 1 -->\n",
"<g id=\"node2\" class=\"node\">\n",
"<title>1</title>\n",
"<path fill=\"#399de5\" fill-opacity=\"0.678431\" stroke=\"#000000\" d=\"M274,-425C274,-425 158,-425 158,-425 152,-425 146,-419 146,-413 146,-413 146,-354 146,-354 146,-348 152,-342 158,-342 158,-342 274,-342 274,-342 280,-342 286,-348 286,-354 286,-354 286,-413 286,-413 286,-419 280,-425 274,-425\"/>\n",
"<text text-anchor=\"start\" x=\"175.5\" y=\"-409.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">Class ≤ 2.5</text>\n",
"<text text-anchor=\"start\" x=\"172\" y=\"-394.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">gini = 0.369</text>\n",
"<text text-anchor=\"start\" x=\"163\" y=\"-379.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">samples = 234</text>\n",
"<text text-anchor=\"start\" x=\"154\" y=\"-364.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">value = [57, 177]</text>\n",
"<text text-anchor=\"start\" x=\"157\" y=\"-349.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">class = Survived</text>\n",
"</g>\n",
"<!-- 0&#45;&gt;1 -->\n",
"<g id=\"edge1\" class=\"edge\">\n",
"<title>0&#45;&gt;1</title>\n",
"<path fill=\"none\" stroke=\"#000000\" d=\"M267.3696,-460.8796C261.57,-452.1434 255.3941,-442.8404 249.4092,-433.8253\"/>\n",
"<polygon fill=\"#000000\" stroke=\"#000000\" points=\"252.1951,-431.6935 243.7483,-425.2981 246.3632,-435.5652 252.1951,-431.6935\"/>\n",
"<text text-anchor=\"middle\" x=\"238.7978\" y=\"-446.103\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">True</text>\n",
"</g>\n",
"<!-- 2 -->\n",
"<g id=\"node5\" class=\"node\">\n",
"<title>2</title>\n",
"<path fill=\"#e58139\" fill-opacity=\"0.733333\" stroke=\"#000000\" d=\"M432,-425C432,-425 316,-425 316,-425 310,-425 304,-419 304,-413 304,-413 304,-354 304,-354 304,-348 310,-342 316,-342 316,-342 432,-342 432,-342 438,-342 444,-348 444,-354 444,-354 444,-413 444,-413 444,-419 438,-425 432,-425\"/>\n",
"<text text-anchor=\"start\" x=\"334\" y=\"-409.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">Age ≤ 13.0</text>\n",
"<text text-anchor=\"start\" x=\"330\" y=\"-394.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">gini = 0.333</text>\n",
"<text text-anchor=\"start\" x=\"321\" y=\"-379.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">samples = 408</text>\n",
"<text text-anchor=\"start\" x=\"312\" y=\"-364.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">value = [322, 86]</text>\n",
"<text text-anchor=\"start\" x=\"330\" y=\"-349.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">class = Died</text>\n",
"</g>\n",
"<!-- 0&#45;&gt;2 -->\n",
"<g id=\"edge4\" class=\"edge\">\n",
"<title>0&#45;&gt;2</title>\n",
"<path fill=\"none\" stroke=\"#000000\" d=\"M322.6304,-460.8796C328.43,-452.1434 334.6059,-442.8404 340.5908,-433.8253\"/>\n",
"<polygon fill=\"#000000\" stroke=\"#000000\" points=\"343.6368,-435.5652 346.2517,-425.2981 337.8049,-431.6935 343.6368,-435.5652\"/>\n",
"<text text-anchor=\"middle\" x=\"351.2022\" y=\"-446.103\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">False</text>\n",
"</g>\n",
"<!-- 3 -->\n",
"<g id=\"node3\" class=\"node\">\n",
"<title>3</title>\n",
"<path fill=\"#399de5\" fill-opacity=\"0.941176\" stroke=\"#000000\" d=\"M122,-298.5C122,-298.5 12,-298.5 12,-298.5 6,-298.5 0,-292.5 0,-286.5 0,-286.5 0,-242.5 0,-242.5 0,-236.5 6,-230.5 12,-230.5 12,-230.5 122,-230.5 122,-230.5 128,-230.5 134,-236.5 134,-242.5 134,-242.5 134,-286.5 134,-286.5 134,-292.5 128,-298.5 122,-298.5\"/>\n",
"<text text-anchor=\"start\" x=\"23\" y=\"-283.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">gini = 0.106</text>\n",
"<text text-anchor=\"start\" x=\"14\" y=\"-268.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">samples = 143</text>\n",
"<text text-anchor=\"start\" x=\"9.5\" y=\"-253.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">value = [8, 135]</text>\n",
"<text text-anchor=\"start\" x=\"8\" y=\"-238.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">class = Survived</text>\n",
"</g>\n",
"<!-- 1&#45;&gt;3 -->\n",
"<g id=\"edge2\" class=\"edge\">\n",
"<title>1&#45;&gt;3</title>\n",
"<path fill=\"none\" stroke=\"#000000\" d=\"M163.887,-341.8796C149.0174,-330.0038 132.836,-317.0804 118.0317,-305.2568\"/>\n",
"<polygon fill=\"#000000\" stroke=\"#000000\" points=\"119.9638,-302.3207 109.9658,-298.8149 115.5954,-307.7904 119.9638,-302.3207\"/>\n",
"</g>\n",
"<!-- 4 -->\n",
"<g id=\"node4\" class=\"node\">\n",
"<title>4</title>\n",
"<path fill=\"#e58139\" fill-opacity=\"0.141176\" stroke=\"#000000\" d=\"M271.5,-298.5C271.5,-298.5 164.5,-298.5 164.5,-298.5 158.5,-298.5 152.5,-292.5 152.5,-286.5 152.5,-286.5 152.5,-242.5 152.5,-242.5 152.5,-236.5 158.5,-230.5 164.5,-230.5 164.5,-230.5 271.5,-230.5 271.5,-230.5 277.5,-230.5 283.5,-236.5 283.5,-242.5 283.5,-242.5 283.5,-286.5 283.5,-286.5 283.5,-292.5 277.5,-298.5 271.5,-298.5\"/>\n",
"<text text-anchor=\"start\" x=\"174\" y=\"-283.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">gini = 0.497</text>\n",
"<text text-anchor=\"start\" x=\"169.5\" y=\"-268.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">samples = 91</text>\n",
"<text text-anchor=\"start\" x=\"160.5\" y=\"-253.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">value = [49, 42]</text>\n",
"<text text-anchor=\"start\" x=\"174\" y=\"-238.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">class = Died</text>\n",
"</g>\n",
"<!-- 1&#45;&gt;4 -->\n",
"<g id=\"edge3\" class=\"edge\">\n",
"<title>1&#45;&gt;4</title>\n",
"<path fill=\"none\" stroke=\"#000000\" d=\"M216.6995,-341.8796C216.8788,-331.2134 217.0722,-319.7021 217.2538,-308.9015\"/>\n",
"<polygon fill=\"#000000\" stroke=\"#000000\" points=\"220.7546,-308.8724 217.4233,-298.8149 213.7556,-308.7547 220.7546,-308.8724\"/>\n",
"</g>\n",
"<!-- 5 -->\n",
"<g id=\"node6\" class=\"node\">\n",
"<title>5</title>\n",
"<path fill=\"#399de5\" fill-opacity=\"0.278431\" stroke=\"#000000\" d=\"M427,-306C427,-306 317,-306 317,-306 311,-306 305,-300 305,-294 305,-294 305,-235 305,-235 305,-229 311,-223 317,-223 317,-223 427,-223 427,-223 433,-223 439,-229 439,-235 439,-235 439,-294 439,-294 439,-300 433,-306 427,-306\"/>\n",
"<text text-anchor=\"start\" x=\"331.5\" y=\"-290.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">Class ≤ 2.5</text>\n",
"<text text-anchor=\"start\" x=\"328\" y=\"-275.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">gini = 0.487</text>\n",
"<text text-anchor=\"start\" x=\"323.5\" y=\"-260.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">samples = 31</text>\n",
"<text text-anchor=\"start\" x=\"314.5\" y=\"-245.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">value = [13, 18]</text>\n",
"<text text-anchor=\"start\" x=\"313\" y=\"-230.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">class = Survived</text>\n",
"</g>\n",
"<!-- 2&#45;&gt;5 -->\n",
"<g id=\"edge5\" class=\"edge\">\n",
"<title>2&#45;&gt;5</title>\n",
"<path fill=\"none\" stroke=\"#000000\" d=\"M373.3005,-341.8796C373.1628,-333.6838 373.0166,-324.9891 372.874,-316.5013\"/>\n",
"<polygon fill=\"#000000\" stroke=\"#000000\" points=\"376.3701,-316.2378 372.7025,-306.2981 369.3711,-316.3555 376.3701,-316.2378\"/>\n",
"</g>\n",
"<!-- 6 -->\n",
"<g id=\"node9\" class=\"node\">\n",
"<title>6</title>\n",
"<path fill=\"#e58139\" fill-opacity=\"0.780392\" stroke=\"#000000\" d=\"M585,-306C585,-306 469,-306 469,-306 463,-306 457,-300 457,-294 457,-294 457,-235 457,-235 457,-229 463,-223 469,-223 469,-223 585,-223 585,-223 591,-223 597,-229 597,-235 597,-235 597,-294 597,-294 597,-300 591,-306 585,-306\"/>\n",
"<text text-anchor=\"start\" x=\"486.5\" y=\"-290.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">Class ≤ 1.5</text>\n",
"<text text-anchor=\"start\" x=\"483\" y=\"-275.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">gini = 0.296</text>\n",
"<text text-anchor=\"start\" x=\"474\" y=\"-260.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">samples = 377</text>\n",
"<text text-anchor=\"start\" x=\"465\" y=\"-245.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">value = [309, 68]</text>\n",
"<text text-anchor=\"start\" x=\"483\" y=\"-230.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">class = Died</text>\n",
"</g>\n",
"<!-- 2&#45;&gt;6 -->\n",
"<g id=\"edge8\" class=\"edge\">\n",
"<title>2&#45;&gt;6</title>\n",
"<path fill=\"none\" stroke=\"#000000\" d=\"M427.512,-341.8796C439.7412,-332.368 452.8344,-322.1843 465.3732,-312.432\"/>\n",
"<polygon fill=\"#000000\" stroke=\"#000000\" points=\"467.7735,-314.9991 473.5182,-306.0969 463.4759,-309.4736 467.7735,-314.9991\"/>\n",
"</g>\n",
"<!-- 9 -->\n",
"<g id=\"node7\" class=\"node\">\n",
"<title>9</title>\n",
"<path fill=\"#399de5\" stroke=\"#000000\" d=\"M283,-179.5C283,-179.5 173,-179.5 173,-179.5 167,-179.5 161,-173.5 161,-167.5 161,-167.5 161,-123.5 161,-123.5 161,-117.5 167,-111.5 173,-111.5 173,-111.5 283,-111.5 283,-111.5 289,-111.5 295,-117.5 295,-123.5 295,-123.5 295,-167.5 295,-167.5 295,-173.5 289,-179.5 283,-179.5\"/>\n",
"<text text-anchor=\"start\" x=\"193\" y=\"-164.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">gini = 0.0</text>\n",
"<text text-anchor=\"start\" x=\"179.5\" y=\"-149.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">samples = 10</text>\n",
"<text text-anchor=\"start\" x=\"175\" y=\"-134.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">value = [0, 10]</text>\n",
"<text text-anchor=\"start\" x=\"169\" y=\"-119.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">class = Survived</text>\n",
"</g>\n",
"<!-- 5&#45;&gt;9 -->\n",
"<g id=\"edge6\" class=\"edge\">\n",
"<title>5&#45;&gt;9</title>\n",
"<path fill=\"none\" stroke=\"#000000\" d=\"M321.6358,-222.8796C307.2651,-211.0038 291.6267,-198.0804 277.3192,-186.2568\"/>\n",
"<polygon fill=\"#000000\" stroke=\"#000000\" points=\"279.462,-183.4872 269.524,-179.8149 275.0029,-188.8831 279.462,-183.4872\"/>\n",
"</g>\n",
"<!-- 10 -->\n",
"<g id=\"node8\" class=\"node\">\n",
"<title>10</title>\n",
"<path fill=\"#e58139\" fill-opacity=\"0.384314\" stroke=\"#000000\" d=\"M423,-179.5C423,-179.5 325,-179.5 325,-179.5 319,-179.5 313,-173.5 313,-167.5 313,-167.5 313,-123.5 313,-123.5 313,-117.5 319,-111.5 325,-111.5 325,-111.5 423,-111.5 423,-111.5 429,-111.5 435,-117.5 435,-123.5 435,-123.5 435,-167.5 435,-167.5 435,-173.5 429,-179.5 423,-179.5\"/>\n",
"<text text-anchor=\"start\" x=\"330\" y=\"-164.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">gini = 0.472</text>\n",
"<text text-anchor=\"start\" x=\"325.5\" y=\"-149.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">samples = 21</text>\n",
"<text text-anchor=\"start\" x=\"321\" y=\"-134.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">value = [13, 8]</text>\n",
"<text text-anchor=\"start\" x=\"330\" y=\"-119.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">class = Died</text>\n",
"</g>\n",
"<!-- 5&#45;&gt;10 -->\n",
"<g id=\"edge7\" class=\"edge\">\n",
"<title>5&#45;&gt;10</title>\n",
"<path fill=\"none\" stroke=\"#000000\" d=\"M372.6995,-222.8796C372.8788,-212.2134 373.0722,-200.7021 373.2538,-189.9015\"/>\n",
"<polygon fill=\"#000000\" stroke=\"#000000\" points=\"376.7546,-189.8724 373.4233,-179.8149 369.7556,-189.7547 376.7546,-189.8724\"/>\n",
"</g>\n",
"<!-- 7 -->\n",
"<g id=\"node10\" class=\"node\">\n",
"<title>7</title>\n",
"<path fill=\"#e58139\" fill-opacity=\"0.384314\" stroke=\"#000000\" d=\"M577.5,-187C577.5,-187 470.5,-187 470.5,-187 464.5,-187 458.5,-181 458.5,-175 458.5,-175 458.5,-116 458.5,-116 458.5,-110 464.5,-104 470.5,-104 470.5,-104 577.5,-104 577.5,-104 583.5,-104 589.5,-110 589.5,-116 589.5,-116 589.5,-175 589.5,-175 589.5,-181 583.5,-187 577.5,-187\"/>\n",
"<text text-anchor=\"start\" x=\"484\" y=\"-171.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">Age ≤ 43.0</text>\n",
"<text text-anchor=\"start\" x=\"480\" y=\"-156.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">gini = 0.471</text>\n",
"<text text-anchor=\"start\" x=\"475.5\" y=\"-141.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">samples = 92</text>\n",
"<text text-anchor=\"start\" x=\"466.5\" y=\"-126.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">value = [57, 35]</text>\n",
"<text text-anchor=\"start\" x=\"480\" y=\"-111.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">class = Died</text>\n",
"</g>\n",
"<!-- 6&#45;&gt;7 -->\n",
"<g id=\"edge9\" class=\"edge\">\n",
"<title>6&#45;&gt;7</title>\n",
"<path fill=\"none\" stroke=\"#000000\" d=\"M525.9507,-222.8796C525.7441,-214.6838 525.5249,-205.9891 525.311,-197.5013\"/>\n",
"<polygon fill=\"#000000\" stroke=\"#000000\" points=\"528.8047,-197.2067 525.0537,-187.2981 521.807,-197.3831 528.8047,-197.2067\"/>\n",
"</g>\n",
"<!-- 8 -->\n",
"<g id=\"node13\" class=\"node\">\n",
"<title>8</title>\n",
"<path fill=\"#e58139\" fill-opacity=\"0.870588\" stroke=\"#000000\" d=\"M736,-179.5C736,-179.5 620,-179.5 620,-179.5 614,-179.5 608,-173.5 608,-167.5 608,-167.5 608,-123.5 608,-123.5 608,-117.5 614,-111.5 620,-111.5 620,-111.5 736,-111.5 736,-111.5 742,-111.5 748,-117.5 748,-123.5 748,-123.5 748,-167.5 748,-167.5 748,-173.5 742,-179.5 736,-179.5\"/>\n",
"<text text-anchor=\"start\" x=\"634\" y=\"-164.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">gini = 0.205</text>\n",
"<text text-anchor=\"start\" x=\"625\" y=\"-149.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">samples = 285</text>\n",
"<text text-anchor=\"start\" x=\"616\" y=\"-134.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">value = [252, 33]</text>\n",
"<text text-anchor=\"start\" x=\"634\" y=\"-119.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">class = Died</text>\n",
"</g>\n",
"<!-- 6&#45;&gt;8 -->\n",
"<g id=\"edge12\" class=\"edge\">\n",
"<title>6&#45;&gt;8</title>\n",
"<path fill=\"none\" stroke=\"#000000\" d=\"M579.8125,-222.8796C594.8817,-211.0038 611.2803,-198.0804 626.2834,-186.2568\"/>\n",
"<polygon fill=\"#000000\" stroke=\"#000000\" points=\"628.7698,-188.7536 634.4575,-179.8149 624.4369,-183.2557 628.7698,-188.7536\"/>\n",
"</g>\n",
"<!-- 11 -->\n",
"<g id=\"node11\" class=\"node\">\n",
"<title>11</title>\n",
"<path fill=\"#399de5\" fill-opacity=\"0.082353\" stroke=\"#000000\" d=\"M503,-68C503,-68 393,-68 393,-68 387,-68 381,-62 381,-56 381,-56 381,-12 381,-12 381,-6 387,0 393,0 393,0 503,0 503,0 509,0 515,-6 515,-12 515,-12 515,-56 515,-56 515,-62 509,-68 503,-68\"/>\n",
"<text text-anchor=\"start\" x=\"404\" y=\"-52.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">gini = 0.499</text>\n",
"<text text-anchor=\"start\" x=\"399.5\" y=\"-37.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">samples = 46</text>\n",
"<text text-anchor=\"start\" x=\"390.5\" y=\"-22.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">value = [22, 24]</text>\n",
"<text text-anchor=\"start\" x=\"389\" y=\"-7.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">class = Survived</text>\n",
"</g>\n",
"<!-- 7&#45;&gt;11 -->\n",
"<g id=\"edge10\" class=\"edge\">\n",
"<title>7&#45;&gt;11</title>\n",
"<path fill=\"none\" stroke=\"#000000\" d=\"M495.7004,-103.9815C489.6238,-95.0666 483.1926,-85.6313 477.0868,-76.6734\"/>\n",
"<polygon fill=\"#000000\" stroke=\"#000000\" points=\"479.8789,-74.5555 471.3546,-68.2637 474.0948,-78.498 479.8789,-74.5555\"/>\n",
"</g>\n",
"<!-- 12 -->\n",
"<g id=\"node12\" class=\"node\">\n",
"<title>12</title>\n",
"<path fill=\"#e58139\" fill-opacity=\"0.686275\" stroke=\"#000000\" d=\"M652.5,-68C652.5,-68 545.5,-68 545.5,-68 539.5,-68 533.5,-62 533.5,-56 533.5,-56 533.5,-12 533.5,-12 533.5,-6 539.5,0 545.5,0 545.5,0 652.5,0 652.5,0 658.5,0 664.5,-6 664.5,-12 664.5,-12 664.5,-56 664.5,-56 664.5,-62 658.5,-68 652.5,-68\"/>\n",
"<text text-anchor=\"start\" x=\"555\" y=\"-52.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">gini = 0.364</text>\n",
"<text text-anchor=\"start\" x=\"550.5\" y=\"-37.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">samples = 46</text>\n",
"<text text-anchor=\"start\" x=\"541.5\" y=\"-22.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">value = [35, 11]</text>\n",
"<text text-anchor=\"start\" x=\"555\" y=\"-7.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">class = Died</text>\n",
"</g>\n",
"<!-- 7&#45;&gt;12 -->\n",
"<g id=\"edge11\" class=\"edge\">\n",
"<title>7&#45;&gt;12</title>\n",
"<path fill=\"none\" stroke=\"#000000\" d=\"M551.9272,-103.9815C557.9238,-95.0666 564.2704,-85.6313 570.2959,-76.6734\"/>\n",
"<polygon fill=\"#000000\" stroke=\"#000000\" points=\"573.2755,-78.5147 575.9527,-68.2637 567.4672,-74.6078 573.2755,-78.5147\"/>\n",
"</g>\n",
"</g>\n",
"</svg>\n"
],
"text/plain": [
"<graphviz.files.Source at 0x7f4623126e80>"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# To visualize our graph we use graphviz:\n",
"dot_data = tree.export_graphviz(clf, out_file=None, \n",
" feature_names=['Sex','Class','Age'], \n",
" class_names=['Died','Survived'], \n",
" filled=True, rounded=True, \n",
" special_characters=True,\n",
" leaves_parallel=False) \n",
"graph = graphviz.Source(dot_data)\n",
"graph"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"11.58 %\n"
]
}
],
"source": [
"# We'll use the \"predict_proba\" method to get predictions\n",
"prob = clf.predict_proba([[1,2,33]])[0,1]\n",
"print('{:.02f} %'.format(100*prob))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"For me, a 33 year old man who hasn't the money to travel first class, the odds to survive were: **11.58%** ... yikes \n",
"\n",
"**But wait** ! How good is our model anyway ? We need to test it. *So let's write an evaluation routine.*"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"collapsed": true,
"scrolled": true
},
"outputs": [],
"source": [
"# we define a lambda function for that\n",
"probs = lambda x: clf.predict([x])\n",
"# now use the eval data !\n",
"eval_samples = pd.concat([eval_data['SexNumerical'], eval_data['Pclass'], eval_data['Age']],axis=1)\n",
"eval_data['Predicted_Probablity'] = [probs(idx) for idx in eval_samples.values]"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def f(x): \n",
" return 0 if x['Predicted_Probablity'] != x['Survived'] else 1\n",
"\n",
"eval_data['prediction_was_correct'] = eval_data.apply(f, axis=1)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"80.56 % was correctly predicted\n"
]
}
],
"source": [
"precision = eval_data['prediction_was_correct'].sum(axis=0)/eval_data.shape[0]\n",
"# Get the predictions and print it out nicely\n",
"print('{:.02f} % was correctly predicted'.format(100*precision))"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"######################################################################################\n",
"###### This is the second part of the article. It is only for tutorial purposes ######\n",
"######################################################################################\n",
"\n",
"# we need linspace and isscalar from numpy\n",
"from numpy import linspace, isscalar\n",
"\n",
"#### we calculate everything for the root node as an example\n",
"\n",
"# The number of samples in the root node\n",
"n_all = len(train_data)\n",
"\n",
"# the proportional probabilites\n",
"p_mk = lambda n_m, indicator: (1/n_m)*indicator\n",
"# gini metric\n",
"gini = lambda pmk_0, pmk_1: pmk_0*(1-pmk_0)+pmk_1*(1-pmk_1)\n",
"# cost function\n",
"cost_fun = lambda n_0, n_1, gini_0, gini_1: (1/n_all)*(n_0*gini_0+n_1*gini_1)\n",
"# all the subsets\n",
"data_provider = lambda cond: train_data.where(cond).dropna(axis=0, how='any', \n",
" subset = ['Sex', 'Age', 'Pclass','Survived'])\n",
"\n",
"def get_subset(t_m, feat):\n",
" # smaller than t_m\n",
" subsets = [data_provider((train_data['Survived']==cat[0]) & (train_data[feat]<=cat[1])) \n",
" for cat in [[0,t_m],[1,t_m]]]\n",
" # larger than t_m\n",
" subsets.extend([data_provider((train_data['Survived']==cat[0]) & (train_data[feat]>=cat[1])) \n",
" for cat in [[0,t_m],[1,t_m]]])\n",
" return subsets\n",
"\n",
"def calculate_cost(data_subsets):\n",
" # the indicator function is just the length of each subset\n",
" n_k = [len(subset) for subset in data_subsets]\n",
" # splitting into smaller/larger\n",
" n_1 = sum(n_k[:2])\n",
" n_2 = sum(n_k[2:])\n",
"\n",
" # intermediate probabilities\n",
" pmk_1 = p_mk(n_1, n_k[0])\n",
" pmk_2 = p_mk(n_1, n_k[1])\n",
" pmk_3 = p_mk(n_2, n_k[2])\n",
" pmk_4 = p_mk(n_2, n_k[3])\n",
"\n",
" # gini for binary choices\n",
" gini_1 = gini(pmk_1, pmk_2)\n",
" gini_2 = gini(pmk_3, pmk_4)\n",
"\n",
" # impurity cost function\n",
" return cost_fun(n_1, n_2, gini_1, gini_2),gini_1 , gini_2"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The best results for each feature are:\n",
"'SexNumerical' -> Cost: 0.346, Threshold: 0.50, Gini_Left: 0.369, Gini_Right: 0.333\n",
"The best results for each feature are:\n",
"'Age' -> Cost: 0.470, Threshold: 6.05, Gini_Left: 0.393, Gini_Right: 0.475\n",
"The best results for each feature are:\n",
"'Pclass' -> Cost: 0.572, Threshold: 2.00, Gini_Left: 0.490, Gini_Right: 0.439\n"
]
}
],
"source": [
"# Tuple to account for all classes, we'll use 100 datapoints to rasterize Age\n",
"it_tuple = ((.5, 'SexNumerical'),((linspace(train_data['Age'].min(),train_data['Age'].max(),100)),'Age'), ((1,2,3),'Pclass'))\n",
"\n",
"# we store the results in a dict\n",
"results = {}\n",
"# helper variable\n",
"cost_start = 1\n",
"# iterate over all values in it_tuple\n",
"for item in it_tuple:\n",
" # if we don't have a scalar then iterate over to first item\n",
" if not isscalar(item[0]):\n",
" for thresholds in item[0]:\n",
" # get all four subsets\n",
" data_subsets = get_subset(thresholds, item[1])\n",
" # calculate the cost function, return it as well as both gini scores\n",
" cost, gini_imp_1 , gini_imp_2 = calculate_cost(data_subsets)\n",
" # here we need the helper, write only in the dict if the\n",
" # value is smaller then everything before it\n",
" if cost < cost_start:\n",
" results.update({item[1]: (thresholds, cost, gini_imp_1 , gini_imp_2)})\n",
" cost_start = cost\n",
" # reset cost_start\n",
" cost_start = 1\n",
" else:\n",
" # the same for scalars\n",
" data_subsets = get_subset(item[0], item[1])\n",
" cost, gini_imp_1 , gini_imp_2 = calculate_cost(data_subsets)\n",
" results.update({item[1]: (item[0], cost, gini_imp_1 , gini_imp_2)})\n",
"\n",
"# print out results\n",
"for feature, vals in results.items():\n",
" print('The best results for each feature are:')\n",
" print('\\'{}\\' -> Cost: {:.03f}, Threshold: {:.02f}, Gini_Left: {:.03f}, Gini_Right: {:.03f}'. format(feature, vals[1], vals[0], vals[2], vals[3]))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"gist_id": "45b94505b330378d12765e65b2814f6b",
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.