Created
April 9, 2020 22:34
-
-
Save ledell/3a3a0547529c311b913e2e8590a6d705 to your computer and use it in GitHub Desktop.
How to get k-fold metrics for all the H2O AutoML models in R
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# How to get k-fold metrics for all the H2O AutoML models in R | |
# Adapted from: http://docs.h2o.ai/h2o/latest-stable/h2o-docs/automl.html | |
library(h2o) | |
h2o.init() | |
# Import a sample binary outcome train/test set into H2O | |
train <- h2o.importFile("https://s3.amazonaws.com/erin-data/higgs/higgs_train_10k.csv") | |
test <- h2o.importFile("https://s3.amazonaws.com/erin-data/higgs/higgs_test_5k.csv") | |
# Identify predictors and response | |
y <- "response" | |
x <- setdiff(names(train), y) | |
# For binary classification, response should be a factor | |
train[,y] <- as.factor(train[,y]) | |
test[,y] <- as.factor(test[,y]) | |
# Run AutoML for 5 base models (limited to 1 hour max runtime by default) | |
aml <- h2o.automl(x = x, y = y, | |
training_frame = train, | |
max_models = 5, | |
seed = 1) | |
# View the AutoML Leaderboard | |
lb <- aml@leaderboard | |
print(lb, n = nrow(lb)) # Print all rows instead of default (6 rows) | |
# Get the model ids | |
model_ids <- as.data.frame(lb)$model_id | |
# Get per-fold stats (add more code to capture the actual values) | |
for (model_id in model_ids) { | |
cat("Model ID: ", model_id) | |
m <- h2o.getModel(model_id) | |
if (grepl("StackedEnsemble", model_id)) { | |
# if you have h2o v3.30.0.1, this works | |
print(m@model$metalearner_model@model$cross_validation_metrics_summary) | |
# otherwise, you have to do this: | |
#print(h2o.getModel(m@model$metalearner$name)@model$cross_validation_metrics_summary) | |
} else { | |
print(m@model$cross_validation_metrics_summary) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I think that using functions like
h2o.roc()
would be a better way of getting those values. What doesh2o.performance()
give you?