Created
March 9, 2016 01:35
-
-
Save ledell/f3a87bd136ce06e0a5ff to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Extract cross-validated predicted values (in order of original rows) | |
h2o.cvpreds <- function(object, single_col = TRUE) { | |
# TO DO: Check that object is an H2OModel | |
# TO DO: Check that keep_cross_validation_predictions = TRUE in the model | |
# TO DO: Need to add support for returning a multiclass prediction and binary (full frame: predict, p0, p1) | |
# TO DO: Remove family variable and just check class(object) directly | |
# Need to extract family from model object | |
if (class(object) == "H2OBinomialModel") family <- "binomial" | |
if (class(object) == "H2OMulticlassModel") family <- "multinomial" | |
if (class(object) == "H2ORegressionModel") family <- "gaussian" | |
# return the frame_id of the resulting 1-col Hdf of cvpreds for learner l | |
V <- object@allparameters$nfolds | |
if (single_col) { | |
if (family %in% c("bernoulli", "binomial")) { | |
predlist <- sapply(1:V, function(v) h2o.getFrame(object@model$cross_validation_predictions[[v]]$name)[,3], simplify = FALSE) | |
} else { | |
predlist <- sapply(1:V, function(v) h2o.getFrame(object@model$cross_validation_predictions[[v]]$name)$predict, simplify = FALSE) | |
} | |
cvpred_sparse <- h2o.cbind(predlist) #N x V H2OFrame with rows that are all zeros, except corresponding to the v^th fold if that rows is associated with v | |
cvpreds <- apply(cvpred_sparse, 1, sum) | |
} else { | |
stop("single_col = FALSE not yet implemented") | |
} | |
return(cvpreds) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment