Created
April 9, 2020 22:34
-
-
Save ledell/3a3a0547529c311b913e2e8590a6d705 to your computer and use it in GitHub Desktop.
How to get k-fold metrics for all the H2O AutoML models in R
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# How to get k-fold metrics for all the H2O AutoML models in R | |
# Adapted from: http://docs.h2o.ai/h2o/latest-stable/h2o-docs/automl.html | |
library(h2o) | |
h2o.init() | |
# Import a sample binary outcome train/test set into H2O | |
train <- h2o.importFile("https://s3.amazonaws.com/erin-data/higgs/higgs_train_10k.csv") | |
test <- h2o.importFile("https://s3.amazonaws.com/erin-data/higgs/higgs_test_5k.csv") | |
# Identify predictors and response | |
y <- "response" | |
x <- setdiff(names(train), y) | |
# For binary classification, response should be a factor | |
train[,y] <- as.factor(train[,y]) | |
test[,y] <- as.factor(test[,y]) | |
# Run AutoML for 5 base models (limited to 1 hour max runtime by default) | |
aml <- h2o.automl(x = x, y = y, | |
training_frame = train, | |
max_models = 5, | |
seed = 1) | |
# View the AutoML Leaderboard | |
lb <- aml@leaderboard | |
print(lb, n = nrow(lb)) # Print all rows instead of default (6 rows) | |
# Get the model ids | |
model_ids <- as.data.frame(lb)$model_id | |
# Get per-fold stats (add more code to capture the actual values) | |
for (model_id in model_ids) { | |
cat("Model ID: ", model_id) | |
m <- h2o.getModel(model_id) | |
if (grepl("StackedEnsemble", model_id)) { | |
# if you have h2o v3.30.0.1, this works | |
print(m@model$metalearner_model@model$cross_validation_metrics_summary) | |
# otherwise, you have to do this: | |
#print(h2o.getModel(m@model$metalearner$name)@model$cross_validation_metrics_summary) | |
} else { | |
print(m@model$cross_validation_metrics_summary) | |
} | |
} |
I think that using functions like h2o.roc()
would be a better way of getting those values. What does h2o.performance()
give you?
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Some helper code if anybody else wants to do computation on these numbers. There's a weird thing that's happening bringing the h2o dataframe back into R where all the columns have character type rather than numeric. Currently I'm hacking around this with
readr::parse_number()
.