# sql.export.rf(): save a randomForest model as SQL | |
# v0.04 | |
# Copyright (c) 2013-2014 Shane Butler <shane dot butler at gmail dot com> | |
# | |
# sql.export.rf is free software: you can redistribute it and/or modify it | |
# under the terms of the GNU General Public License as published by | |
# the Free Software Foundation, either version 2 of the License, or | |
# (at your option) any later version. | |
# | |
# sql.export.rf is distributed in the hope that it will be useful, but | |
# WITHOUT ANY WARRANTY; without even the implied warranty of | |
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU | |
# General Public License for more details. | |
# | |
# You should have received a copy of the GNU General Public License | |
# along with sql.export.rf. If not, see <http://www.gnu.org/licenses/>. | |
# | |
# | |
## NOTE: | |
# This code generates SQL scoring code from your randomForest model. | |
# Currently the generated code is not optimal since it makes as many | |
# passes over the input data as there are trees (ie. if there are 500 | |
# trees there will be 500 INSERT... SELECT statements) | |
# | |
## USAGE: | |
# sql.export.rf(rf1, file="model_output.SQL", input.table="data", id="id") | |
# | |
## ARGUMENTS: | |
# variant: Optional argument for Teradata variant="teradata" | |
# | |
sql.export.rf <- function (model, file, input.table="source_table", | |
id="id", | |
variant="generic") { | |
require (randomForest, quietly=TRUE) | |
if (!("randomForest" %in% class(model))) { | |
stop ("Expected a randomForest object") | |
return | |
} | |
sink(file, type="output") | |
if (model$type == "classification") { | |
pred.type <- "VARCHAR" | |
} else { | |
pred.type <- "FLOAT" | |
} | |
if (variant == "teradata") { | |
cat(paste("CREATE VOLATILE TABLE rf_predictions (\n", | |
"\t",id," INT NOT NULL,\n", | |
"\tpred ",pred.type,"\n", | |
") ON COMMIT PRESERVE ROWS;\n\n", | |
"CREATE MULTISET VOLATILE TABLE tmp_rf (\n", | |
"\t",id," INT NOT NULL,\n", | |
"\tpred ",pred.type,"\n", | |
") ON COMMIT PRESERVE ROWS;\n\n",sep="")) | |
} else { | |
cat(paste("CREATE TABLE rf_predictions (\n", | |
"\t",id," INT NOT NULL,\n", | |
"\tpred ",pred.type,"\n", | |
");\n\n", | |
"DROP TABLE IF EXISTS tmp_rf;\n\n", | |
"CREATE TABLE tmp_rf (\n", | |
"\t",id," INT NOT NULL,\n", | |
"\tpred ",pred.type,"\n", | |
");\n\n",sep="")) | |
} | |
for (tree.num in 1:(model$ntree)) { | |
cat(paste("INSERT INTO tmp_rf\nSELECT ",id,",", sep="")) | |
recurse.rf <- function(model, tree.data, tree.row.num, ind=0) { | |
tree.row <- tree.data[tree.row.num,] | |
indent.str <- paste(rep("\t", ind), collapse="") | |
split.var <- as.character(tree.row[,"split var"]) | |
split.point <- tree.row[,"split point"] | |
if(tree.row[,"status"] != -1) { # splitting node | |
if(is.numeric(unlist(model$forest$xlevels[split.var]))) { | |
cat(paste("\n",indent.str,"CASE WHEN", gsub("[.]","_",split.var), "IS NULL THEN NULL", | |
"\n",indent.str,"WHEN", gsub("[.]","_",split.var), "<=", split.point, "THEN ")) | |
recurse.rf(model, tree.data, tree.row[,"left daughter"], ind=(ind+1)) | |
cat("\n",indent.str,"ELSE ") | |
recurse.rf(model, tree.data, tree.row[,"right daughter"], ind=(ind+1)) | |
cat("END ") | |
} else { # categorical | |
# function to convert from binary coding to the category values it represents | |
conv.to.binary <- function (ncat, num.to.convert) { | |
ret <- numeric() | |
if((2^ncat) <= num.to.convert) { | |
return (NULL) | |
} else { | |
for (x in (ncat - 1):0) { | |
if (num.to.convert >= (2^x)) { | |
num.to.convert <- num.to.convert - (2^x) | |
ret <- c(ret, 1) | |
} else { | |
ret <- c(ret, 0) | |
} | |
} | |
return(ret) | |
} | |
} | |
categ.bin <- conv.to.binary(model$forest$ncat[split.var], split.point) | |
categ.flags <- (categ.bin[length(categ.bin):1] == 1) | |
categ.values <- unlist(model$forest$xlevels[split.var]) | |
cat(paste("\n",indent.str,"CASE WHEN ", gsub("[.]","_",split.var), " IN ('", | |
paste(categ.values[categ.flags], sep="", collapse="', '"), #FIXME replace quotes dependant on var type | |
"') THEN ", sep="")) | |
recurse.rf(model, tree.data, tree.row[,"left daughter"], ind=(ind+1)) | |
cat(paste("\n",indent.str,"WHEN ", gsub("[.]","_",split.var), " IN ('", | |
paste(categ.values[!categ.flags], sep="", collapse="', '"), | |
"') THEN ", sep="")) | |
recurse.rf(model, tree.data, tree.row[,"right daughter"], ind=(ind+1)) | |
cat(paste("\n", indent.str,"ELSE NULL END ", sep="")) #FIXME: null or a new category | |
} | |
} else { # terminal node | |
if (is.numeric(tree.data$prediction)) { | |
cat(paste(tree.row[,"prediction"], " ", sep="")) | |
} else { | |
cat(paste("'", tree.row[,"prediction"], "' ", sep="")) | |
} | |
} | |
} | |
recurse.rf(model, getTree(model,k=tree.num,labelVar=TRUE), 1) | |
cat(paste("as tree",tree.num,"\nFROM ",input.table,";\n\n", sep="")) | |
} | |
if (model$type == "classification") { | |
# This code is not optimal but many SQL implementations do not support window functions (eg. SQLite) | |
# Had to remove use of WITH because not supported by all SQL variants | |
cat(paste("INSERT INTO rf_predictions\n", | |
"SELECT a.id, a.pred\n", | |
"FROM (SELECT ",id," as id, pred, COUNT(*) as cnt FROM tmp_rf GROUP BY ",id,", pred) a\n", | |
"INNER JOIN (SELECT id, MAX(cnt) as cnt\n", | |
"\t\t\tFROM (SELECT ",id," as id, pred, COUNT(*) as cnt FROM tmp_rf GROUP BY ",id,", pred) cc\n", | |
"\t\t\tGROUP BY id) b\n", | |
"ON a.id = b.id AND a.cnt = b.cnt;\n\n", sep="")) | |
} else { | |
cat(paste("INSERT INTO rf_predictions\n", | |
"SELECT ",id,", AVG(pred)\n", | |
"FROM tmp_rf\n", | |
"GROUP BY ",id,";\n\n", sep="")) | |
} | |
if (variant == "teradata") { | |
cat("DROP TABLE tmp_rf;\n\n") | |
} else { | |
cat("DROP TABLE IF EXISTS tmp_rf;\n\n") | |
} | |
# close the file | |
sink() | |
} |
This comment has been minimized.
This comment has been minimized.
@pjankiewicz a function would probably be more efficient however there is no standard for SQL functions that would work across multiple platforms |
This comment has been minimized.
This comment has been minimized.
that is very useful. However, how can i get the probabilities instead of the output 0/1? |
This comment has been minimized.
This comment has been minimized.
How do you get the probabilities from Random Forest as the member levels or record levels? |
This comment has been minimized.
This comment has been minimized.
How do you get the probabilities from Random Forest as the member levels or record levels in R I meant? |
This comment has been minimized.
This comment has been minimized.
for anyone interested in doing something similar with xgboost, here you go: https://github.com/ras44/articles/blob/master/20181018_xgboost_scoring_via_sql.md |
This comment has been minimized.
This comment has been minimized.
great, thankyou @ras44! |
This comment has been minimized.
This comment has been minimized.
it gives an error because there are too many "case"; Msg 125, Level 15, State 4, Line 20 |
This comment has been minimized.
I was looking for a way to export R randomForest object. This is a smart approach. It probably would me more convenient to create an SQL function that accepts the variables as arguments. Eventually I would recommend using pmml R package to export randomForest to a standardized XML file.