Skip to content

Instantly share code, notes, and snippets.

@grigory93
Last active October 7, 2021 05:43
Show Gist options
  • Save grigory93/149fb361fdfe933cd6317601e8a7107c to your computer and use it in GitHub Desktop.
Save grigory93/149fb361fdfe933cd6317601e8a7107c to your computer and use it in GitHub Desktop.
Plotting decision trees with H2O-3
titanic_1tree = h2o.gbm(x = predictors, y = response,
training_frame = titanicHex,
ntrees = 1, min_rows = 1, sample_rate = 1, col_sample_rate = 1,
max_depth = 5,
# use early stopping once the validation AUC doesn't improve by at least 0.01%
# for 5 consecutive scoring events
stopping_rounds = 3, stopping_tolerance = 0.01,
stopping_metric = "AUC",
seed = 1)
titanicH2oTree = h2o.getModelTree(model = titanic_1tree, tree_number = 1)
# Titles mapping
TITLES = data.frame(from=c("Capt", "Col", "Major", "Jonkheer", "Don", "Sir", "Dr", "Rev", "the Countess",
"Mme", "Mlle", "Ms", "Mr", "Mrs", "Miss", "Master", "Lady"),
to = c("Officer", "Officer", "Officer", "Royalty", "Royalty", "Royalty", "Officer", "Officer", "Royalty",
"Mrs", "Miss", "Mrs", "Mr", "Mrs", "Miss", "Master", "Royalty"),
stringsAsFactors = FALSE)
# Create features
titanicDT[,
c("sex",
"embarked",
"survived",
"pclass",
"cabin_type",
"family_size",
"family_type",
"title") := list(
factor(sex, labels = c("Female","Male")),
factor(embarked, labels = c("", "Cherbourg","Queenstown","Southampton")),
factor(-survived, labels = c('Yes','No')),
factor(pclass, labels = c("Class 1","Class 2","Class 3")),
as.factor(substring(cabin, 1, 1)),
sibsp + parch,
as.factor(ifelse(sibsp + parch <= 1, "SINGLE", ifelse(sibsp + parch <= 3, "SMALL", "LARGE"))),
as.factor(sapply(strsplit(name, "[\\., ]+"), function(x) {
words = trimws(x)
words = words[!words=="" ]
words = words[words %in% TITLES$from]
if (length(words) > 0)
title_word = words[[1]]
else
return(NA)
return(TITLES[title_word == TITLES$from, 'to'])
}))
)]
# Handle missing values by imputing them with nulls
titanicDT[, c("age","fare") :=
list(ifelse(is.na(age), mean(age, na.rm=T), age),
ifelse(is.na(fare), mean(fare, na.rm=T), fare)),
by = c("survived","sex","embarked")]
# create dataset for Titanic survived predictive model
response = "survived"
predictors = setdiff(colnames(titanicDT),
c(response,"name","ticket","cabin","boat","body","home.dest"))
titanicDT = titanicDT[, c(response, predictors), with=FALSE]
# split into train and validation
splits = h2o.splitFrame(data = titanicHex, ratios = .8, seed = 1234)
trainHex = splits[[1]]
validHex = splits[[2]]
# GBM hyperparamters
gbm_params = list(max_depth = seq(2, 10))
# Train and validate a cartesian grid of GBMs
gbm_grid = h2o.grid("gbm", x = predictors, y = response,
grid_id = "gbm_grid_1tree",
training_frame = trainHex,
validation_frame = validHex,
ntrees = 1, min_rows = 1, sample_rate = 1, col_sample_rate = 1,
learn_rate = .01, seed = 111,
hyper_params = gbm_params)
gbm_gridperf = h2o.getGrid(grid_id = "gbm_grid_1tree",
sort_by = "auc",
decreasing = TRUE)
# Plot grid model AUC vs. max-depth
library(ggplot2)
library(ggthemes)
ggplot(as.data.frame(sapply(gbm_gridperf@summary_table, as.numeric))) +
geom_point(aes(max_depth, auc)) +
geom_line(aes(max_depth, auc, group=1)) +
labs(x="max depth", y="AUC", title="Grid Search for Single Tree Models") +
theme_pander(base_family = 'Palatino', base_size = 12)
# The following two commands remove any previously installed H2O packages for R.
if ("package:h2o" %in% search()) { detach("package:h2o", unload=TRUE) }
if ("h2o" %in% rownames(installed.packages())) { remove.packages("h2o") }
# Next, we download packages that H2O depends on.
pkgs <- c("RCurl","jsonlite")
for (pkg in pkgs) {
if (! (pkg %in% rownames(installed.packages()))) { install.packages(pkg) }
}
# Now we download, install and initialize the H2O package for R.
install.packages("h2o", type="source", repos="http://h2o-release.s3.amazonaws.com/h2o/rel-xia/2/R")
# Finally, let's load H2O and start up an H2O cluster
library(h2o)
h2o.init()
titanicHex = as.h2o(titanicDT)
library(data.table)
titanicDT = fread("https://s3.amazonaws.com/h2o-public-test-data/smalldata/gbm_test/titanic.csv")
library(data.tree)
createDataTree <- function(h2oTree) {
h2oTreeRoot = h2oTree@root_node
dataTree = Node$new(h2oTreeRoot@split_feature)
dataTree$type = 'split'
addChildren(dataTree, h2oTreeRoot)
return(dataTree)
}
addChildren <- function(dtree, node) {
if(class(node)[1] != 'H2OSplitNode') return(TRUE)
feature = node@split_feature
id = node@id
na_direction = node@na_direction
if(is.na(node@threshold)) {
leftEdgeLabel = printValues(node@left_levels, na_direction=='LEFT', 4)
rightEdgeLabel = printValues(node@right_levels, na_direction=='RIGHT', 4)
}else {
leftEdgeLabel = paste("<", node@threshold, ifelse(na_direction=='LEFT',',NA',''))
rightEdgeLabel = paste(">=", node@threshold, ifelse(na_direction=='RIGHT',',NA',''))
}
left_node = node@left_child
right_node = node@right_child
if(class(left_node)[[1]] == 'H2OLeafNode')
leftLabel = paste("prediction:", left_node@prediction)
else
leftLabel = left_node@split_feature
if(class(right_node)[[1]] == 'H2OLeafNode')
rightLabel = paste("prediction:", right_node@prediction)
else
rightLabel = right_node@split_feature
if(leftLabel == rightLabel) {
leftLabel = paste(leftLabel, "(L)")
rightLabel = paste(rightLabel, "(R)")
}
dtreeLeft = dtree$AddChild(leftLabel)
dtreeLeft$edgeLabel = leftEdgeLabel
dtreeLeft$type = ifelse(class(left_node)[1] == 'H2OSplitNode', 'split', 'leaf')
dtreeRight = dtree$AddChild(rightLabel)
dtreeRight$edgeLabel = rightEdgeLabel
dtreeRight$type = ifelse(class(right_node)[1] == 'H2OSplitNode', 'split', 'leaf')
addChildren(dtreeLeft, left_node)
addChildren(dtreeRight, right_node)
return(FALSE)
}
printValues <- function(values, is_na_direction, n=4) {
l = length(values)
if(l == 0)
value_string = ifelse(is_na_direction, "NA", "")
else
value_string = paste0(paste0(values[1:min(n,l)], collapse = ', '),
ifelse(l > n, ",...", ""),
ifelse(is_na_direction, ", NA", ""))
return(value_string)
}
titanicDataTree = createDataTree(titanicH2oTree)
GetEdgeLabel <- function(node) {return (node$edgeLabel)}
GetNodeShape <- function(node) {switch(node$type, split = "diamond", leaf = "oval")}
GetFontName <- function(node) {switch(node$type, split = 'Palatino-bold', leaf = 'Palatino')}
SetEdgeStyle(titanicDataTree, fontname = 'Palatino-italic', label = GetEdgeLabel, labelfloat = TRUE,
fontsize = "26", fontcolor='royalblue4')
SetNodeStyle(titanicDataTree, fontname = GetFontName, shape = GetNodeShape,
fontsize = "26", fontcolor='royalblue4',
height="0.75", width="1")
SetGraphStyle(titanicDataTree, rankdir = "LR", dpi=70.)
plot(titanicDataTree, output = "graph")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment