Created
July 3, 2014 00:28
-
-
Save lendle/ec3d92dca08a8de7f139 to your computer and use it in GitHub Desktop.
Create SuperLearner wrappers that run caret models
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
make_caret_wrappers <- function(alg_names) { | |
#creates an environment that stors SL wrappers. For a given <name> in alg_names, the wrapper will be | |
# named SL.caret.<name> | |
# alg_names should be a list of names of models provided by the caret pacakge. | |
# see http://caret.r-forge.r-project.org/modelList.html | |
# only set up for binomial family | |
# also, set up to use single split CV internally within caret. This can be changed by fiddling with fitControl, below | |
require("caret") | |
require("SuperLearner") | |
#use negative bern log likelihood for predicion | |
nbll_summary <- function(data, lev = NULL, model = NULL) { | |
if (length(lev) > 2) { | |
stop("negative Bernoulli log likelihood is only for 2 classes") | |
} | |
is_class1 <- ifelse(data$obs == lev[1], 1, 0) | |
eps = 0.01 | |
prob_class1 <- pmin(pmax(data[, lev[1]], eps), 1-eps) | |
c(nloglik = -2*mean(is_class1*log(prob_class1) + | |
((1-is_class1)*log(1-prob_class1))), | |
twoClassSummary(data, lev = lev)) | |
} | |
fitControl <- trainControl(method = "LGOCV", #Single split CV | |
number = 1, | |
p=0.9, | |
allowParallel = TRUE, | |
classProbs = TRUE, | |
returnData = FALSE, | |
summaryFunction = nbll_summary, | |
verboseIter=FALSE) | |
alg_env <- new.env() | |
for (name in alg_names) assign(paste("SL.caret", name, sep="."), function(Y, X, newX, family, obsWeights, id) | |
SL.caret(Y, X, newX, family, obsWeights, id, method=name, trControl=fitControl, metric="nloglik"), | |
envir=alg_env) | |
alg_env | |
} | |
#for example | |
alg_names <- c("bayesglm", "bdk", "C5.0", "C5.0Rules", "C5.0Tree", "ctree2", | |
"earth", "fda", "gamboost", "gbm", "gcvEarth", "glm", "glmboost", | |
"glmnet", "hda", "knn", "lda", "lda2", "LogitBoost", "mda", "multinom", | |
"nnet", "pam", "pcaNNet", "pda", "pda2", "plr", "qda", "rpart", | |
"rpartCost", "rrlda", "sda", "sddaLDA", "sddaQDA", "slda", "sparseLDA", | |
"spls", "xyf", "QdaCov") | |
wrappers_env <- make_caret_wrappers(alg_names) | |
SL.library <- ls(wrappers_env) | |
#SuperLearner searches for objects whos names are in SL.library, so we can | |
#attach the wrappers_env to the R's search path | |
# | |
#attach SuperLearner first, so if any names clash with those in wrappers_env, | |
#the ones in wrappers_env will be used | |
library(SuperLearner) | |
attach(wrappers_env) | |
test.NNloglik <- SampleSplitSuperLearner(Y = Y, X = X, SL.library = SL.library, | |
verbose = TRUE, method = "method.CC_nloglik", family = binomial(), | |
split=0.9) | |
cbind(round(test.NNloglik$coef, 4)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment