Skip to content

Instantly share code, notes, and snippets.

@cstrelioff
Last active June 4, 2018 16:31
Show Gist options
  • Star 12 You must be signed in to star a gist
  • Fork 14 You must be signed in to fork a gist
  • Save cstrelioff/8fefa9a43e82d96e9f0c to your computer and use it in GitHub Desktop.
Save cstrelioff/8fefa9a43e82d96e9f0c to your computer and use it in GitHub Desktop.
decision trees: scikit-learn + pandas

decision trees: scikit-learn + pandas

This script provides an example of learning a decision tree with scikit-learn. Pandas is used to read data and custom functions are employed to investigate the decision tree after it is learned. Grab the code and try it out.

A blog post about this code is available here, check it out!

Requirements

  • python -- developed with 2.7.6
  • sckit-learn -- using version 0.16.1
  • pandas -- using version 0.16.1
  • numpy -- using version 1.9.2

and to create the graphic of the tree you must have graphviz/dot installed.

Usage

1. Run script from command line

This provides an example of using the available functions-- look at lines 122-143 to see what is done.

$ python analyze_dt.py

This:

  • Fetches the data using pandas, or grabs the local copy.
  • Outputs the head of the pandas data frame.
  • Fits the decision tree and outputs the pseudo code for the decision tree.
  • Uses pandas to show that the first branch at PetalLength <= 2.45 is easily verified.

The resulting output is:

-- get data:
-- iris.csv found locally

-- df.head():
   SepalLength  SepalWidth  PetalLength  PetalWidth         Name
0          5.1         3.5          1.4         0.2  Iris-setosa
1          4.9         3.0          1.4         0.2  Iris-setosa
2          4.7         3.2          1.3         0.2  Iris-setosa
3          4.6         3.1          1.5         0.2  Iris-setosa
4          5.0         3.6          1.4         0.2  Iris-setosa


-- get_code:
if ( PetalLength <= 2.45000004768 ) {
    return Iris-setosa ( 50 examples )
}
else {
    if ( PetalWidth <= 1.75 ) {
        if ( PetalLength <= 4.94999980927 ) {
            if ( PetalWidth <= 1.65000009537 ) {
                return Iris-versicolor ( 47 examples )
            }
            else {
                return Iris-virginica ( 1 examples )
            }
        }
        else {
            return Iris-versicolor ( 2 examples )
            return Iris-virginica ( 4 examples )
        }
    }
    else {
        if ( PetalLength <= 4.85000038147 ) {
            return Iris-versicolor ( 1 examples )
            return Iris-virginica ( 2 examples )
        }
        else {
            return Iris-virginica ( 43 examples )
        }
    }
}

-- look back at original data using pandas
-- df[df['PetalLength'] <= 2.45]]['Name'].unique():  ['Iris-setosa']

2. Use interactively with (i)python

This code can also be used interactively by importing the available functions. I do this by importing analyze_dt as adt and using a function like so adt.function(). Follow along:

>>> import analyze_dt as adt
>>> df = adt.get_iris_data()
-- iris.csv found locally
>>> df.head()
   SepalLength  SepalWidth  PetalLength  PetalWidth         Name
0          5.1         3.5          1.4         0.2  Iris-setosa
1          4.9         3.0          1.4         0.2  Iris-setosa
2          4.7         3.2          1.3         0.2  Iris-setosa
3          4.6         3.1          1.5         0.2  Iris-setosa
4          5.0         3.6          1.4         0.2  Iris-setosa
>>> df.columns
Index([u'SepalLength', u'SepalWidth', u'PetalLength', u'PetalWidth', u'Name'], dtype='object')
>>> features = list(df.columns[:4])
>>> features
['SepalLength', 'SepalWidth', 'PetalLength', 'PetalWidth']
>>> df, targets = adt.encode_target(df, "Name")
>>> y = df["Target"]
>>> X = df[features]
>>> dt = adt.DecisionTreeClassifier(min_samples_split=20, random_state=99)
>>> dt.fit(X,y)
DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=None,
            max_features=None, max_leaf_nodes=None, min_samples_leaf=1,
            min_samples_split=20, min_weight_fraction_leaf=0.0,
            random_state=99, splitter='best')
>>> adt.get_code(dt, features, targets)
if ( PetalLength <= 2.45000004768 ) {
    return Iris-setosa ( 50 examples )
}
else {
    if ( PetalWidth <= 1.75 ) {
        if ( PetalLength <= 4.94999980927 ) {
            if ( PetalWidth <= 1.65000009537 ) {
                return Iris-versicolor ( 47 examples )
            }
            else {
                return Iris-virginica ( 1 examples )
            }
        }
        else {
            return Iris-versicolor ( 2 examples )
            return Iris-virginica ( 4 examples )
        }
    }
    else {
        if ( PetalLength <= 4.85000038147 ) {
            return Iris-versicolor ( 1 examples )
            return Iris-virginica ( 2 examples )
        }
        else {
            return Iris-virginica ( 43 examples )
        }
    }
}
>>> df[df['PetalLength'] <= 2.45]['Name'].unique()
array(['Iris-setosa'], dtype=object)
>>> adt.visualize_tree(dt, features)
#! /usr/bin/env python
# -*- coding: utf-8 -*-
# vim:fenc=utf-8
#
# Copyright © 2015 Christopher C. Strelioff <chris.strelioff@gmail.com>
#
# Distributed under terms of the MIT license.
"""analyze_dt.py -- probe a decision tree found with scikit-learn."""
from __future__ import print_function
import os
import subprocess
import pandas as pd
import numpy as np
from sklearn.tree import DecisionTreeClassifier, export_graphviz
def get_code(tree, feature_names, target_names, spacer_base=" "):
"""Produce psuedo-code for decision tree.
Args
----
tree -- scikit-leant DescisionTree.
feature_names -- list of feature names.
target_names -- list of target (class) names.
spacer_base -- used for spacing code (default: " ").
Notes
-----
based on http://stackoverflow.com/a/30104792.
"""
left = tree.tree_.children_left
right = tree.tree_.children_right
threshold = tree.tree_.threshold
features = [feature_names[i] for i in tree.tree_.feature]
value = tree.tree_.value
def recurse(left, right, threshold, features, node, depth):
spacer = spacer_base * depth
if (threshold[node] != -2):
print(spacer + "if ( " + features[node] + " <= " + \
str(threshold[node]) + " ) {")
if left[node] != -1:
recurse (left, right, threshold, features, left[node],
depth+1)
print(spacer + "}\n" + spacer +"else {")
if right[node] != -1:
recurse (left, right, threshold, features, right[node],
depth+1)
print(spacer + "}")
else:
target = value[node]
for i, v in zip(np.nonzero(target)[1], target[np.nonzero(target)]):
target_name = target_names[i]
target_count = int(v)
print(spacer + "return " + str(target_name) + " ( " + \
str(target_count) + " examples )")
recurse(left, right, threshold, features, 0, 0)
def visualize_tree(tree, feature_names):
"""Create tree png using graphviz.
Args
----
tree -- scikit-learn DecsisionTree.
feature_names -- list of feature names.
"""
with open("dt.dot", 'w') as f:
export_graphviz(tree, out_file=f, feature_names=feature_names)
command = ["dot", "-Tpng", "dt.dot", "-o", "dt.png"]
try:
subprocess.check_call(command)
except:
exit("Could not run dot, ie graphviz, to produce visualization")
def encode_target(df, target_column):
"""Add column to df with integers for the target.
Args
----
df -- pandas DataFrame.
target_column -- column to map to int, producing new Target column.
Returns
-------
df -- modified DataFrame.
targets -- list of target names.
"""
df_mod = df.copy()
targets = df_mod[target_column].unique()
map_to_int = {name: n for n, name in enumerate(targets)}
df_mod["Target"] = df_mod[target_column].replace(map_to_int)
return (df_mod, targets)
def get_iris_data():
"""Get the iris data, from local csv or pandas repo."""
if os.path.exists("iris.csv"):
print("-- iris.csv found locally")
df = pd.read_csv("iris.csv", index_col=0)
else:
print("-- trying to download from github")
fn = "https://raw.githubusercontent.com/pydata/pandas/master/pandas/tests/data/iris.csv"
try:
df = pd.read_csv(fn)
except:
exit("-- Unable to download iris.csv")
with open("iris.csv", 'w') as f:
print("-- writing to local iris.csv file")
df.to_csv(f)
return df
if __name__ == '__main__':
print("\n-- get data:")
df = get_iris_data()
print("\n-- df.head():")
print(df.head(), end="\n\n")
features = ["SepalLength", "SepalWidth", "PetalLength", "PetalWidth"]
df, targets = encode_target(df, "Name")
y = df["Target"]
X = df[features]
dt = DecisionTreeClassifier(min_samples_split=20, random_state=99)
dt.fit(X, y)
print("\n-- get_code:")
get_code(dt, features, targets)
print("\n-- look back at original data using pandas")
print("-- df[df['PetalLength'] <= 2.45]]['Name'].unique(): ",
df[df['PetalLength'] <= 2.45]['Name'].unique(), end="\n\n")
visualize_tree(dt, features)
@msusol
Copy link

msusol commented Mar 7, 2016

First, what a great example you have provided 👍

After using wget with the link you provided to download the iris.csv directly first, I encountered an error using the code "as is":
KeyError: "['SepalLength'] not in index"

removing "index_col=0" from read_csv (l#106) fixed the problem.

@jlgranda
Copy link

Please explain me this:
print("-- df[df['PetalLength'] <= 2.45]]['Name'].unique(): ",
df[df['PetalLength'] <= 2.45]['Name'].unique(), end="\n\n")

@BrazilForever11
Copy link

Excellent example! Minor comment y = df["Target"] and X = df[features] now require reshaping in scikit-learn

@geraldstanje
Copy link

@BrazilForever11 how to reshape it to?

@BhanuJyothi
Copy link

map_to_int = {name: n for n, name in enumerate(targets)}
df_mod["Target"] = df_mod[target_column].replace(map_to_int)

Can anyone explain what the above code indicates..? I'm running the program but got the below error -

AttributeError: 'Series' object has no attribute 'replace'

How to resolve this error

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment