Skip to content

Instantly share code, notes, and snippets.

@jonrobinson2
Forked from shanebutler/sql.export.gbm.R
Last active August 29, 2015 14:07
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jonrobinson2/0dcc99cffe12a106447a to your computer and use it in GitHub Desktop.
Save jonrobinson2/0dcc99cffe12a106447a to your computer and use it in GitHub Desktop.
# sql.export.gbm(): save a GBM model as SQL
# v0.11
# Copyright (c) 2013-2014 Shane Butler <shane dot butler at gmail dot com>
#
# sql.export.gbm is free software: you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 2 of the License, or
# (at your option) any later version.
#
# sql.export.gbm is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with sql.export.gbm. If not, see <http://www.gnu.org/licenses/>.
#
#
## USAGE:
# sql.export.gbm(gbm1, file="model_output.SQL", input.table="data", id="id", n.trees=50)
#
## ARGUMENTS:
# trees.per.query: Optional argument to speed up the scoring by batching multiple trees per query. It is recommended to set this higher than 1. Setting it too high however may cause your the query to fail if the SQL engine can't handle the complexity.
# variant: Optional argument for Teradata variant="teradata"
#
## NOTE:
# Due to a bug in predict.gbm() in versions prior to 2.0-9.5, the results may not align if factors are used and the new data has more or less levels
# http://code.google.com/p/gradientboostedmodels/issues/detail?id=7
# http://stats.stackexchange.com/questions/23793/why-does-gbm-predict-different-values-for-the-same-data
sql.export.gbm <- function (g, file, input.table="source_table", n.trees=NULL, id="id", trees.per.query=1, variant="generic") {
require (gbm, quietly=TRUE)
if (class(g) != "gbm") {
stop ("Expected a GBM object")
return
}
if (g$distribution$name == "multinomial") {
stop (paste("Unsupported distribution ", g$distribution$name))
} else if (!(g$distribution$name %in% c("gaussian", "bernoulli", "laplace"))) {
# Untested but should work: "tdist","huberized","poisson","coxph","quantile","pairwise", "gamma","tweedie"
warning (paste("Untested distribution", g$distribution$name, "- please validate your model works as expected"))
}
# TODO: leverage gbm.pref() to find the optimal number of trees?
if (is.null(n.trees)) {
n.trees <- g$n.trees
}
indent = ""
sink(file, type="output")
if (variant == "teradata") {
cat("CREATE VOLATILE TABLE gbm_predictions (\n\t", id, " INT NOT NULL,\n\tpred float\n) ON COMMIT PRESERVE ROWS;\n\n", sep="")
cat("CREATE MULTISET VOLATILE TABLE tmp_gbm (\n\t", id, " INT NOT NULL,\n\tpred float\n) ON COMMIT PRESERVE ROWS;\n\n",sep="")
} else {
cat("CREATE TABLE gbm_predictions (\n\t", id, " INT NOT NULL,\n\tpred float\n);\n\n", sep="")
cat("DROP TABLE IF EXISTS tmp_gbm;\n\n")
cat("CREATE TABLE tmp_gbm (\n\t", id, " INT NOT NULL,\n\tpred float\n);\n\n",sep="")
}
recurse.gbm.tree <- function (tree, leaf=0, indent="") {
if (leaf > -1) {
split.var <- tree[leaf+1,"SplitVar"]
split.var.name <- g$var.names[split.var+1]
if (tree[leaf+1,"SplitVar"] == -1) {
cat(tree[leaf+1,"Prediction"])
} else {
if (leaf == 0) {
cat("\n", indent, "(CASE ")
} else {
cat("\n", indent, "(CASE ")
}
if (attr(g$Terms, "dataClasses")[[split.var.name]] == "factor") {
val.index <- tree[leaf+1,"SplitCodePred"]
categories <- unlist(g$var.levels[split.var+1])
if (tree[leaf+1,"LeftNode"] != -1) {
cat("WHEN", gsub("[.]","_",split.var.name), "in (",
paste("'", categories[g$c.split[[val.index+1]]==-1], "'", sep="", collapse=","), ") THEN ")
recurse.gbm.tree(tree, tree[leaf+1,"LeftNode"], paste(indent," "))
indent = paste(indent," ")
}
if (tree[leaf+1,"RightNode"] != -1) {
cat("\n", indent, "WHEN", gsub("[.]","_",split.var.name), "in (",
paste("'", categories[g$c.split[[val.index+1]]==1], "'", sep="", collapse=","), ") THEN ")
recurse.gbm.tree(tree, tree[leaf+1,"RightNode"], paste(indent," "))
}
if (tree[leaf+1,"MissingNode"] != -1) {
cat("\n", indent, "ELSE ")
recurse.gbm.tree(tree, tree[leaf+1,"MissingNode"], paste(indent," "))
}
} else {
if (tree[leaf+1,"MissingNode"] != -1) {
cat("WHEN", gsub("[.]","_",split.var.name), "IS NULL THEN ")
recurse.gbm.tree(tree, tree[leaf+1,"MissingNode"], paste(indent," "))
indent = paste(indent," ")
}
if (tree[leaf+1,"LeftNode"] != -1) {
cat("\n", indent, "WHEN", gsub("[.]","_",split.var.name), "< ", tree[leaf+1,"SplitCodePred"], " THEN ")
recurse.gbm.tree(tree, tree[leaf+1,"LeftNode"], paste(indent," "))
}
if (tree[leaf+1,"RightNode"] != -1) {
cat("\n", indent, "WHEN", gsub("[.]","_",split.var.name), ">= ", tree[leaf+1,"SplitCodePred"], " THEN ")
recurse.gbm.tree(tree, tree[leaf+1,"RightNode"], paste(indent," "))
}
}
cat(" END)")
}
}
}
print.breaks <- unique(c(seq(1,n.trees,trees.per.query), n.trees))
lapply(1:(length(print.breaks)-1), function (print.tree.grp) {
cat("INSERT INTO tmp_gbm\nSELECT ", id, ",", sep="")
if ((print.tree.grp+1) == length(print.breaks)) {
last.tree <- print.breaks[print.tree.grp+1]
} else {
last.tree <- print.breaks[print.tree.grp+1] - 1
}
invisible(lapply(print.breaks[print.tree.grp]:last.tree, function (which.tree) {
recurse.gbm.tree(pretty.gbm.tree(g, which.tree), 0, indent)
cat (" + ")
}))
cat("0 as tree", print.tree.grp, "\n", sep="")
cat("FROM ", input.table, ";\n\n", sep="")
})
indent = paste(indent," ")
if (g$distribution$name %in% c("bernoulli","pairwise")) {
cat("INSERT INTO gbm_predictions\nSELECT ", id, ", 1/(1 + exp(-(", g$initF," + SUM(pred)))) as pred\nFROM tmp_gbm\nGROUP BY ", id, ";\n\n",sep="")
} else if (g$distribution$name %in% c("poisson","gamma","tweedie")) {
cat("INSERT INTO gbm_predictions\nSELECT ", id, ", exp(", g$initF,"+ SUM(pred)) as pred\nFROM tmp_gbm\nGROUP BY ", id, ";\n\n",sep="")
} else if (g$distribution$name %in% c("adaboost")) {
cat("INSERT INTO gbm_predictions\nSELECT ", id, ", 1 /(1 + exp(-2*(", g$initF," + SUM(pred)))) as pred\nFROM tmp_gbm\nGROUP BY ", id, ";\n\n",sep="")
} else {
cat("INSERT INTO gbm_predictions\nSELECT ", id, ", (", g$initF, " + SUM(pred)) as pred\nFROM tmp_gbm\nGROUP BY ", id, ";\n\n",sep="")
}
if (variant == "teradata") {
cat("DROP TABLE tmp_gbm;\n\n")
} else {
cat("DROP TABLE IF EXISTS tmp_gbm;\n\n")
}
# close the file
sink()
if (!is.null(attr(g$Terms, "offset"))) {
warning("offset not implemented")
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment