Skip to content

Instantly share code, notes, and snippets.

@shanebutler
Last active July 4, 2023 08:10
Show Gist options
  • Star 16 You must be signed in to star a gist
  • Fork 13 You must be signed in to fork a gist
  • Save shanebutler/5456942 to your computer and use it in GitHub Desktop.
Save shanebutler/5456942 to your computer and use it in GitHub Desktop.
Deploy your GBM models in SQL! This tool enables in-database scoring of GBM models built using R. To use it, you simply call the function with the GBM model, output filename, SQL input data table and the name of the unique key on that table. For example:sql.export.gbm(gbm1, file="model_output.SQL", input.table="source_table", id="id") Please let…
# sql.export.gbm(): save a GBM model as SQL
# v0.11
# Copyright (c) 2013-2014 Shane Butler <shane dot butler at gmail dot com>
#
# sql.export.gbm is free software: you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 2 of the License, or
# (at your option) any later version.
#
# sql.export.gbm is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with sql.export.gbm. If not, see <http://www.gnu.org/licenses/>.
#
#
## USAGE:
# sql.export.gbm(gbm1, file="model_output.SQL", input.table="data", id="id", n.trees=50)
#
## ARGUMENTS:
# trees.per.query: Optional argument to speed up the scoring by batching multiple trees per query. It is recommended to set this higher than 1. Setting it too high however may cause your the query to fail if the SQL engine can't handle the complexity.
# variant: Optional argument for Teradata variant="teradata"
#
## NOTE:
# Due to a bug in predict.gbm() in versions prior to 2.0-9.5, the results may not align if factors are used and the new data has more or less levels
# http://code.google.com/p/gradientboostedmodels/issues/detail?id=7
# http://stats.stackexchange.com/questions/23793/why-does-gbm-predict-different-values-for-the-same-data
sql.export.gbm <- function (g, file, input.table="source_table", n.trees=NULL, id="id", trees.per.query=1, variant="generic") {
require (gbm, quietly=TRUE)
if (class(g) != "gbm") {
stop ("Expected a GBM object")
return
}
if (g$distribution$name == "multinomial") {
stop (paste("Unsupported distribution ", g$distribution$name))
} else if (!(g$distribution$name %in% c("gaussian", "bernoulli", "laplace"))) {
# Untested but should work: "tdist","huberized","poisson","coxph","quantile","pairwise", "gamma","tweedie"
warning (paste("Untested distribution", g$distribution$name, "- please validate your model works as expected"))
}
# TODO: leverage gbm.pref() to find the optimal number of trees?
if (is.null(n.trees)) {
n.trees <- g$n.trees
}
indent = ""
sink(file, type="output")
if (variant == "teradata") {
cat("CREATE VOLATILE TABLE gbm_predictions (\n\t", id, " INT NOT NULL,\n\tpred float\n) ON COMMIT PRESERVE ROWS;\n\n", sep="")
cat("CREATE MULTISET VOLATILE TABLE tmp_gbm (\n\t", id, " INT NOT NULL,\n\tpred float\n) ON COMMIT PRESERVE ROWS;\n\n",sep="")
} else {
cat("CREATE TABLE gbm_predictions (\n\t", id, " INT NOT NULL,\n\tpred float\n);\n\n", sep="")
cat("DROP TABLE IF EXISTS tmp_gbm;\n\n")
cat("CREATE TABLE tmp_gbm (\n\t", id, " INT NOT NULL,\n\tpred float\n);\n\n",sep="")
}
recurse.gbm.tree <- function (tree, leaf=0, indent="") {
if (leaf > -1) {
split.var <- tree[leaf+1,"SplitVar"]
split.var.name <- g$var.names[split.var+1]
if (tree[leaf+1,"SplitVar"] == -1) {
cat(tree[leaf+1,"Prediction"])
} else {
if (leaf == 0) {
cat("\n", indent, "(CASE ")
} else {
cat("\n", indent, "(CASE ")
}
if (attr(g$Terms, "dataClasses")[[split.var.name]] == "factor") {
val.index <- tree[leaf+1,"SplitCodePred"]
categories <- unlist(g$var.levels[split.var+1])
if (tree[leaf+1,"LeftNode"] != -1) {
cat("WHEN", gsub("[.]","_",split.var.name), "in (",
paste("'", categories[g$c.split[[val.index+1]]==-1], "'", sep="", collapse=","), ") THEN ")
recurse.gbm.tree(tree, tree[leaf+1,"LeftNode"], paste(indent," "))
indent = paste(indent," ")
}
if (tree[leaf+1,"RightNode"] != -1) {
cat("\n", indent, "WHEN", gsub("[.]","_",split.var.name), "in (",
paste("'", categories[g$c.split[[val.index+1]]==1], "'", sep="", collapse=","), ") THEN ")
recurse.gbm.tree(tree, tree[leaf+1,"RightNode"], paste(indent," "))
}
if (tree[leaf+1,"MissingNode"] != -1) {
cat("\n", indent, "ELSE ")
recurse.gbm.tree(tree, tree[leaf+1,"MissingNode"], paste(indent," "))
}
} else {
if (tree[leaf+1,"MissingNode"] != -1) {
cat("WHEN", gsub("[.]","_",split.var.name), "IS NULL THEN ")
recurse.gbm.tree(tree, tree[leaf+1,"MissingNode"], paste(indent," "))
indent = paste(indent," ")
}
if (tree[leaf+1,"LeftNode"] != -1) {
cat("\n", indent, "WHEN", gsub("[.]","_",split.var.name), "< ", tree[leaf+1,"SplitCodePred"], " THEN ")
recurse.gbm.tree(tree, tree[leaf+1,"LeftNode"], paste(indent," "))
}
if (tree[leaf+1,"RightNode"] != -1) {
cat("\n", indent, "WHEN", gsub("[.]","_",split.var.name), ">= ", tree[leaf+1,"SplitCodePred"], " THEN ")
recurse.gbm.tree(tree, tree[leaf+1,"RightNode"], paste(indent," "))
}
}
cat(" END)")
}
}
}
print.breaks <- unique(c(seq(1,n.trees,trees.per.query), n.trees))
lapply(1:(length(print.breaks)-1), function (print.tree.grp) {
cat("INSERT INTO tmp_gbm\nSELECT ", id, ",", sep="")
if ((print.tree.grp+1) == length(print.breaks)) {
last.tree <- print.breaks[print.tree.grp+1]
} else {
last.tree <- print.breaks[print.tree.grp+1] - 1
}
invisible(lapply(print.breaks[print.tree.grp]:last.tree, function (which.tree) {
recurse.gbm.tree(pretty.gbm.tree(g, which.tree), 0, indent)
cat (" + ")
}))
cat("0 as tree", print.tree.grp, "\n", sep="")
cat("FROM ", input.table, ";\n\n", sep="")
})
indent = paste(indent," ")
if (g$distribution$name %in% c("bernoulli","pairwise")) {
cat("INSERT INTO gbm_predictions\nSELECT ", id, ", 1/(1 + exp(-(", g$initF," + SUM(pred)))) as pred\nFROM tmp_gbm\nGROUP BY ", id, ";\n\n",sep="")
} else if (g$distribution$name %in% c("poisson","gamma","tweedie")) {
cat("INSERT INTO gbm_predictions\nSELECT ", id, ", exp(", g$initF,"+ SUM(pred)) as pred\nFROM tmp_gbm\nGROUP BY ", id, ";\n\n",sep="")
} else if (g$distribution$name %in% c("adaboost")) {
cat("INSERT INTO gbm_predictions\nSELECT ", id, ", 1 /(1 + exp(-2*(", g$initF," + SUM(pred)))) as pred\nFROM tmp_gbm\nGROUP BY ", id, ";\n\n",sep="")
} else {
cat("INSERT INTO gbm_predictions\nSELECT ", id, ", (", g$initF, " + SUM(pred)) as pred\nFROM tmp_gbm\nGROUP BY ", id, ";\n\n",sep="")
}
if (variant == "teradata") {
cat("DROP TABLE tmp_gbm;\n\n")
} else {
cat("DROP TABLE IF EXISTS tmp_gbm;\n\n")
}
# close the file
sink()
if (!is.null(attr(g$Terms, "offset"))) {
warning("offset not implemented")
}
}
@Rambeaux
Copy link

Brilliant!
Now for a SAS exporter... :-)

@Rambeaux
Copy link

I notice also that it works for the gbm() function with formula specified, but does not work for gbm.fit().

@shanebutler
Copy link
Author

Thanks Nathaniel, I hadn't noticed the difference between gbm() and gbm.fit()

@shanebutler
Copy link
Author

18/01/14: New version with a fix suggestion from Chris Pouliot for the batching of several trees in per query

@az0
Copy link

az0 commented Oct 29, 2014

@Rambeaux Please see https://github.com/az0/mlmeta for a GBM to SAS exporter

@shanebutler: Thank you for the code. Do you mind if I add sql.export.gbm.R and sql.export.randomForest.R to the mlmeta repository under the license GNU GPL version 2 or later? Eventually I would like to put the package on CRAN too. If you do not mind, let me know whether you want to do a pull request, or whether I should just copy it myself. Thank you

@shanebutler
Copy link
Author

@az0: Thanks for getting in touch, your package sounds interesting. You are free to incorporate this code in your package under the terms of the license. Cheers, Shane

@YvesCR
Copy link

YvesCR commented Jan 11, 2016

Thanks, small remark however. I changed attr(g$Terms, "dataClasses")[[split.var.name]] == "factor" by is.factor(my.table[split.var.name]) and it worked for gbm.fit.

@cturner500
Copy link

Awesome, thanks for saving me a day. :) :) Works great with adaboost distribution stumps, mixed categorical and continuous for me, btw.

@mohit879
Copy link

mohit879 commented Jun 7, 2017

Great work Shane, What to do in case of Multinomial Distribution?

@mohit879
Copy link

Hey Shane, do you have it written in Python as well ?
Thanks!

@Spectrum2511
Copy link

Hi Shane, having some trouble with your script. What exactly is the "id" argument supposed to be in regards to the model/data?

@shanebutler
Copy link
Author

Hi @Spectrum2511 , "id" is a unique integer ID on the input table that will also be the key on the output table gbm_predictions . Make sense?

@Spectrum2511
Copy link

Thank you @shanebutler, makes perfect sense!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment