Skip to content

Instantly share code, notes, and snippets.

@ledell
Created March 9, 2016 01:35
Show Gist options
  • Save ledell/f3a87bd136ce06e0a5ff to your computer and use it in GitHub Desktop.
Save ledell/f3a87bd136ce06e0a5ff to your computer and use it in GitHub Desktop.
# Extract cross-validated predicted values (in order of original rows)
h2o.cvpreds <- function(object, single_col = TRUE) {
# TO DO: Check that object is an H2OModel
# TO DO: Check that keep_cross_validation_predictions = TRUE in the model
# TO DO: Need to add support for returning a multiclass prediction and binary (full frame: predict, p0, p1)
# TO DO: Remove family variable and just check class(object) directly
# Need to extract family from model object
if (class(object) == "H2OBinomialModel") family <- "binomial"
if (class(object) == "H2OMulticlassModel") family <- "multinomial"
if (class(object) == "H2ORegressionModel") family <- "gaussian"
# return the frame_id of the resulting 1-col Hdf of cvpreds for learner l
V <- object@allparameters$nfolds
if (single_col) {
if (family %in% c("bernoulli", "binomial")) {
predlist <- sapply(1:V, function(v) h2o.getFrame(object@model$cross_validation_predictions[[v]]$name)[,3], simplify = FALSE)
} else {
predlist <- sapply(1:V, function(v) h2o.getFrame(object@model$cross_validation_predictions[[v]]$name)$predict, simplify = FALSE)
}
cvpred_sparse <- h2o.cbind(predlist) #N x V H2OFrame with rows that are all zeros, except corresponding to the v^th fold if that rows is associated with v
cvpreds <- apply(cvpred_sparse, 1, sum)
} else {
stop("single_col = FALSE not yet implemented")
}
return(cvpreds)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment