Skip to content

Instantly share code, notes, and snippets.

@tomasgreif
Created July 19, 2013 12:36
Show Gist options
  • Star 4 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save tomasgreif/6038822 to your computer and use it in GitHub Desktop.
Save tomasgreif/6038822 to your computer and use it in GitHub Desktop.
Parse rpart model.
#' Create SQL statement from rpart rules
#'
#' Rpart rules are changed to sql CASE statement.
#'
#' @param df data frame used for rpart model
#' @param model rpart model
#' @export
#' @examples
#' parse_tree(df=kyphosis,model=rpart(data=kyphosis,formula=Kyphosis~.))
#' parse_tree(df=mtcars,model=rpart(data=mtcars,formula=am~.))
#' parse_tree(df=iris,model=rpart(data=iris,formula=Species~.))
#' x <- german_data
#' x$gbbin <- NULL
#' model <- rpart(data=x,formula=gb~.)
#' parse_tree(x,model)
parse_tree <- function (df=NULL, model=NULL) {
log <- capture.output({
rpart.rules <- path.rpart(model,rownames(model$frame)[model$frame$var=="<leaf>"])
})
args <- c("<=",">=","<",">","=")
rules_out <- "case "
i <- 1
for (rule in rpart.rules) {
rule_out <- character(0)
for (component in rule) {
sep <- lapply(args, function(x) length(unlist(strsplit(component,x)))) > 1
elements <- unlist(strsplit(component,(args[sep])[1]))
if(!(elements[1]=="root")) {
if (is.numeric(df[,elements[[1]]])) {
rule_out <- c(rule_out,paste(elements[1],(args[sep])[1],elements[2]))
} else {
rule_out <- c(rule_out,paste0(elements[1]," in (",paste0("'",unlist(strsplit(elements[2],",")),"'",collapse=","),")"))
}
}
}
rules_out <- c(rules_out, paste0("when ", paste(rule_out,collapse=" AND ")," then 'node_" ,names(rpart.rules)[i],"'"))
if(i==length(rpart.rules)) rules_out <- c(rules_out," end ")
i <- i +1
}
sql_out <- paste(rules_out, collapse=" ")
sql_out
}
@kamashay
Copy link

kamashay commented Feb 7, 2018

Hi, thanks much for the great function, helped me much!
could there be a bug in the scenario of a rule in the tree indicating a missing value?
for example 'featureX=' which really means 'featureX' is NA.
in that case line:
' sep <- lapply(args, function(x) length(unlist(strsplit(component,x)))) > 1'
would result with all false, as the strsplit would fail since the '=' operator is terminating the string (and return the full string with the '=' operator).
as a result, 'df[,elements[[1]]])' cumming later on would crush the code as it would try to mach a column name which does not exist ('featureX=')
thanks much,
kamashay

@dennisliub
Copy link

Very good work, thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment