Skip to content

Instantly share code, notes, and snippets.

@jonrobinson2
Forked from shanebutler/sql.export.randomForest.R
Last active August 29, 2015 14:07
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jonrobinson2/68a0060fe9133b884bbd to your computer and use it in GitHub Desktop.
Save jonrobinson2/68a0060fe9133b884bbd to your computer and use it in GitHub Desktop.
# sql.export.rf(): save a randomForest model as SQL
# v0.03
# Copyright (c) 2013-2014 Shane Butler <shane dot butler at gmail dot com>
#
# sql.export.rf is free software: you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 2 of the License, or
# (at your option) any later version.
#
# sql.export.rf is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with sql.export.rf. If not, see <http://www.gnu.org/licenses/>.
#
#
## NOTE:
# This code generates SQL scoring code from your randomForest model.
# Currently the generated code is not optimal since it makes as many
# passes over the input data as there are trees (ie. if there are 500
# trees there will be 500 INSERT... SELECT statements)
#
## USAGE:
# sql.export.rf(rf1, file="model_output.SQL", input.table="data", id="id")
#
## ARGUMENTS:
# variant: Optional argument for Teradata variant="teradata"
#
sql.export.rf <- function (model, file, input.table="source_table",
id="id",
variant="generic") {
require (randomForest, quietly=TRUE)
if (!("randomForest" %in% class(model))) {
stop ("Expected a randomForest object")
return
}
sink(file, type="output")
if (model$type == "classification" && is.numeric(t$prediction)==FALSE) {
pred.type <- "VARCHAR"
} else {
pred.type <- "FLOAT"
}
if (variant == "teradata") {
cat(paste("CREATE VOLATILE TABLE rf_predictions (\n",
"\t",id," INT NOT NULL,\n",
"\tpred ",pred.type,"\n",
") ON COMMIT PRESERVE ROWS;\n\n",
"CREATE MULTISET VOLATILE TABLE tmp_rf (\n",
"\t",id," INT NOT NULL,\n",
"\tpred ",pred.type,"\n",
") ON COMMIT PRESERVE ROWS;\n\n",sep=""))
} else {
cat(paste("CREATE TABLE rf_predictions (\n",
"\t",id," INT NOT NULL,\n",
"\tpred ",pred.type,"\n",
");\n\n",
"DROP TABLE IF EXISTS tmp_rf;\n\n",
"CREATE TABLE tmp_rf (\n",
"\t",id," INT NOT NULL,\n",
"\tpred ",pred.type,"\n",
");\n\n",sep=""))
}
for (tree.num in 1:(model$ntree)) {
cat(paste("INSERT INTO tmp_rf\nSELECT ",id,",", sep=""))
recurse.rf <- function(model, tree.data, tree.row.num, ind=0) {
tree.row <- tree.data[tree.row.num,]
indent.str <- paste(rep("\t", ind), collapse="")
split.var <- as.character(tree.row[,"split var"])
split.point <- tree.row[,"split point"]
if(tree.row[,"status"] != -1) { # splitting node
if(is.numeric(unlist(model$forest$xlevels[split.var]))) {
cat(paste("\n",indent.str,"CASE WHEN", gsub("[.]","_",split.var), "IS NULL THEN NULL",
"\n",indent.str,"WHEN", gsub("[.]","_",split.var), "<=", split.point, "THEN "))
recurse.rf(model, tree.data, tree.row[,"left daughter"], ind=(ind+1))
cat("\n",indent.str,"ELSE ")
recurse.rf(model, tree.data, tree.row[,"right daughter"], ind=(ind+1))
cat("END ")
} else { # categorical
# function to convert from binary coding to the category values it represents
conv.to.binary <- function (ncat, num.to.convert) {
ret <- numeric()
if((2^ncat) <= num.to.convert) {
return (NULL)
} else {
for (x in (ncat - 1):0) {
if (num.to.convert >= (2^x)) {
num.to.convert <- num.to.convert - (2^x)
ret <- c(ret, 1)
} else {
ret <- c(ret, 0)
}
}
return(ret)
}
}
categ.bin <- conv.to.binary(model$forest$ncat[split.var], split.point)
categ.flags <- (categ.bin[length(categ.bin):1] == 1)
categ.values <- unlist(model$forest$xlevels[split.var])
cat(paste("\n",indent.str,"CASE WHEN ", gsub("[.]","_",split.var), " IN ('",
paste(categ.values[categ.flags], sep="", collapse="', '"), #FIXME replace quotes dependant on var type
"') THEN ", sep=""))
recurse.rf(model, tree.data, tree.row[,"left daughter"], ind=(ind+1))
cat(paste("\n",indent.str,"WHEN ", gsub("[.]","_",split.var), " IN ('",
paste(categ.values[!categ.flags], sep="", collapse="', '"),
"') THEN ", sep=""))
recurse.rf(model, tree.data, tree.row[,"right daughter"], ind=(ind+1))
cat(paste("\n", indent.str,"ELSE NULL END ", sep="")) #FIXME: null or a new category
}
} else { # terminal node
if (is.numeric(tree.data$prediction)) {
cat(paste(tree.row[,"prediction"], " ", sep=""))
} else {
cat(paste("'", tree.row[,"prediction"], "' ", sep=""))
}
}
}
recurse.rf(model, getTree(model,k=tree.num,labelVar=TRUE), 1)
cat(paste("as tree",tree.num,"\nFROM ",input.table,";\n\n", sep=""))
}
if (model$type == "classification") {
# This code is not optimal but many SQL implementations do not support window functions (eg. SQLite)
# Had to remove use of WITH because not supported by all SQL variants
cat(paste("INSERT INTO rf_predictions\n",
"SELECT a.id, a.pred\n",
"FROM (SELECT ",id," as id, pred, COUNT(*) as cnt FROM tmp_rf GROUP BY ",id,", pred) a\n",
"INNER JOIN (SELECT id, MAX(cnt) as cnt\n",
"\t\t\tFROM (SELECT ",id," as id, pred, COUNT(*) as cnt FROM tmp_rf GROUP BY ",id,", pred) cc\n",
"\t\t\tGROUP BY id) b\n",
"ON a.id = b.id AND a.cnt = b.cnt;\n\n", sep=""))
} else {
cat(paste("INSERT INTO rf_predictions\n",
"SELECT ",id,", AVG(pred)\n",
"FROM tmp_rf\n",
"GROUP BY ",id,";\n\n", sep=""))
}
if (variant == "teradata") {
cat("DROP TABLE tmp_rf;\n\n")
} else {
cat("DROP TABLE IF EXISTS tmp_rf;\n\n")
}
# close the file
sink()
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment