Skip to content

Instantly share code, notes, and snippets.

@shanebutler
Last active September 14, 2020 14:04
Show Gist options
  • Star 24 You must be signed in to star a gist
  • Fork 19 You must be signed in to fork a gist
  • Save shanebutler/96f0e78a02c84cdcf558 to your computer and use it in GitHub Desktop.
Save shanebutler/96f0e78a02c84cdcf558 to your computer and use it in GitHub Desktop.
Deploy your RandomForest models in SQL! This tool enables in-database scoring of Random Forest models built using R. To use it, you simply call the function with the Random Forest model, output filename, SQL input data table and the name of the unique key on that table. For example:sql.export.rf(rf.mdl, file="model_output.SQL", input.table="sour…
# sql.export.rf(): save a randomForest model as SQL
# v0.04
# Copyright (c) 2013-2014 Shane Butler <shane dot butler at gmail dot com>
#
# sql.export.rf is free software: you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 2 of the License, or
# (at your option) any later version.
#
# sql.export.rf is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with sql.export.rf. If not, see <http://www.gnu.org/licenses/>.
#
#
## NOTE:
# This code generates SQL scoring code from your randomForest model.
# Currently the generated code is not optimal since it makes as many
# passes over the input data as there are trees (ie. if there are 500
# trees there will be 500 INSERT... SELECT statements)
#
## USAGE:
# sql.export.rf(rf1, file="model_output.SQL", input.table="data", id="id")
#
## ARGUMENTS:
# variant: Optional argument for Teradata variant="teradata"
#
sql.export.rf <- function (model, file, input.table="source_table",
id="id",
variant="generic") {
require (randomForest, quietly=TRUE)
if (!("randomForest" %in% class(model))) {
stop ("Expected a randomForest object")
return
}
sink(file, type="output")
if (model$type == "classification") {
pred.type <- "VARCHAR"
} else {
pred.type <- "FLOAT"
}
if (variant == "teradata") {
cat(paste("CREATE VOLATILE TABLE rf_predictions (\n",
"\t",id," INT NOT NULL,\n",
"\tpred ",pred.type,"\n",
") ON COMMIT PRESERVE ROWS;\n\n",
"CREATE MULTISET VOLATILE TABLE tmp_rf (\n",
"\t",id," INT NOT NULL,\n",
"\tpred ",pred.type,"\n",
") ON COMMIT PRESERVE ROWS;\n\n",sep=""))
} else {
cat(paste("CREATE TABLE rf_predictions (\n",
"\t",id," INT NOT NULL,\n",
"\tpred ",pred.type,"\n",
");\n\n",
"DROP TABLE IF EXISTS tmp_rf;\n\n",
"CREATE TABLE tmp_rf (\n",
"\t",id," INT NOT NULL,\n",
"\tpred ",pred.type,"\n",
");\n\n",sep=""))
}
for (tree.num in 1:(model$ntree)) {
cat(paste("INSERT INTO tmp_rf\nSELECT ",id,",", sep=""))
recurse.rf <- function(model, tree.data, tree.row.num, ind=0) {
tree.row <- tree.data[tree.row.num,]
indent.str <- paste(rep("\t", ind), collapse="")
split.var <- as.character(tree.row[,"split var"])
split.point <- tree.row[,"split point"]
if(tree.row[,"status"] != -1) { # splitting node
if(is.numeric(unlist(model$forest$xlevels[split.var]))) {
cat(paste("\n",indent.str,"CASE WHEN", gsub("[.]","_",split.var), "IS NULL THEN NULL",
"\n",indent.str,"WHEN", gsub("[.]","_",split.var), "<=", split.point, "THEN "))
recurse.rf(model, tree.data, tree.row[,"left daughter"], ind=(ind+1))
cat("\n",indent.str,"ELSE ")
recurse.rf(model, tree.data, tree.row[,"right daughter"], ind=(ind+1))
cat("END ")
} else { # categorical
# function to convert from binary coding to the category values it represents
conv.to.binary <- function (ncat, num.to.convert) {
ret <- numeric()
if((2^ncat) <= num.to.convert) {
return (NULL)
} else {
for (x in (ncat - 1):0) {
if (num.to.convert >= (2^x)) {
num.to.convert <- num.to.convert - (2^x)
ret <- c(ret, 1)
} else {
ret <- c(ret, 0)
}
}
return(ret)
}
}
categ.bin <- conv.to.binary(model$forest$ncat[split.var], split.point)
categ.flags <- (categ.bin[length(categ.bin):1] == 1)
categ.values <- unlist(model$forest$xlevels[split.var])
cat(paste("\n",indent.str,"CASE WHEN ", gsub("[.]","_",split.var), " IN ('",
paste(categ.values[categ.flags], sep="", collapse="', '"), #FIXME replace quotes dependant on var type
"') THEN ", sep=""))
recurse.rf(model, tree.data, tree.row[,"left daughter"], ind=(ind+1))
cat(paste("\n",indent.str,"WHEN ", gsub("[.]","_",split.var), " IN ('",
paste(categ.values[!categ.flags], sep="", collapse="', '"),
"') THEN ", sep=""))
recurse.rf(model, tree.data, tree.row[,"right daughter"], ind=(ind+1))
cat(paste("\n", indent.str,"ELSE NULL END ", sep="")) #FIXME: null or a new category
}
} else { # terminal node
if (is.numeric(tree.data$prediction)) {
cat(paste(tree.row[,"prediction"], " ", sep=""))
} else {
cat(paste("'", tree.row[,"prediction"], "' ", sep=""))
}
}
}
recurse.rf(model, getTree(model,k=tree.num,labelVar=TRUE), 1)
cat(paste("as tree",tree.num,"\nFROM ",input.table,";\n\n", sep=""))
}
if (model$type == "classification") {
# This code is not optimal but many SQL implementations do not support window functions (eg. SQLite)
# Had to remove use of WITH because not supported by all SQL variants
cat(paste("INSERT INTO rf_predictions\n",
"SELECT a.id, a.pred\n",
"FROM (SELECT ",id," as id, pred, COUNT(*) as cnt FROM tmp_rf GROUP BY ",id,", pred) a\n",
"INNER JOIN (SELECT id, MAX(cnt) as cnt\n",
"\t\t\tFROM (SELECT ",id," as id, pred, COUNT(*) as cnt FROM tmp_rf GROUP BY ",id,", pred) cc\n",
"\t\t\tGROUP BY id) b\n",
"ON a.id = b.id AND a.cnt = b.cnt;\n\n", sep=""))
} else {
cat(paste("INSERT INTO rf_predictions\n",
"SELECT ",id,", AVG(pred)\n",
"FROM tmp_rf\n",
"GROUP BY ",id,";\n\n", sep=""))
}
if (variant == "teradata") {
cat("DROP TABLE tmp_rf;\n\n")
} else {
cat("DROP TABLE IF EXISTS tmp_rf;\n\n")
}
# close the file
sink()
}
@shanebutler
Copy link
Author

for anyone interested in doing something similar with xgboost, here you go: https://github.com/ras44/articles/blob/master/20181018_xgboost_scoring_via_sql.md

great, thankyou @ras44!

@nezorepla
Copy link

it gives an error because there are too many "case";

Msg 125, Level 15, State 4, Line 20
Case expressions may only be nested to level 10.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment