Skip to content

Instantly share code, notes, and snippets.

@cstrelioff
Last active October 19, 2020 17:20
Show Gist options
  • Save cstrelioff/4cfd65d224c89604dc2b to your computer and use it in GitHub Desktop.
Save cstrelioff/4cfd65d224c89604dc2b to your computer and use it in GitHub Desktop.

decision trees: cross-validation

This script provides an example of using cross-validation to fine-tune parameters for learning a decision tree with scikit-learn.

A blog post about this code is available here, check it out!

Requirements

  • python -- developed with 2.7.6
  • sckit-learn -- using version 0.16.1
  • pandas -- using version 0.16.1
  • numpy -- using version 1.9.2

and to create the graphic of the tree you must have graphviz/dot installed.

Usage

1. Run script from command line

This provides an example of using the available functions-- look at lines 232 onwards to see how data is obtained and functions used.

$ python dt_cross_validation.py

The resulting output is:

-- get data:
-- iris.csv found locally

-- 10-fold cross-validation [using setup from previous post]
mean: 0.960 (std: 0.033)

-- Grid Parameter Search via 10-fold CV

GridSearchCV took 5.10 seconds for 288 candidate parameter settings.
Model with rank: 1
Mean validation score: 0.967 (std: 0.033)
Parameters: {'min_samples_split': 10, 'max_leaf_nodes': 5, 'criterion': 'gini', 'max_depth': None, 'min_samples_leaf': 1}

Model with rank: 2
Mean validation score: 0.967 (std: 0.033)
Parameters: {'min_samples_split': 20, 'max_leaf_nodes': 5, 'criterion': 'gini', 'max_depth': None, 'min_samples_leaf': 1}

Model with rank: 3
Mean validation score: 0.967 (std: 0.033)
Parameters: {'min_samples_split': 10, 'max_leaf_nodes': 5, 'criterion': 'gini', 'max_depth': 5, 'min_samples_leaf': 1}


-- Best Parameters:
parameter: min_samples_split    setting: 10
parameter: max_leaf_nodes       setting: 5
parameter: criterion            setting: gini
parameter: max_depth            setting: None
parameter: min_samples_leaf     setting: 1


-- Testing best parameters [Grid]...
mean: 0.967 (std: 0.033)


-- get_code for best parameters [Grid]:

if ( PetalLength <= 2.45000004768 ) {
    return Iris-setosa ( 50 examples )
}
else {
    if ( PetalWidth <= 1.75 ) {
        if ( PetalLength <= 4.94999980927 ) {
            if ( PetalWidth <= 1.65000009537 ) {
                return Iris-versicolor ( 47 examples )
            }
            else {
                return Iris-virginica ( 1 examples )
            }
        }
        else {
            return Iris-versicolor ( 2 examples )
            return Iris-virginica ( 4 examples )
        }
    }
    else {
        return Iris-versicolor ( 1 examples )
        return Iris-virginica ( 45 examples )
    }
}
-- Random Parameter Search via 10-fold CV

RandomizedSearchCV took 1.55 seconds for 288 candidates parameter settings.
Model with rank: 1
Mean validation score: 0.967 (std: 0.033)
Parameters: {'min_samples_split': 14, 'max_leaf_nodes': 5, 'criterion': 'gini', 'max_depth': 9, 'min_samples_leaf': 1}

Model with rank: 2
Mean validation score: 0.960 (std: 0.042)
Parameters: {'min_samples_split': 1, 'max_leaf_nodes': 11, 'criterion': 'gini', 'max_depth': 11, 'min_samples_leaf': 4}

Model with rank: 3
Mean validation score: 0.960 (std: 0.042)
Parameters: {'min_samples_split': 11, 'max_leaf_nodes': 4, 'criterion': 'gini', 'max_depth': 16, 'min_samples_leaf': 5}


-- Best Parameters:
parameters: min_samples_split    setting: 14
parameters: max_leaf_nodes       setting: 5
parameters: criterion            setting: gini
parameters: max_depth            setting: 9
parameters: min_samples_leaf     setting: 1


-- Testing best parameters [Random]...
mean: 0.967 (std: 0.033)


-- get_code for best parameters [Random]:
if ( PetalLength <= 2.45000004768 ) {
    return Iris-setosa ( 50 examples )
}
else {
    if ( PetalWidth <= 1.75 ) {
        if ( PetalLength <= 4.94999980927 ) {
            if ( PetalWidth <= 1.65000009537 ) {
                return Iris-versicolor ( 47 examples )
            }
            else {
                return Iris-virginica ( 1 examples )
            }
        }
        else {
            return Iris-versicolor ( 2 examples )
            return Iris-virginica ( 4 examples )
        }
    }
    else {
        return Iris-versicolor ( 1 examples )
        return Iris-virginica ( 45 examples )
    }
}
#! /usr/bin/env python
# -*- coding: utf-8 -*-
# vim:fenc=utf-8
#
# Copyright © 2015 Christopher C. Strelioff <chris.strelioff@gmail.com>
#
# Distributed under terms of the MIT license.
"""dt_cross_validation.py -- use cross-validation to choose best decision
tree parameters.
"""
from __future__ import print_function
import os
import subprocess
from time import time
from operator import itemgetter
from scipy.stats import randint
import pandas as pd
import numpy as np
from sklearn.tree import DecisionTreeClassifier
from sklearn.tree import export_graphviz
from sklearn.grid_search import GridSearchCV
from sklearn.grid_search import RandomizedSearchCV
from sklearn.cross_validation import cross_val_score
def get_code(tree, feature_names, target_names,
spacer_base=" "):
"""Produce pseudo-code for decision tree.
Args
----
tree -- scikit-leant Decision Tree.
feature_names -- list of feature names.
target_names -- list of target (class) names.
spacer_base -- used for spacing code (default: " ").
Notes
-----
based on http://stackoverflow.com/a/30104792.
"""
left = tree.tree_.children_left
right = tree.tree_.children_right
threshold = tree.tree_.threshold
features = [feature_names[i] for i in tree.tree_.feature]
value = tree.tree_.value
def recurse(left, right, threshold, features, node, depth):
spacer = spacer_base * depth
if (threshold[node] != -2):
print(spacer + "if ( " + features[node] + " <= " + \
str(threshold[node]) + " ) {")
if left[node] != -1:
recurse (left, right, threshold, features,
left[node], depth+1)
print(spacer + "}\n" + spacer +"else {")
if right[node] != -1:
recurse (left, right, threshold, features,
right[node], depth+1)
print(spacer + "}")
else:
target = value[node]
for i, v in zip(np.nonzero(target)[1],
target[np.nonzero(target)]):
target_name = target_names[i]
target_count = int(v)
print(spacer + "return " + str(target_name) + \
" ( " + str(target_count) + " examples )")
recurse(left, right, threshold, features, 0, 0)
def visualize_tree(tree, feature_names, fn="dt"):
"""Create tree png using graphviz.
Args
----
tree -- scikit-learn Decision Tree.
feature_names -- list of feature names.
fn -- [string], root of filename, default `dt`.
"""
dotfile = fn + ".dot"
pngfile = fn + ".png"
with open(dotfile, 'w') as f:
export_graphviz(tree, out_file=f,
feature_names=feature_names)
command = ["dot", "-Tpng", dotfile, "-o", pngfile]
try:
subprocess.check_call(command)
except:
exit("Could not run dot, ie graphviz, "
"to produce visualization")
def encode_target(df, target_column):
"""Add column to df with integers for the target.
Args
----
df -- pandas Data Frame.
target_column -- column to map to int, producing new
Target column.
Returns
-------
df -- modified Data Frame.
targets -- list of target names.
"""
df_mod = df.copy()
targets = df_mod[target_column].unique()
map_to_int = {name: n for n, name in enumerate(targets)}
df_mod["Target"] = df_mod[target_column].replace(map_to_int)
return (df_mod, targets)
def get_iris_data():
"""Get the iris data, from local csv or pandas repo."""
if os.path.exists("iris.csv"):
print("-- iris.csv found locally")
df = pd.read_csv("iris.csv", index_col=0)
else:
print("-- trying to download from github")
fn = ("https://raw.githubusercontent.com/pydata/"
"pandas/master/pandas/tests/data/iris.csv")
try:
df = pd.read_csv(fn)
except:
exit("-- Unable to download iris.csv")
with open("iris.csv", 'w') as f:
print("-- writing to local iris.csv file")
df.to_csv(f)
return df
def report(grid_scores, n_top=3):
"""Report top n_top parameters settings, default n_top=3.
Args
----
grid_scores -- output from grid or random search
n_top -- how many to report, of top models
Returns
-------
top_params -- [dict] top parameter settings found in
search
"""
top_scores = sorted(grid_scores,
key=itemgetter(1),
reverse=True)[:n_top]
for i, score in enumerate(top_scores):
print("Model with rank: {0}".format(i + 1))
print(("Mean validation score: "
"{0:.3f} (std: {1:.3f})").format(
score.mean_validation_score,
np.std(score.cv_validation_scores)))
print("Parameters: {0}".format(score.parameters))
print("")
return top_scores[0].parameters
def run_gridsearch(X, y, clf, param_grid, cv=5):
"""Run a grid search for best Decision Tree parameters.
Args
----
X -- features
y -- targets (classes)
cf -- scikit-learn Decision Tree
param_grid -- [dict] parameter settings to test
cv -- fold of cross-validation, default 5
Returns
-------
top_params -- [dict] from report()
"""
grid_search = GridSearchCV(clf,
param_grid=param_grid,
cv=cv)
start = time()
grid_search.fit(X, y)
print(("\nGridSearchCV took {:.2f} "
"seconds for {:d} candidate "
"parameter settings.").format(time() - start,
len(grid_search.grid_scores_)))
top_params = report(grid_search.grid_scores_, 3)
return top_params
def run_randomsearch(X, y, clf, para_dist, cv=5,
n_iter_search=20):
"""Run a random search for best Decision Tree parameters.
Args
----
X -- features
y -- targets (classes)
cf -- scikit-learn Decision Tree
param_dist -- [dict] list, distributions of parameters
to sample
cv -- fold of cross-validation, default 5
n_iter_search -- number of random parameter sets to try,
default 20.
Returns
-------
top_params -- [dict] from report()
"""
random_search = RandomizedSearchCV(clf,
param_distributions=param_dist,
n_iter=n_iter_search)
start = time()
random_search.fit(X, y)
print(("\nRandomizedSearchCV took {:.2f} seconds "
"for {:d} candidates parameter "
"settings.").format((time() - start),
n_iter_search))
top_params = report(random_search.grid_scores_, 3)
return top_params
if __name__ == "__main__":
print("\n-- get data:")
df = get_iris_data()
print("")
features = ["SepalLength", "SepalWidth",
"PetalLength", "PetalWidth"]
df, targets = encode_target(df, "Name")
y = df["Target"]
X = df[features]
print("-- 10-fold cross-validation "
"[using setup from previous post]")
dt_old = DecisionTreeClassifier(min_samples_split=20,
random_state=99)
dt_old.fit(X, y)
scores = cross_val_score(dt_old, X, y, cv=10)
print("mean: {:.3f} (std: {:.3f})".format(scores.mean(),
scores.std()),
end="\n\n" )
print("-- Grid Parameter Search via 10-fold CV")
# set of parameters to test
param_grid = {"criterion": ["gini", "entropy"],
"min_samples_split": [2, 10, 20],
"max_depth": [None, 2, 5, 10],
"min_samples_leaf": [1, 5, 10],
"max_leaf_nodes": [None, 5, 10, 20],
}
dt = DecisionTreeClassifier()
ts_gs = run_gridsearch(X, y, dt, param_grid, cv=10)
print("\n-- Best Parameters:")
for k, v in ts_gs.items():
print("parameter: {:<20s} setting: {}".format(k, v))
# test the retuned best parameters
print("\n\n-- Testing best parameters [Grid]...")
dt_ts_gs = DecisionTreeClassifier(**ts_gs)
scores = cross_val_score(dt_ts_gs, X, y, cv=10)
print("mean: {:.3f} (std: {:.3f})".format(scores.mean(),
scores.std()),
end="\n\n" )
print("\n-- get_code for best parameters [Grid]:", end="\n\n")
dt_ts_gs.fit(X,y)
get_code(dt_ts_gs, features, targets)
visualize_tree(dt_ts_gs, features, fn="grid_best")
print("-- Random Parameter Search via 10-fold CV")
# dict of parameter list/distributions to sample
param_dist = {"criterion": ["gini", "entropy"],
"min_samples_split": randint(1, 20),
"max_depth": randint(1, 20),
"min_samples_leaf": randint(1, 20),
"max_leaf_nodes": randint(2, 20)}
dt = DecisionTreeClassifier()
ts_rs = run_randomsearch(X, y, dt, param_dist, cv=10,
n_iter_search=288)
print("\n-- Best Parameters:")
for k, v in ts_rs.items():
print("parameters: {:<20s} setting: {}".format(k, v))
# test the retuned best parameters
print("\n\n-- Testing best parameters [Random]...")
dt_ts_rs = DecisionTreeClassifier(**ts_rs)
scores = cross_val_score(dt_ts_rs, X, y, cv=10)
print("mean: {:.3f} (std: {:.3f})".format(scores.mean(),
scores.std()),
end="\n\n" )
print("\n-- get_code for best parameters [Random]:")
dt_ts_rs.fit(X,y)
get_code(dt_ts_rs, features, targets)
visualize_tree(dt_ts_rs, features, fn="rand_best")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment