Skip to content

Instantly share code, notes, and snippets.

@mathcass
Created May 1, 2018 17:36
Show Gist options
  • Save mathcass/df19fd1e9faa884e6355e5b5a401e210 to your computer and use it in GitHub Desktop.
Save mathcass/df19fd1e9faa884e6355e5b5a401e210 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Digging into Scikit-Learn's DecisionTreeClassifier\n",
"\n",
"I was really curious about how Scikit-Learn's [DecisionTreeClassifier](http://scikit-learn.org/stable/modules/tree.html#tree) works. \n",
"\n",
"They provide a [base function](https://github.com/scikit-learn/scikit-learn/blob/4d9a12d175a38f2bcb720389ad2213f71a3d7697/sklearn/tree/export.py#L63) for visualizing it using `graphviz` that's really useful. But I wanted to know *how* it worked. Below is some investigation into what's going on under the hood."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"It made sense to take a look at how it works on the Iris dataset just to keep things simple. First, we'll train a classifier on the Iris dataset, then look at how the exported graph looks."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.datasets import load_iris\n",
"from sklearn import tree\n",
"iris = load_iris()\n",
"clf = tree.DecisionTreeClassifier()\n",
"clf = clf.fit(iris.data, iris.target)"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['sepal length (cm)',\n",
" 'sepal width (cm)',\n",
" 'petal length (cm)',\n",
" 'petal width (cm)']"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"iris.feature_names"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
"<!-- Generated by graphviz version 2.40.1 (20161225.0304)\n",
" -->\n",
"<!-- Title: Tree Pages: 1 -->\n",
"<svg width=\"883pt\" height=\"642pt\"\n",
" viewBox=\"0.00 0.00 882.83 642.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
"<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 638)\">\n",
"<title>Tree</title>\n",
"<polygon fill=\"#ffffff\" stroke=\"transparent\" points=\"-4,4 -4,-638 878.8306,-638 878.8306,4 -4,4\"/>\n",
"<!-- 0 -->\n",
"<g id=\"node1\" class=\"node\">\n",
"<title>0</title>\n",
"<path fill=\"transparent\" stroke=\"#000000\" d=\"M522.6091,-634C522.6091,-634 397.1458,-634 397.1458,-634 391.1458,-634 385.1458,-628 385.1458,-622 385.1458,-622 385.1458,-568 385.1458,-568 385.1458,-562 391.1458,-556 397.1458,-556 397.1458,-556 522.6091,-556 522.6091,-556 528.6091,-556 534.6091,-562 534.6091,-568 534.6091,-568 534.6091,-622 534.6091,-622 534.6091,-628 528.6091,-634 522.6091,-634\"/>\n",
"<text text-anchor=\"start\" x=\"393.0117\" y=\"-618.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">petal width (cm) ≤ 0.8</text>\n",
"<text text-anchor=\"start\" x=\"423.4863\" y=\"-604.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">gini = 0.667</text>\n",
"<text text-anchor=\"start\" x=\"414.1553\" y=\"-590.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">samples = 150</text>\n",
"<text text-anchor=\"start\" x=\"400.1382\" y=\"-576.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">value = [50, 50, 50]</text>\n",
"<text text-anchor=\"start\" x=\"415.3276\" y=\"-562.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">class = setosa</text>\n",
"</g>\n",
"<!-- 1 -->\n",
"<g id=\"node2\" class=\"node\">\n",
"<title>1</title>\n",
"<path fill=\"#e58139\" stroke=\"#000000\" d=\"M429.7837,-513C429.7837,-513 333.9712,-513 333.9712,-513 327.9712,-513 321.9712,-507 321.9712,-501 321.9712,-501 321.9712,-461 321.9712,-461 321.9712,-455 327.9712,-449 333.9712,-449 333.9712,-449 429.7837,-449 429.7837,-449 435.7837,-449 441.7837,-455 441.7837,-461 441.7837,-461 441.7837,-501 441.7837,-501 441.7837,-507 435.7837,-513 429.7837,-513\"/>\n",
"<text text-anchor=\"start\" x=\"353.2725\" y=\"-497.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">gini = 0.0</text>\n",
"<text text-anchor=\"start\" x=\"340.0483\" y=\"-483.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">samples = 50</text>\n",
"<text text-anchor=\"start\" x=\"329.9243\" y=\"-469.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">value = [50, 0, 0]</text>\n",
"<text text-anchor=\"start\" x=\"337.3276\" y=\"-455.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">class = setosa</text>\n",
"</g>\n",
"<!-- 0&#45;&gt;1 -->\n",
"<g id=\"edge1\" class=\"edge\">\n",
"<title>0&#45;&gt;1</title>\n",
"<path fill=\"none\" stroke=\"#000000\" d=\"M433.0343,-555.7677C425.4491,-544.6817 417.1823,-532.5994 409.5494,-521.4436\"/>\n",
"<polygon fill=\"#000000\" stroke=\"#000000\" points=\"412.3933,-519.4019 403.8578,-513.1252 406.6161,-523.3547 412.3933,-519.4019\"/>\n",
"<text text-anchor=\"middle\" x=\"399.2506\" y=\"-533.497\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">True</text>\n",
"</g>\n",
"<!-- 2 -->\n",
"<g id=\"node3\" class=\"node\">\n",
"<title>2</title>\n",
"<path fill=\"transparent\" stroke=\"#000000\" d=\"M605.3958,-520C605.3958,-520 472.3591,-520 472.3591,-520 466.3591,-520 460.3591,-514 460.3591,-508 460.3591,-508 460.3591,-454 460.3591,-454 460.3591,-448 466.3591,-442 472.3591,-442 472.3591,-442 605.3958,-442 605.3958,-442 611.3958,-442 617.3958,-448 617.3958,-454 617.3958,-454 617.3958,-508 617.3958,-508 617.3958,-514 611.3958,-520 605.3958,-520\"/>\n",
"<text text-anchor=\"start\" x=\"468.1187\" y=\"-504.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">petal width (cm) ≤ 1.75</text>\n",
"<text text-anchor=\"start\" x=\"510.2725\" y=\"-490.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">gini = 0.5</text>\n",
"<text text-anchor=\"start\" x=\"493.1553\" y=\"-476.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">samples = 100</text>\n",
"<text text-anchor=\"start\" x=\"483.0313\" y=\"-462.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">value = [0, 50, 50]</text>\n",
"<text text-anchor=\"start\" x=\"485\" y=\"-448.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">class = versicolor</text>\n",
"</g>\n",
"<!-- 0&#45;&gt;2 -->\n",
"<g id=\"edge2\" class=\"edge\">\n",
"<title>0&#45;&gt;2</title>\n",
"<path fill=\"none\" stroke=\"#000000\" d=\"M487.0648,-555.7677C493.0853,-547.0798 499.5298,-537.7801 505.7571,-528.794\"/>\n",
"<polygon fill=\"#000000\" stroke=\"#000000\" points=\"508.8209,-530.5176 511.64,-520.3046 503.0673,-526.5304 508.8209,-530.5176\"/>\n",
"<text text-anchor=\"middle\" x=\"516.1009\" y=\"-540.7034\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">False</text>\n",
"</g>\n",
"<!-- 3 -->\n",
"<g id=\"node4\" class=\"node\">\n",
"<title>3</title>\n",
"<path fill=\"#39e581\" fill-opacity=\"0.898039\" stroke=\"#000000\" d=\"M493.3569,-406C493.3569,-406 354.3979,-406 354.3979,-406 348.3979,-406 342.3979,-400 342.3979,-394 342.3979,-394 342.3979,-340 342.3979,-340 342.3979,-334 348.3979,-328 354.3979,-328 354.3979,-328 493.3569,-328 493.3569,-328 499.3569,-328 505.3569,-334 505.3569,-340 505.3569,-340 505.3569,-394 505.3569,-394 505.3569,-400 499.3569,-406 493.3569,-406\"/>\n",
"<text text-anchor=\"start\" x=\"350.3877\" y=\"-390.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">petal length (cm) ≤ 4.95</text>\n",
"<text text-anchor=\"start\" x=\"387.4863\" y=\"-376.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">gini = 0.168</text>\n",
"<text text-anchor=\"start\" x=\"382.0483\" y=\"-362.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">samples = 54</text>\n",
"<text text-anchor=\"start\" x=\"371.9243\" y=\"-348.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">value = [0, 49, 5]</text>\n",
"<text text-anchor=\"start\" x=\"370\" y=\"-334.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">class = versicolor</text>\n",
"</g>\n",
"<!-- 2&#45;&gt;3 -->\n",
"<g id=\"edge3\" class=\"edge\">\n",
"<title>2&#45;&gt;3</title>\n",
"<path fill=\"none\" stroke=\"#000000\" d=\"M499.301,-441.7677C490.0852,-432.632 480.1868,-422.8198 470.6915,-413.407\"/>\n",
"<polygon fill=\"#000000\" stroke=\"#000000\" points=\"473.0928,-410.8591 463.5269,-406.3046 468.1647,-415.8305 473.0928,-410.8591\"/>\n",
"</g>\n",
"<!-- 12 -->\n",
"<g id=\"node13\" class=\"node\">\n",
"<title>12</title>\n",
"<path fill=\"#8139e5\" fill-opacity=\"0.976471\" stroke=\"#000000\" d=\"M723.3569,-406C723.3569,-406 584.3979,-406 584.3979,-406 578.3979,-406 572.3979,-400 572.3979,-394 572.3979,-394 572.3979,-340 572.3979,-340 572.3979,-334 578.3979,-328 584.3979,-328 584.3979,-328 723.3569,-328 723.3569,-328 729.3569,-328 735.3569,-334 735.3569,-340 735.3569,-340 735.3569,-394 735.3569,-394 735.3569,-400 729.3569,-406 723.3569,-406\"/>\n",
"<text text-anchor=\"start\" x=\"580.3877\" y=\"-390.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">petal length (cm) ≤ 4.85</text>\n",
"<text text-anchor=\"start\" x=\"617.4863\" y=\"-376.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">gini = 0.043</text>\n",
"<text text-anchor=\"start\" x=\"612.0483\" y=\"-362.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">samples = 46</text>\n",
"<text text-anchor=\"start\" x=\"601.9243\" y=\"-348.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">value = [0, 1, 45]</text>\n",
"<text text-anchor=\"start\" x=\"604.2759\" y=\"-334.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">class = virginica</text>\n",
"</g>\n",
"<!-- 2&#45;&gt;12 -->\n",
"<g id=\"edge12\" class=\"edge\">\n",
"<title>2&#45;&gt;12</title>\n",
"<path fill=\"none\" stroke=\"#000000\" d=\"M578.4539,-441.7677C587.6697,-432.632 597.568,-422.8198 607.0634,-413.407\"/>\n",
"<polygon fill=\"#000000\" stroke=\"#000000\" points=\"609.5902,-415.8305 614.228,-406.3046 604.6621,-410.8591 609.5902,-415.8305\"/>\n",
"</g>\n",
"<!-- 4 -->\n",
"<g id=\"node5\" class=\"node\">\n",
"<title>4</title>\n",
"<path fill=\"#39e581\" fill-opacity=\"0.980392\" stroke=\"#000000\" d=\"M265.3958,-292C265.3958,-292 132.3591,-292 132.3591,-292 126.3591,-292 120.3591,-286 120.3591,-280 120.3591,-280 120.3591,-226 120.3591,-226 120.3591,-220 126.3591,-214 132.3591,-214 132.3591,-214 265.3958,-214 265.3958,-214 271.3958,-214 277.3958,-220 277.3958,-226 277.3958,-226 277.3958,-280 277.3958,-280 277.3958,-286 271.3958,-292 265.3958,-292\"/>\n",
"<text text-anchor=\"start\" x=\"128.1187\" y=\"-276.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">petal width (cm) ≤ 1.65</text>\n",
"<text text-anchor=\"start\" x=\"162.4863\" y=\"-262.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">gini = 0.041</text>\n",
"<text text-anchor=\"start\" x=\"157.0483\" y=\"-248.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">samples = 48</text>\n",
"<text text-anchor=\"start\" x=\"146.9243\" y=\"-234.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">value = [0, 47, 1]</text>\n",
"<text text-anchor=\"start\" x=\"145\" y=\"-220.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">class = versicolor</text>\n",
"</g>\n",
"<!-- 3&#45;&gt;4 -->\n",
"<g id=\"edge4\" class=\"edge\">\n",
"<title>3&#45;&gt;4</title>\n",
"<path fill=\"none\" stroke=\"#000000\" d=\"M346.76,-327.9272C326.8134,-317.8209 305.2147,-306.8775 284.8464,-296.5576\"/>\n",
"<polygon fill=\"#000000\" stroke=\"#000000\" points=\"286.3964,-293.4194 275.8941,-292.0218 283.2326,-299.6636 286.3964,-293.4194\"/>\n",
"</g>\n",
"<!-- 7 -->\n",
"<g id=\"node8\" class=\"node\">\n",
"<title>7</title>\n",
"<path fill=\"#8139e5\" fill-opacity=\"0.498039\" stroke=\"#000000\" d=\"M490.3958,-292C490.3958,-292 357.3591,-292 357.3591,-292 351.3591,-292 345.3591,-286 345.3591,-280 345.3591,-280 345.3591,-226 345.3591,-226 345.3591,-220 351.3591,-214 357.3591,-214 357.3591,-214 490.3958,-214 490.3958,-214 496.3958,-214 502.3958,-220 502.3958,-226 502.3958,-226 502.3958,-280 502.3958,-280 502.3958,-286 496.3958,-292 490.3958,-292\"/>\n",
"<text text-anchor=\"start\" x=\"353.1187\" y=\"-276.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">petal width (cm) ≤ 1.55</text>\n",
"<text text-anchor=\"start\" x=\"387.4863\" y=\"-262.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">gini = 0.444</text>\n",
"<text text-anchor=\"start\" x=\"385.9414\" y=\"-248.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">samples = 6</text>\n",
"<text text-anchor=\"start\" x=\"375.8174\" y=\"-234.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">value = [0, 2, 4]</text>\n",
"<text text-anchor=\"start\" x=\"374.2759\" y=\"-220.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">class = virginica</text>\n",
"</g>\n",
"<!-- 3&#45;&gt;7 -->\n",
"<g id=\"edge7\" class=\"edge\">\n",
"<title>3&#45;&gt;7</title>\n",
"<path fill=\"none\" stroke=\"#000000\" d=\"M423.8774,-327.7677C423.8774,-319.6172 423.8774,-310.9283 423.8774,-302.4649\"/>\n",
"<polygon fill=\"#000000\" stroke=\"#000000\" points=\"427.3775,-302.3046 423.8774,-292.3046 420.3775,-302.3047 427.3775,-302.3046\"/>\n",
"</g>\n",
"<!-- 5 -->\n",
"<g id=\"node6\" class=\"node\">\n",
"<title>5</title>\n",
"<path fill=\"#39e581\" stroke=\"#000000\" d=\"M111.6326,-171C111.6326,-171 12.1223,-171 12.1223,-171 6.1223,-171 .1223,-165 .1223,-159 .1223,-159 .1223,-119 .1223,-119 .1223,-113 6.1223,-107 12.1223,-107 12.1223,-107 111.6326,-107 111.6326,-107 117.6326,-107 123.6326,-113 123.6326,-119 123.6326,-119 123.6326,-159 123.6326,-159 123.6326,-165 117.6326,-171 111.6326,-171\"/>\n",
"<text text-anchor=\"start\" x=\"33.2725\" y=\"-155.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">gini = 0.0</text>\n",
"<text text-anchor=\"start\" x=\"20.0483\" y=\"-141.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">samples = 47</text>\n",
"<text text-anchor=\"start\" x=\"9.9243\" y=\"-127.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">value = [0, 47, 0]</text>\n",
"<text text-anchor=\"start\" x=\"8\" y=\"-113.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">class = versicolor</text>\n",
"</g>\n",
"<!-- 4&#45;&gt;5 -->\n",
"<g id=\"edge5\" class=\"edge\">\n",
"<title>4&#45;&gt;5</title>\n",
"<path fill=\"none\" stroke=\"#000000\" d=\"M151.7298,-213.7677C137.7541,-202.1383 122.4609,-189.4125 108.5175,-177.81\"/>\n",
"<polygon fill=\"#000000\" stroke=\"#000000\" points=\"110.4096,-174.8312 100.484,-171.1252 105.9321,-180.2119 110.4096,-174.8312\"/>\n",
"</g>\n",
"<!-- 6 -->\n",
"<g id=\"node7\" class=\"node\">\n",
"<title>6</title>\n",
"<path fill=\"#8139e5\" stroke=\"#000000\" d=\"M244.5807,-171C244.5807,-171 153.1741,-171 153.1741,-171 147.1741,-171 141.1741,-165 141.1741,-159 141.1741,-159 141.1741,-119 141.1741,-119 141.1741,-113 147.1741,-107 153.1741,-107 153.1741,-107 244.5807,-107 244.5807,-107 250.5807,-107 256.5807,-113 256.5807,-119 256.5807,-119 256.5807,-159 256.5807,-159 256.5807,-165 250.5807,-171 244.5807,-171\"/>\n",
"<text text-anchor=\"start\" x=\"170.2725\" y=\"-155.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">gini = 0.0</text>\n",
"<text text-anchor=\"start\" x=\"160.9414\" y=\"-141.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">samples = 1</text>\n",
"<text text-anchor=\"start\" x=\"150.8174\" y=\"-127.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">value = [0, 0, 1]</text>\n",
"<text text-anchor=\"start\" x=\"149.2759\" y=\"-113.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">class = virginica</text>\n",
"</g>\n",
"<!-- 4&#45;&gt;6 -->\n",
"<g id=\"edge6\" class=\"edge\">\n",
"<title>4&#45;&gt;6</title>\n",
"<path fill=\"none\" stroke=\"#000000\" d=\"M198.8774,-213.7677C198.8774,-203.3338 198.8774,-192.0174 198.8774,-181.4215\"/>\n",
"<polygon fill=\"#000000\" stroke=\"#000000\" points=\"202.3775,-181.1252 198.8774,-171.1252 195.3775,-181.1252 202.3775,-181.1252\"/>\n",
"</g>\n",
"<!-- 8 -->\n",
"<g id=\"node9\" class=\"node\">\n",
"<title>8</title>\n",
"<path fill=\"#8139e5\" stroke=\"#000000\" d=\"M377.5807,-171C377.5807,-171 286.1741,-171 286.1741,-171 280.1741,-171 274.1741,-165 274.1741,-159 274.1741,-159 274.1741,-119 274.1741,-119 274.1741,-113 280.1741,-107 286.1741,-107 286.1741,-107 377.5807,-107 377.5807,-107 383.5807,-107 389.5807,-113 389.5807,-119 389.5807,-119 389.5807,-159 389.5807,-159 389.5807,-165 383.5807,-171 377.5807,-171\"/>\n",
"<text text-anchor=\"start\" x=\"303.2725\" y=\"-155.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">gini = 0.0</text>\n",
"<text text-anchor=\"start\" x=\"293.9414\" y=\"-141.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">samples = 3</text>\n",
"<text text-anchor=\"start\" x=\"283.8174\" y=\"-127.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">value = [0, 0, 3]</text>\n",
"<text text-anchor=\"start\" x=\"282.2759\" y=\"-113.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">class = virginica</text>\n",
"</g>\n",
"<!-- 7&#45;&gt;8 -->\n",
"<g id=\"edge8\" class=\"edge\">\n",
"<title>7&#45;&gt;8</title>\n",
"<path fill=\"none\" stroke=\"#000000\" d=\"M392.2163,-213.7677C383.182,-202.573 373.3278,-190.3624 364.2516,-179.1158\"/>\n",
"<polygon fill=\"#000000\" stroke=\"#000000\" points=\"366.807,-176.7091 357.803,-171.1252 361.3596,-181.1053 366.807,-176.7091\"/>\n",
"</g>\n",
"<!-- 9 -->\n",
"<g id=\"node10\" class=\"node\">\n",
"<title>9</title>\n",
"<path fill=\"#39e581\" fill-opacity=\"0.498039\" stroke=\"#000000\" d=\"M561.9673,-178C561.9673,-178 419.7876,-178 419.7876,-178 413.7876,-178 407.7876,-172 407.7876,-166 407.7876,-166 407.7876,-112 407.7876,-112 407.7876,-106 413.7876,-100 419.7876,-100 419.7876,-100 561.9673,-100 561.9673,-100 567.9673,-100 573.9673,-106 573.9673,-112 573.9673,-112 573.9673,-166 573.9673,-166 573.9673,-172 567.9673,-178 561.9673,-178\"/>\n",
"<text text-anchor=\"start\" x=\"415.8325\" y=\"-162.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">sepal length (cm) ≤ 6.95</text>\n",
"<text text-anchor=\"start\" x=\"454.4863\" y=\"-148.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">gini = 0.444</text>\n",
"<text text-anchor=\"start\" x=\"452.9414\" y=\"-134.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">samples = 3</text>\n",
"<text text-anchor=\"start\" x=\"442.8174\" y=\"-120.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">value = [0, 2, 1]</text>\n",
"<text text-anchor=\"start\" x=\"437\" y=\"-106.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">class = versicolor</text>\n",
"</g>\n",
"<!-- 7&#45;&gt;9 -->\n",
"<g id=\"edge9\" class=\"edge\">\n",
"<title>7&#45;&gt;9</title>\n",
"<path fill=\"none\" stroke=\"#000000\" d=\"M446.935,-213.7677C451.9884,-205.1694 457.394,-195.9718 462.6246,-187.072\"/>\n",
"<polygon fill=\"#000000\" stroke=\"#000000\" points=\"465.7279,-188.6994 467.7773,-178.3046 459.6929,-185.1525 465.7279,-188.6994\"/>\n",
"</g>\n",
"<!-- 10 -->\n",
"<g id=\"node11\" class=\"node\">\n",
"<title>10</title>\n",
"<path fill=\"#39e581\" stroke=\"#000000\" d=\"M471.6326,-64C471.6326,-64 372.1223,-64 372.1223,-64 366.1223,-64 360.1223,-58 360.1223,-52 360.1223,-52 360.1223,-12 360.1223,-12 360.1223,-6 366.1223,0 372.1223,0 372.1223,0 471.6326,0 471.6326,0 477.6326,0 483.6326,-6 483.6326,-12 483.6326,-12 483.6326,-52 483.6326,-52 483.6326,-58 477.6326,-64 471.6326,-64\"/>\n",
"<text text-anchor=\"start\" x=\"393.2725\" y=\"-48.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">gini = 0.0</text>\n",
"<text text-anchor=\"start\" x=\"383.9414\" y=\"-34.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">samples = 2</text>\n",
"<text text-anchor=\"start\" x=\"373.8174\" y=\"-20.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">value = [0, 2, 0]</text>\n",
"<text text-anchor=\"start\" x=\"368\" y=\"-6.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">class = versicolor</text>\n",
"</g>\n",
"<!-- 9&#45;&gt;10 -->\n",
"<g id=\"edge10\" class=\"edge\">\n",
"<title>9&#45;&gt;10</title>\n",
"<path fill=\"none\" stroke=\"#000000\" d=\"M465.5762,-99.7647C459.9223,-90.9971 453.9128,-81.678 448.2179,-72.8469\"/>\n",
"<polygon fill=\"#000000\" stroke=\"#000000\" points=\"451.0138,-70.7242 442.6528,-64.2169 445.1309,-74.5178 451.0138,-70.7242\"/>\n",
"</g>\n",
"<!-- 11 -->\n",
"<g id=\"node12\" class=\"node\">\n",
"<title>11</title>\n",
"<path fill=\"#8139e5\" stroke=\"#000000\" d=\"M604.5807,-64C604.5807,-64 513.1741,-64 513.1741,-64 507.1741,-64 501.1741,-58 501.1741,-52 501.1741,-52 501.1741,-12 501.1741,-12 501.1741,-6 507.1741,0 513.1741,0 513.1741,0 604.5807,0 604.5807,0 610.5807,0 616.5807,-6 616.5807,-12 616.5807,-12 616.5807,-52 616.5807,-52 616.5807,-58 610.5807,-64 604.5807,-64\"/>\n",
"<text text-anchor=\"start\" x=\"530.2725\" y=\"-48.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">gini = 0.0</text>\n",
"<text text-anchor=\"start\" x=\"520.9414\" y=\"-34.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">samples = 1</text>\n",
"<text text-anchor=\"start\" x=\"510.8174\" y=\"-20.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">value = [0, 0, 1]</text>\n",
"<text text-anchor=\"start\" x=\"509.2759\" y=\"-6.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">class = virginica</text>\n",
"</g>\n",
"<!-- 9&#45;&gt;11 -->\n",
"<g id=\"edge11\" class=\"edge\">\n",
"<title>9&#45;&gt;11</title>\n",
"<path fill=\"none\" stroke=\"#000000\" d=\"M515.812,-99.7647C521.384,-90.9971 527.3064,-81.678 532.9187,-72.8469\"/>\n",
"<polygon fill=\"#000000\" stroke=\"#000000\" points=\"535.9934,-74.534 538.4032,-64.2169 530.0855,-70.7795 535.9934,-74.534\"/>\n",
"</g>\n",
"<!-- 13 -->\n",
"<g id=\"node14\" class=\"node\">\n",
"<title>13</title>\n",
"<path fill=\"#8139e5\" fill-opacity=\"0.498039\" stroke=\"#000000\" d=\"M724.9673,-292C724.9673,-292 582.7876,-292 582.7876,-292 576.7876,-292 570.7876,-286 570.7876,-280 570.7876,-280 570.7876,-226 570.7876,-226 570.7876,-220 576.7876,-214 582.7876,-214 582.7876,-214 724.9673,-214 724.9673,-214 730.9673,-214 736.9673,-220 736.9673,-226 736.9673,-226 736.9673,-280 736.9673,-280 736.9673,-286 730.9673,-292 724.9673,-292\"/>\n",
"<text text-anchor=\"start\" x=\"578.8325\" y=\"-276.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">sepal length (cm) ≤ 5.95</text>\n",
"<text text-anchor=\"start\" x=\"617.4863\" y=\"-262.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">gini = 0.444</text>\n",
"<text text-anchor=\"start\" x=\"615.9414\" y=\"-248.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">samples = 3</text>\n",
"<text text-anchor=\"start\" x=\"605.8174\" y=\"-234.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">value = [0, 1, 2]</text>\n",
"<text text-anchor=\"start\" x=\"604.2759\" y=\"-220.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">class = virginica</text>\n",
"</g>\n",
"<!-- 12&#45;&gt;13 -->\n",
"<g id=\"edge13\" class=\"edge\">\n",
"<title>12&#45;&gt;13</title>\n",
"<path fill=\"none\" stroke=\"#000000\" d=\"M653.8774,-327.7677C653.8774,-319.6172 653.8774,-310.9283 653.8774,-302.4649\"/>\n",
"<polygon fill=\"#000000\" stroke=\"#000000\" points=\"657.3775,-302.3046 653.8774,-292.3046 650.3775,-302.3047 657.3775,-302.3046\"/>\n",
"</g>\n",
"<!-- 16 -->\n",
"<g id=\"node17\" class=\"node\">\n",
"<title>16</title>\n",
"<path fill=\"#8139e5\" stroke=\"#000000\" d=\"M862.7837,-285C862.7837,-285 766.9712,-285 766.9712,-285 760.9712,-285 754.9712,-279 754.9712,-273 754.9712,-273 754.9712,-233 754.9712,-233 754.9712,-227 760.9712,-221 766.9712,-221 766.9712,-221 862.7837,-221 862.7837,-221 868.7837,-221 874.7837,-227 874.7837,-233 874.7837,-233 874.7837,-273 874.7837,-273 874.7837,-279 868.7837,-285 862.7837,-285\"/>\n",
"<text text-anchor=\"start\" x=\"786.2725\" y=\"-269.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">gini = 0.0</text>\n",
"<text text-anchor=\"start\" x=\"773.0483\" y=\"-255.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">samples = 43</text>\n",
"<text text-anchor=\"start\" x=\"762.9243\" y=\"-241.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">value = [0, 0, 43]</text>\n",
"<text text-anchor=\"start\" x=\"765.2759\" y=\"-227.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">class = virginica</text>\n",
"</g>\n",
"<!-- 12&#45;&gt;16 -->\n",
"<g id=\"edge16\" class=\"edge\">\n",
"<title>12&#45;&gt;16</title>\n",
"<path fill=\"none\" stroke=\"#000000\" d=\"M709.2845,-327.7677C726.0155,-315.9209 744.3533,-302.9364 760.984,-291.1606\"/>\n",
"<polygon fill=\"#000000\" stroke=\"#000000\" points=\"763.369,-293.7604 769.5076,-285.1252 759.3238,-288.0476 763.369,-293.7604\"/>\n",
"</g>\n",
"<!-- 14 -->\n",
"<g id=\"node15\" class=\"node\">\n",
"<title>14</title>\n",
"<path fill=\"#39e581\" stroke=\"#000000\" d=\"M703.6326,-171C703.6326,-171 604.1223,-171 604.1223,-171 598.1223,-171 592.1223,-165 592.1223,-159 592.1223,-159 592.1223,-119 592.1223,-119 592.1223,-113 598.1223,-107 604.1223,-107 604.1223,-107 703.6326,-107 703.6326,-107 709.6326,-107 715.6326,-113 715.6326,-119 715.6326,-119 715.6326,-159 715.6326,-159 715.6326,-165 709.6326,-171 703.6326,-171\"/>\n",
"<text text-anchor=\"start\" x=\"625.2725\" y=\"-155.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">gini = 0.0</text>\n",
"<text text-anchor=\"start\" x=\"615.9414\" y=\"-141.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">samples = 1</text>\n",
"<text text-anchor=\"start\" x=\"605.8174\" y=\"-127.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">value = [0, 1, 0]</text>\n",
"<text text-anchor=\"start\" x=\"600\" y=\"-113.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">class = versicolor</text>\n",
"</g>\n",
"<!-- 13&#45;&gt;14 -->\n",
"<g id=\"edge14\" class=\"edge\">\n",
"<title>13&#45;&gt;14</title>\n",
"<path fill=\"none\" stroke=\"#000000\" d=\"M653.8774,-213.7677C653.8774,-203.3338 653.8774,-192.0174 653.8774,-181.4215\"/>\n",
"<polygon fill=\"#000000\" stroke=\"#000000\" points=\"657.3775,-181.1252 653.8774,-171.1252 650.3775,-181.1252 657.3775,-181.1252\"/>\n",
"</g>\n",
"<!-- 15 -->\n",
"<g id=\"node16\" class=\"node\">\n",
"<title>15</title>\n",
"<path fill=\"#8139e5\" stroke=\"#000000\" d=\"M836.5807,-171C836.5807,-171 745.1741,-171 745.1741,-171 739.1741,-171 733.1741,-165 733.1741,-159 733.1741,-159 733.1741,-119 733.1741,-119 733.1741,-113 739.1741,-107 745.1741,-107 745.1741,-107 836.5807,-107 836.5807,-107 842.5807,-107 848.5807,-113 848.5807,-119 848.5807,-119 848.5807,-159 848.5807,-159 848.5807,-165 842.5807,-171 836.5807,-171\"/>\n",
"<text text-anchor=\"start\" x=\"762.2725\" y=\"-155.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">gini = 0.0</text>\n",
"<text text-anchor=\"start\" x=\"752.9414\" y=\"-141.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">samples = 2</text>\n",
"<text text-anchor=\"start\" x=\"742.8174\" y=\"-127.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">value = [0, 0, 2]</text>\n",
"<text text-anchor=\"start\" x=\"741.2759\" y=\"-113.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\" fill=\"#000000\">class = virginica</text>\n",
"</g>\n",
"<!-- 13&#45;&gt;15 -->\n",
"<g id=\"edge15\" class=\"edge\">\n",
"<title>13&#45;&gt;15</title>\n",
"<path fill=\"none\" stroke=\"#000000\" d=\"M701.0251,-213.7677C715.0008,-202.1383 730.294,-189.4125 744.2374,-177.81\"/>\n",
"<polygon fill=\"#000000\" stroke=\"#000000\" points=\"746.8227,-180.2119 752.2708,-171.1252 742.3453,-174.8312 746.8227,-180.2119\"/>\n",
"</g>\n",
"</g>\n",
"</svg>\n"
],
"text/plain": [
"<graphviz.files.Source at 0x10d9bcfd0>"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import graphviz \n",
"dot_data = tree.export_graphviz(clf, out_file=None) \n",
"graph = graphviz.Source(dot_data) \n",
"graph.render(\"iris\")\n",
"dot_data = tree.export_graphviz(clf, out_file=None, \n",
" feature_names=iris.feature_names, \n",
" class_names=iris.target_names, \n",
" filled=True, rounded=True, \n",
" special_characters=True) \n",
"graph = graphviz.Source(dot_data)\n",
"graph"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As we can see, the exported graph does a good job of visually showing us how the tree works."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Tree Structure Exploration\n",
"\n",
"Feel free to skip ahead."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['__abstractmethods__',\n",
" '__class__',\n",
" '__delattr__',\n",
" '__dict__',\n",
" '__dir__',\n",
" '__doc__',\n",
" '__eq__',\n",
" '__format__',\n",
" '__ge__',\n",
" '__getattribute__',\n",
" '__getstate__',\n",
" '__gt__',\n",
" '__hash__',\n",
" '__init__',\n",
" '__init_subclass__',\n",
" '__le__',\n",
" '__lt__',\n",
" '__module__',\n",
" '__ne__',\n",
" '__new__',\n",
" '__reduce__',\n",
" '__reduce_ex__',\n",
" '__repr__',\n",
" '__setattr__',\n",
" '__setstate__',\n",
" '__sizeof__',\n",
" '__str__',\n",
" '__subclasshook__',\n",
" '__weakref__',\n",
" '_abc_cache',\n",
" '_abc_negative_cache',\n",
" '_abc_negative_cache_version',\n",
" '_abc_registry',\n",
" '_estimator_type',\n",
" '_get_param_names',\n",
" '_validate_X_predict',\n",
" 'apply',\n",
" 'class_weight',\n",
" 'classes_',\n",
" 'criterion',\n",
" 'decision_path',\n",
" 'feature_importances_',\n",
" 'fit',\n",
" 'get_params',\n",
" 'max_depth',\n",
" 'max_features',\n",
" 'max_features_',\n",
" 'max_leaf_nodes',\n",
" 'min_impurity_decrease',\n",
" 'min_impurity_split',\n",
" 'min_samples_leaf',\n",
" 'min_samples_split',\n",
" 'min_weight_fraction_leaf',\n",
" 'n_classes_',\n",
" 'n_features_',\n",
" 'n_outputs_',\n",
" 'predict',\n",
" 'predict_log_proba',\n",
" 'predict_proba',\n",
" 'presort',\n",
" 'random_state',\n",
" 'score',\n",
" 'set_params',\n",
" 'splitter',\n",
" 'tree_']"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dir(clf)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['__class__',\n",
" '__delattr__',\n",
" '__dir__',\n",
" '__doc__',\n",
" '__eq__',\n",
" '__format__',\n",
" '__ge__',\n",
" '__getattribute__',\n",
" '__getstate__',\n",
" '__gt__',\n",
" '__hash__',\n",
" '__init__',\n",
" '__init_subclass__',\n",
" '__le__',\n",
" '__lt__',\n",
" '__ne__',\n",
" '__new__',\n",
" '__pyx_vtable__',\n",
" '__reduce__',\n",
" '__reduce_ex__',\n",
" '__repr__',\n",
" '__setattr__',\n",
" '__setstate__',\n",
" '__sizeof__',\n",
" '__str__',\n",
" '__subclasshook__',\n",
" 'apply',\n",
" 'capacity',\n",
" 'children_left',\n",
" 'children_right',\n",
" 'compute_feature_importances',\n",
" 'decision_path',\n",
" 'feature',\n",
" 'impurity',\n",
" 'max_depth',\n",
" 'max_n_classes',\n",
" 'n_classes',\n",
" 'n_features',\n",
" 'n_node_samples',\n",
" 'n_outputs',\n",
" 'node_count',\n",
" 'predict',\n",
" 'threshold',\n",
" 'value',\n",
" 'weighted_n_node_samples']"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dir(clf.tree_)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"17"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(clf.tree_.feature)"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([ 3, -2, 3, 2, 3, -2, -2, 3, -2, 0, -2, -2, 2, 0, -2, -2, -2])"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"clf.tree_.feature"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"17"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(clf.tree_.n_node_samples)"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([150, 50, 100, 54, 48, 47, 1, 6, 3, 3, 2, 1, 46,\n",
" 3, 1, 2, 43])"
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"clf.tree_.n_node_samples"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"17"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(clf.tree_.children_left)"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([ 1, -1, 3, 4, 5, -1, -1, 8, -1, 10, -1, -1, 13, 14, -1, -1, -1])"
]
},
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"clf.tree_.children_left"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"17"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(clf.tree_.children_right)"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([ 2, -1, 12, 7, 6, -1, -1, 9, -1, 11, -1, -1, 16, 15, -1, -1, -1])"
]
},
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"clf.tree_.children_right"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([ 0.80000001, -2. , 1.75 , 4.94999981, 1.6500001 ,\n",
" -2. , -2. , 1.54999995, -2. , 6.94999981,\n",
" -2. , -2. , 4.85000038, 5.94999981, -2. ,\n",
" -2. , -2. ])"
]
},
"execution_count": 30,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"clf.tree_.threshold"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[[50., 50., 50.]],\n",
"\n",
" [[50., 0., 0.]],\n",
"\n",
" [[ 0., 50., 50.]],\n",
"\n",
" [[ 0., 49., 5.]],\n",
"\n",
" [[ 0., 47., 1.]],\n",
"\n",
" [[ 0., 47., 0.]],\n",
"\n",
" [[ 0., 0., 1.]],\n",
"\n",
" [[ 0., 2., 4.]],\n",
"\n",
" [[ 0., 0., 3.]],\n",
"\n",
" [[ 0., 2., 1.]],\n",
"\n",
" [[ 0., 2., 0.]],\n",
"\n",
" [[ 0., 0., 1.]],\n",
"\n",
" [[ 0., 1., 45.]],\n",
"\n",
" [[ 0., 1., 2.]],\n",
"\n",
" [[ 0., 1., 0.]],\n",
"\n",
" [[ 0., 0., 2.]],\n",
"\n",
" [[ 0., 0., 43.]]])"
]
},
"execution_count": 31,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"clf.tree_.value"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Explanation of Tree Structure\n",
"\n",
"Below, we see how we can pick apart the tree and the values."
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>feature</th>\n",
" <th>left</th>\n",
" <th>mapped_feature</th>\n",
" <th>n_node_samples</th>\n",
" <th>right</th>\n",
" <th>thresholds</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>2</td>\n",
" <td>1</td>\n",
" <td>petal length (cm)</td>\n",
" <td>150</td>\n",
" <td>2</td>\n",
" <td>2.45</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>-2</td>\n",
" <td>-1</td>\n",
" <td>None</td>\n",
" <td>50</td>\n",
" <td>-1</td>\n",
" <td>-2.00</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>3</td>\n",
" <td>3</td>\n",
" <td>petal width (cm)</td>\n",
" <td>100</td>\n",
" <td>12</td>\n",
" <td>1.75</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>2</td>\n",
" <td>4</td>\n",
" <td>petal length (cm)</td>\n",
" <td>54</td>\n",
" <td>7</td>\n",
" <td>4.95</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>3</td>\n",
" <td>5</td>\n",
" <td>petal width (cm)</td>\n",
" <td>48</td>\n",
" <td>6</td>\n",
" <td>1.65</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>-2</td>\n",
" <td>-1</td>\n",
" <td>None</td>\n",
" <td>47</td>\n",
" <td>-1</td>\n",
" <td>-2.00</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>-2</td>\n",
" <td>-1</td>\n",
" <td>None</td>\n",
" <td>1</td>\n",
" <td>-1</td>\n",
" <td>-2.00</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <td>3</td>\n",
" <td>8</td>\n",
" <td>petal width (cm)</td>\n",
" <td>6</td>\n",
" <td>9</td>\n",
" <td>1.55</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <td>-2</td>\n",
" <td>-1</td>\n",
" <td>None</td>\n",
" <td>3</td>\n",
" <td>-1</td>\n",
" <td>-2.00</td>\n",
" </tr>\n",
" <tr>\n",
" <th>9</th>\n",
" <td>2</td>\n",
" <td>10</td>\n",
" <td>petal length (cm)</td>\n",
" <td>3</td>\n",
" <td>11</td>\n",
" <td>5.45</td>\n",
" </tr>\n",
" <tr>\n",
" <th>10</th>\n",
" <td>-2</td>\n",
" <td>-1</td>\n",
" <td>None</td>\n",
" <td>2</td>\n",
" <td>-1</td>\n",
" <td>-2.00</td>\n",
" </tr>\n",
" <tr>\n",
" <th>11</th>\n",
" <td>-2</td>\n",
" <td>-1</td>\n",
" <td>None</td>\n",
" <td>1</td>\n",
" <td>-1</td>\n",
" <td>-2.00</td>\n",
" </tr>\n",
" <tr>\n",
" <th>12</th>\n",
" <td>2</td>\n",
" <td>13</td>\n",
" <td>petal length (cm)</td>\n",
" <td>46</td>\n",
" <td>16</td>\n",
" <td>4.85</td>\n",
" </tr>\n",
" <tr>\n",
" <th>13</th>\n",
" <td>0</td>\n",
" <td>14</td>\n",
" <td>sepal length (cm)</td>\n",
" <td>3</td>\n",
" <td>15</td>\n",
" <td>5.95</td>\n",
" </tr>\n",
" <tr>\n",
" <th>14</th>\n",
" <td>-2</td>\n",
" <td>-1</td>\n",
" <td>None</td>\n",
" <td>1</td>\n",
" <td>-1</td>\n",
" <td>-2.00</td>\n",
" </tr>\n",
" <tr>\n",
" <th>15</th>\n",
" <td>-2</td>\n",
" <td>-1</td>\n",
" <td>None</td>\n",
" <td>2</td>\n",
" <td>-1</td>\n",
" <td>-2.00</td>\n",
" </tr>\n",
" <tr>\n",
" <th>16</th>\n",
" <td>-2</td>\n",
" <td>-1</td>\n",
" <td>None</td>\n",
" <td>43</td>\n",
" <td>-1</td>\n",
" <td>-2.00</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" feature left mapped_feature n_node_samples right thresholds\n",
"0 2 1 petal length (cm) 150 2 2.45\n",
"1 -2 -1 None 50 -1 -2.00\n",
"2 3 3 petal width (cm) 100 12 1.75\n",
"3 2 4 petal length (cm) 54 7 4.95\n",
"4 3 5 petal width (cm) 48 6 1.65\n",
"5 -2 -1 None 47 -1 -2.00\n",
"6 -2 -1 None 1 -1 -2.00\n",
"7 3 8 petal width (cm) 6 9 1.55\n",
"8 -2 -1 None 3 -1 -2.00\n",
"9 2 10 petal length (cm) 3 11 5.45\n",
"10 -2 -1 None 2 -1 -2.00\n",
"11 -2 -1 None 1 -1 -2.00\n",
"12 2 13 petal length (cm) 46 16 4.85\n",
"13 0 14 sepal length (cm) 3 15 5.95\n",
"14 -2 -1 None 1 -1 -2.00\n",
"15 -2 -1 None 2 -1 -2.00\n",
"16 -2 -1 None 43 -1 -2.00"
]
},
"execution_count": 41,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.datasets import load_iris\n",
"from sklearn import tree\n",
"iris = load_iris()\n",
"clf = tree.DecisionTreeClassifier()\n",
"clf = clf.fit(iris.data, iris.target)\n",
"\n",
"pd.DataFrame(\n",
" data = {\n",
"# 'values': clf.tree_.value,\n",
" 'right': clf.tree_.children_right,\n",
" 'left': clf.tree_.children_left,\n",
" 'thresholds': clf.tree_.threshold,\n",
" 'n_node_samples': clf.tree_.n_node_samples,\n",
" 'feature': clf.tree_.feature,\n",
" 'mapped_feature': [iris.feature_names[i] if i >= 0 else None for i in clf.tree_.feature]\n",
" }\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Each of the properties of the tree has the same length and each corresponds with a node of the tree. Each node has left or right children and the value corresponds with the *index* of the particular child. \n",
"\n",
"Above, the root node is index *0*. The left children is index 1, right is index 2. In this particular case, looking at node 1, we see that it's a terminal node (a leaf) because it's children have value *-1*. \n",
"\n",
"To determine the feature used at the particular node, use the value of `feature` in the feature mapping (this just corresponds with the order of the features in the training data). \n",
"\n",
"The threshold used is provided as well. If both feature and threshold are *-2* then this is a leaf node and no split is performed."
]
},
{
"cell_type": "code",
"execution_count": 57,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[50. 50. 50.]]\n",
"petal length (cm) <= 2.450000047683716\n",
"\t[[50. 0. 0.]]\n",
"\t[[ 0. 50. 50.]]\n",
"\tpetal width (cm) <= 1.75\n",
"\t\t[[ 0. 49. 5.]]\n",
"\t\tpetal length (cm) <= 4.949999809265137\n",
"\t\t\t[[ 0. 47. 1.]]\n",
"\t\t\tpetal width (cm) <= 1.6500000953674316\n",
"\t\t\t\t[[ 0. 47. 0.]]\n",
"\t\t\t\t[[0. 0. 1.]]\n",
"\t\t\t[[0. 2. 4.]]\n",
"\t\t\tpetal width (cm) <= 1.5499999523162842\n",
"\t\t\t\t[[0. 0. 3.]]\n",
"\t\t\t\t[[0. 2. 1.]]\n",
"\t\t\t\tpetal length (cm) <= 5.449999809265137\n",
"\t\t\t\t\t[[0. 2. 0.]]\n",
"\t\t\t\t\t[[0. 0. 1.]]\n",
"\t\t[[ 0. 1. 45.]]\n",
"\t\tpetal length (cm) <= 4.850000381469727\n",
"\t\t\t[[0. 1. 2.]]\n",
"\t\t\tsepal length (cm) <= 5.949999809265137\n",
"\t\t\t\t[[0. 1. 0.]]\n",
"\t\t\t\t[[0. 0. 2.]]\n",
"\t\t\t[[ 0. 0. 43.]]\n"
]
}
],
"source": [
"\n",
"\n",
"def recurse(t, node_id, call_stack=0, feature_names=None):\n",
" \"\"\"Recurses a tree and prints out the structure\"\"\"\n",
" left_id, right_id = t.children_left[node_id], t.children_right[node_id]\n",
" \n",
" if not feature_names:\n",
" feature_names = range(t.feature.max())\n",
" feature_mapping = {k: v for k, v in enumerate(feature_names)}\n",
" \n",
" tabs = '\\t' * call_stack\n",
" print(f'{tabs}{t.value[node_id]}')\n",
" if left_id != -1:\n",
" print(f'{tabs}{feature_mapping[t.feature[node_id]]} <= {t.threshold[node_id]}')\n",
"\n",
" recurse(t, left_id, call_stack + 1, feature_names)\n",
" recurse(t, right_id, call_stack + 1, feature_names)\n",
" \n",
"recurse(clf.tree_, 0, feature_names=iris.feature_names)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment