Skip to content

Instantly share code, notes, and snippets.

@ledell
Created April 9, 2020 22:34
Show Gist options
  • Save ledell/3a3a0547529c311b913e2e8590a6d705 to your computer and use it in GitHub Desktop.
Save ledell/3a3a0547529c311b913e2e8590a6d705 to your computer and use it in GitHub Desktop.
How to get k-fold metrics for all the H2O AutoML models in R
# How to get k-fold metrics for all the H2O AutoML models in R
# Adapted from: http://docs.h2o.ai/h2o/latest-stable/h2o-docs/automl.html
library(h2o)
h2o.init()
# Import a sample binary outcome train/test set into H2O
train <- h2o.importFile("https://s3.amazonaws.com/erin-data/higgs/higgs_train_10k.csv")
test <- h2o.importFile("https://s3.amazonaws.com/erin-data/higgs/higgs_test_5k.csv")
# Identify predictors and response
y <- "response"
x <- setdiff(names(train), y)
# For binary classification, response should be a factor
train[,y] <- as.factor(train[,y])
test[,y] <- as.factor(test[,y])
# Run AutoML for 5 base models (limited to 1 hour max runtime by default)
aml <- h2o.automl(x = x, y = y,
training_frame = train,
max_models = 5,
seed = 1)
# View the AutoML Leaderboard
lb <- aml@leaderboard
print(lb, n = nrow(lb)) # Print all rows instead of default (6 rows)
# Get the model ids
model_ids <- as.data.frame(lb)$model_id
# Get per-fold stats (add more code to capture the actual values)
for (model_id in model_ids) {
cat("Model ID: ", model_id)
m <- h2o.getModel(model_id)
if (grepl("StackedEnsemble", model_id)) {
# if you have h2o v3.30.0.1, this works
print(m@model$metalearner_model@model$cross_validation_metrics_summary)
# otherwise, you have to do this:
#print(h2o.getModel(m@model$metalearner$name)@model$cross_validation_metrics_summary)
} else {
print(m@model$cross_validation_metrics_summary)
}
}
@alexpghayes
Copy link

Some helper code if anybody else wants to do computation on these numbers. There's a weird thing that's happening bringing the h2o dataframe back into R where all the columns have character type rather than numeric. Currently I'm hacking around this with readr::parse_number().

library(tidyverse)

get_h2o_cv_summary <- function(model_id) {

  m <- h2o.getModel(model_id)

  if (grepl("StackedEnsemble", model_id)) {
    cv <- m@model$metalearner_model@model$cross_validation_metrics_summary
  } else {
    cv <- m@model$cross_validation_metrics_summary
  }

  as.data.frame(cv) %>%
    mutate_all(readr::parse_number) %>%  # h2o frame imports chr columns back into R
    mutate(model_id = model_id)
}

automl_metrics <- function(aml) {
  lb <- aml@leaderboard
  model_ids <- as.data.frame(lb)$model_id
  map_dfr(model_ids, get_h2o_cv_summary)
}

automl_metrics(aml)

@topepo
Copy link

topepo commented Apr 24, 2020

I think that using functions like h2o.roc() would be a better way of getting those values. What does h2o.performance() give you?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment