-
-
Save shanebutler/5456942 to your computer and use it in GitHub Desktop.
# sql.export.gbm(): save a GBM model as SQL | |
# v0.11 | |
# Copyright (c) 2013-2014 Shane Butler <shane dot butler at gmail dot com> | |
# | |
# sql.export.gbm is free software: you can redistribute it and/or modify it | |
# under the terms of the GNU General Public License as published by | |
# the Free Software Foundation, either version 2 of the License, or | |
# (at your option) any later version. | |
# | |
# sql.export.gbm is distributed in the hope that it will be useful, but | |
# WITHOUT ANY WARRANTY; without even the implied warranty of | |
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU | |
# General Public License for more details. | |
# | |
# You should have received a copy of the GNU General Public License | |
# along with sql.export.gbm. If not, see <http://www.gnu.org/licenses/>. | |
# | |
# | |
## USAGE: | |
# sql.export.gbm(gbm1, file="model_output.SQL", input.table="data", id="id", n.trees=50) | |
# | |
## ARGUMENTS: | |
# trees.per.query: Optional argument to speed up the scoring by batching multiple trees per query. It is recommended to set this higher than 1. Setting it too high however may cause your the query to fail if the SQL engine can't handle the complexity. | |
# variant: Optional argument for Teradata variant="teradata" | |
# | |
## NOTE: | |
# Due to a bug in predict.gbm() in versions prior to 2.0-9.5, the results may not align if factors are used and the new data has more or less levels | |
# http://code.google.com/p/gradientboostedmodels/issues/detail?id=7 | |
# http://stats.stackexchange.com/questions/23793/why-does-gbm-predict-different-values-for-the-same-data | |
sql.export.gbm <- function (g, file, input.table="source_table", n.trees=NULL, id="id", trees.per.query=1, variant="generic") { | |
require (gbm, quietly=TRUE) | |
if (class(g) != "gbm") { | |
stop ("Expected a GBM object") | |
return | |
} | |
if (g$distribution$name == "multinomial") { | |
stop (paste("Unsupported distribution ", g$distribution$name)) | |
} else if (!(g$distribution$name %in% c("gaussian", "bernoulli", "laplace"))) { | |
# Untested but should work: "tdist","huberized","poisson","coxph","quantile","pairwise", "gamma","tweedie" | |
warning (paste("Untested distribution", g$distribution$name, "- please validate your model works as expected")) | |
} | |
# TODO: leverage gbm.pref() to find the optimal number of trees? | |
if (is.null(n.trees)) { | |
n.trees <- g$n.trees | |
} | |
indent = "" | |
sink(file, type="output") | |
if (variant == "teradata") { | |
cat("CREATE VOLATILE TABLE gbm_predictions (\n\t", id, " INT NOT NULL,\n\tpred float\n) ON COMMIT PRESERVE ROWS;\n\n", sep="") | |
cat("CREATE MULTISET VOLATILE TABLE tmp_gbm (\n\t", id, " INT NOT NULL,\n\tpred float\n) ON COMMIT PRESERVE ROWS;\n\n",sep="") | |
} else { | |
cat("CREATE TABLE gbm_predictions (\n\t", id, " INT NOT NULL,\n\tpred float\n);\n\n", sep="") | |
cat("DROP TABLE IF EXISTS tmp_gbm;\n\n") | |
cat("CREATE TABLE tmp_gbm (\n\t", id, " INT NOT NULL,\n\tpred float\n);\n\n",sep="") | |
} | |
recurse.gbm.tree <- function (tree, leaf=0, indent="") { | |
if (leaf > -1) { | |
split.var <- tree[leaf+1,"SplitVar"] | |
split.var.name <- g$var.names[split.var+1] | |
if (tree[leaf+1,"SplitVar"] == -1) { | |
cat(tree[leaf+1,"Prediction"]) | |
} else { | |
if (leaf == 0) { | |
cat("\n", indent, "(CASE ") | |
} else { | |
cat("\n", indent, "(CASE ") | |
} | |
if (attr(g$Terms, "dataClasses")[[split.var.name]] == "factor") { | |
val.index <- tree[leaf+1,"SplitCodePred"] | |
categories <- unlist(g$var.levels[split.var+1]) | |
if (tree[leaf+1,"LeftNode"] != -1) { | |
cat("WHEN", gsub("[.]","_",split.var.name), "in (", | |
paste("'", categories[g$c.split[[val.index+1]]==-1], "'", sep="", collapse=","), ") THEN ") | |
recurse.gbm.tree(tree, tree[leaf+1,"LeftNode"], paste(indent," ")) | |
indent = paste(indent," ") | |
} | |
if (tree[leaf+1,"RightNode"] != -1) { | |
cat("\n", indent, "WHEN", gsub("[.]","_",split.var.name), "in (", | |
paste("'", categories[g$c.split[[val.index+1]]==1], "'", sep="", collapse=","), ") THEN ") | |
recurse.gbm.tree(tree, tree[leaf+1,"RightNode"], paste(indent," ")) | |
} | |
if (tree[leaf+1,"MissingNode"] != -1) { | |
cat("\n", indent, "ELSE ") | |
recurse.gbm.tree(tree, tree[leaf+1,"MissingNode"], paste(indent," ")) | |
} | |
} else { | |
if (tree[leaf+1,"MissingNode"] != -1) { | |
cat("WHEN", gsub("[.]","_",split.var.name), "IS NULL THEN ") | |
recurse.gbm.tree(tree, tree[leaf+1,"MissingNode"], paste(indent," ")) | |
indent = paste(indent," ") | |
} | |
if (tree[leaf+1,"LeftNode"] != -1) { | |
cat("\n", indent, "WHEN", gsub("[.]","_",split.var.name), "< ", tree[leaf+1,"SplitCodePred"], " THEN ") | |
recurse.gbm.tree(tree, tree[leaf+1,"LeftNode"], paste(indent," ")) | |
} | |
if (tree[leaf+1,"RightNode"] != -1) { | |
cat("\n", indent, "WHEN", gsub("[.]","_",split.var.name), ">= ", tree[leaf+1,"SplitCodePred"], " THEN ") | |
recurse.gbm.tree(tree, tree[leaf+1,"RightNode"], paste(indent," ")) | |
} | |
} | |
cat(" END)") | |
} | |
} | |
} | |
print.breaks <- unique(c(seq(1,n.trees,trees.per.query), n.trees)) | |
lapply(1:(length(print.breaks)-1), function (print.tree.grp) { | |
cat("INSERT INTO tmp_gbm\nSELECT ", id, ",", sep="") | |
if ((print.tree.grp+1) == length(print.breaks)) { | |
last.tree <- print.breaks[print.tree.grp+1] | |
} else { | |
last.tree <- print.breaks[print.tree.grp+1] - 1 | |
} | |
invisible(lapply(print.breaks[print.tree.grp]:last.tree, function (which.tree) { | |
recurse.gbm.tree(pretty.gbm.tree(g, which.tree), 0, indent) | |
cat (" + ") | |
})) | |
cat("0 as tree", print.tree.grp, "\n", sep="") | |
cat("FROM ", input.table, ";\n\n", sep="") | |
}) | |
indent = paste(indent," ") | |
if (g$distribution$name %in% c("bernoulli","pairwise")) { | |
cat("INSERT INTO gbm_predictions\nSELECT ", id, ", 1/(1 + exp(-(", g$initF," + SUM(pred)))) as pred\nFROM tmp_gbm\nGROUP BY ", id, ";\n\n",sep="") | |
} else if (g$distribution$name %in% c("poisson","gamma","tweedie")) { | |
cat("INSERT INTO gbm_predictions\nSELECT ", id, ", exp(", g$initF,"+ SUM(pred)) as pred\nFROM tmp_gbm\nGROUP BY ", id, ";\n\n",sep="") | |
} else if (g$distribution$name %in% c("adaboost")) { | |
cat("INSERT INTO gbm_predictions\nSELECT ", id, ", 1 /(1 + exp(-2*(", g$initF," + SUM(pred)))) as pred\nFROM tmp_gbm\nGROUP BY ", id, ";\n\n",sep="") | |
} else { | |
cat("INSERT INTO gbm_predictions\nSELECT ", id, ", (", g$initF, " + SUM(pred)) as pred\nFROM tmp_gbm\nGROUP BY ", id, ";\n\n",sep="") | |
} | |
if (variant == "teradata") { | |
cat("DROP TABLE tmp_gbm;\n\n") | |
} else { | |
cat("DROP TABLE IF EXISTS tmp_gbm;\n\n") | |
} | |
# close the file | |
sink() | |
if (!is.null(attr(g$Terms, "offset"))) { | |
warning("offset not implemented") | |
} | |
} | |
I notice also that it works for the gbm() function with formula specified, but does not work for gbm.fit().
Thanks Nathaniel, I hadn't noticed the difference between gbm() and gbm.fit()
18/01/14: New version with a fix suggestion from Chris Pouliot for the batching of several trees in per query
@Rambeaux Please see https://github.com/az0/mlmeta for a GBM to SAS exporter
@shanebutler: Thank you for the code. Do you mind if I add sql.export.gbm.R and sql.export.randomForest.R to the mlmeta repository under the license GNU GPL version 2 or later? Eventually I would like to put the package on CRAN too. If you do not mind, let me know whether you want to do a pull request, or whether I should just copy it myself. Thank you
@az0: Thanks for getting in touch, your package sounds interesting. You are free to incorporate this code in your package under the terms of the license. Cheers, Shane
Thanks, small remark however. I changed attr(g$Terms, "dataClasses")[[split.var.name]] == "factor" by is.factor(my.table[split.var.name]) and it worked for gbm.fit.
Awesome, thanks for saving me a day. :) :) Works great with adaboost distribution stumps, mixed categorical and continuous for me, btw.
Great work Shane, What to do in case of Multinomial Distribution?
Hey Shane, do you have it written in Python as well ?
Thanks!
Hi Shane, having some trouble with your script. What exactly is the "id" argument supposed to be in regards to the model/data?
Hi @Spectrum2511 , "id" is a unique integer ID on the input table that will also be the key on the output table gbm_predictions . Make sense?
Thank you @shanebutler, makes perfect sense!
Brilliant!
Now for a SAS exporter... :-)