Skip to content

Instantly share code, notes, and snippets.

@ledell
Last active May 8, 2020 03:51
Show Gist options
  • Save ledell/ba1fcc3ca1c8cf7db5147c7a54849578 to your computer and use it in GitHub Desktop.
Save ledell/ba1fcc3ca1c8cf7db5147c7a54849578 to your computer and use it in GitHub Desktop.
Stacked Ensembles with clustered observations (pooled repeated measures data) in H2O
# Example of how to do Stacking using clustered (aka. "pooled repeated measures") data:
# Since stacking uses cross-validation, we must ensure that the observations from
# the same clusters are all in the same fold. We borrow the SuperLearner::CVFolds()
# function and use H2O Stacked Ensembles and AutoML to train stacked ensembles.
library(SuperLearner)
library(h2o)
h2o.init()
# Import a sample binary outcome train/test set into H2O
train <- h2o.importFile("https://s3.amazonaws.com/erin-data/higgs/higgs_train_10k.csv")
test <- h2o.importFile("https://s3.amazonaws.com/erin-data/higgs/higgs_test_5k.csv")
# Identify predictors and response
y <- "response"
x <- setdiff(names(train), y)
# For binary classification, response should be a factor
train[,y] <- as.factor(train[,y])
test[,y] <- as.factor(test[,y])
# Number of CV folds
nfolds <- 5
# Let's make some cluster IDs (30 clusters)
# Our dataset did not come with an ID/cluster column, so we created a dummy one here
n_clusters <- 30
ids <- as.factor(rep(c(1:n_clusters), length.out = nrow(train)))
# Add the fold column (borrow the CVFolds function from SL and reshape it into fold_column format)
# This assigns observations to folds, while making sure that clusters stay together
folds <- SuperLearner::CVFolds(N = nrow(train),
id = ids,
Y = train[,y],
cvControl = list(V = nfolds, stratifyCV = FALSE, shuffle = TRUE))
convert_foldlist_to_vec <- function(folds) {
V <- length(folds)
N <- length(unlist(folds))
fold_column <- rep(NA, N)
for (i in 1:V) {
fold_column[folds[[i]]] <- i
}
return(fold_column)
}
fold_column <- convert_foldlist_to_vec(folds)
train$fold_id <- as.h2o(fold_column)
# Train & Cross-validate a GBM
my_gbm <- h2o.gbm(x = x,
y = y,
training_frame = train,
fold_column = "fold_id",
keep_cross_validation_predictions = TRUE,
seed = 1)
# Train & Cross-validate a RF
my_rf <- h2o.randomForest(x = x,
y = y,
training_frame = train,
fold_column = "fold_id",
keep_cross_validation_predictions = TRUE,
seed = 1)
# Train a stacked ensemble using the GBM and RF above
# http://docs.h2o.ai/h2o/latest-stable/h2o-docs/data-science/stacked-ensembles.html
ensemble <- h2o.stackedEnsemble(x = x,
y = y,
training_frame = train,
base_models = list(my_gbm, my_rf))
# Eval ensemble performance on a test set
perf <- h2o.performance(ensemble, newdata = test)
# Compare to base learner performance on the test set
perf_gbm_test <- h2o.performance(my_gbm, newdata = test)
perf_rf_test <- h2o.performance(my_rf, newdata = test)
baselearner_best_auc_test <- max(h2o.auc(perf_gbm_test), h2o.auc(perf_rf_test))
ensemble_auc_test <- h2o.auc(perf)
print(sprintf("Best Base-learner Test AUC: %s", baselearner_best_auc_test))
print(sprintf("Ensemble Test AUC: %s", ensemble_auc_test))
# [1] "Best Base-learner Test AUC: 0.78168458478629"
# [1] "Ensemble Test AUC: 0.785058580788397"
# Generate predictions on a test set (if neccessary)
pred <- h2o.predict(ensemble, newdata = test)
# Or use AutoML instead for superior results: http://docs.h2o.ai/h2o/latest-stable/h2o-docs/automl.html
# Run AutoML for 20 base models (limited to 1 hour max runtime by default)
aml <- h2o.automl(x = x, y = y,
training_frame = train,
fold_column = "fold_id",
max_models = 20,
seed = 1)
# Test AUC of AutoML winning model
automl_auc_test <- h2o.auc(h2o.performance(aml@leader, newdata = test))
print(sprintf("AutoML Test AUC: %s", automl_auc_test))
# [1] "AutoML Test AUC: 0.793495420924985"
# View the AutoML Leaderboard (sorted by CV AUC, not test set AUC)
lb <- aml@leaderboard
print(lb, n = nrow(lb)) # Print all rows instead of default (6 rows)
# model_id auc logloss mean_per_class_error rmse mse
# 1 StackedEnsemble_AllModels_AutoML_20191029_203448 0.7884892 0.5527148 0.3151245 0.4329294 0.1874278
# 2 StackedEnsemble_BestOfFamily_AutoML_20191029_203448 0.7882598 0.5525626 0.3195189 0.4328554 0.1873638
# 3 XGBoost_3_AutoML_20191029_203448 0.7847433 0.5571271 0.3097956 0.4347480 0.1890058
# 4 XGBoost_grid_1_AutoML_20191029_203448_model_3 0.7831363 0.5591175 0.3078232 0.4355900 0.1897386
# 5 XGBoost_2_AutoML_20191029_203448 0.7827626 0.5568081 0.3233692 0.4350913 0.1893044
# 6 GBM_5_AutoML_20191029_203448 0.7825846 0.5569374 0.3236149 0.4351555 0.1893603
# 7 XGBoost_grid_1_AutoML_20191029_203448_model_4 0.7824331 0.5587759 0.3325449 0.4357166 0.1898490
# 8 XGBoost_1_AutoML_20191029_203448 0.7819913 0.5578475 0.3170729 0.4354584 0.1896240
# 9 XGBoost_grid_1_AutoML_20191029_203448_model_1 0.7810482 0.5618187 0.3535456 0.4367870 0.1907829
# 10 GBM_1_AutoML_20191029_203448 0.7787226 0.5617122 0.3132779 0.4371347 0.1910867
# 11 GBM_2_AutoML_20191029_203448 0.7781562 0.5619227 0.3263315 0.4373559 0.1912802
# 12 GBM_3_AutoML_20191029_203448 0.7756294 0.5647284 0.3281380 0.4386296 0.1923959
# 13 XGBoost_grid_1_AutoML_20191029_203448_model_2 0.7735268 0.5772599 0.3299787 0.4434771 0.1966719
# 14 GBM_4_AutoML_20191029_203448 0.7717646 0.5694947 0.3318445 0.4409754 0.1944593
# 15 GBM_grid_1_AutoML_20191029_203448_model_1 0.7694388 0.5730569 0.3293395 0.4426899 0.1959744
# 16 DRF_1_AutoML_20191029_203448 0.7673043 0.5785543 0.3322677 0.4444543 0.1975396
# 17 XRT_1_AutoML_20191029_203448 0.7639914 0.5815739 0.3535071 0.4459000 0.1988268
# 18 GBM_grid_1_AutoML_20191029_203448_model_2 0.7528762 0.9222735 0.3492733 0.4948733 0.2448996
# 19 DeepLearning_grid_1_AutoML_20191029_203448_model_2 0.7340179 0.6042518 0.3911436 0.4564639 0.2083593
# 20 GLM_grid_1_AutoML_20191029_203448_model_1 0.6813971 0.6390391 0.3973750 0.4729697 0.2237004
# 21 DeepLearning_grid_1_AutoML_20191029_203448_model_1 0.6743707 0.6701134 0.4271833 0.4815416 0.2318823
# 22 DeepLearning_1_AutoML_20191029_203448 0.6711025 0.6531100 0.4154329 0.4785532 0.2290131
#
# [22 rows x 6 columns]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment