Last active
June 25, 2016 16:19
-
-
Save szilard/b87233bbf41a4b366c26eede7bb1a0f3 to your computer and use it in GitHub Desktop.
ML with H2O.ai
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
library(h2o) | |
h2o.init(max_mem_size = "20g", nthreads = -1) | |
# R is connected to the H2O cluster: | |
# H2O cluster uptime: 1 seconds 704 milliseconds | |
# H2O cluster version: 3.8.2.8 | |
# H2O cluster name: H2O_started_from_R_szilard_lcr105 | |
# H2O cluster total nodes: 1 | |
# H2O cluster total memory: 17.78 GB | |
# H2O cluster total cores: 16 | |
# H2O cluster allowed cores: 16 | |
# H2O cluster healthy: TRUE | |
# H2O Connection ip: localhost | |
# H2O Connection port: 54321 | |
# H2O Connection proxy: NA | |
# R Version: R version 3.3.0 (2016-05-03) | |
dx <- h2o.importFile("https://s3.amazonaws.com/benchm-ml--main/train-1m.csv") | |
dx | |
# Month DayofMonth DayOfWeek DepTime UniqueCarrier Origin Dest Distance dep_delayed_15min | |
# 1 c-4 c-26 c-2 1828 XE LEX IAH 828 N | |
# 2 c-12 c-11 c-1 1212 UA DEN MCI 533 N | |
# 3 c-10 c-1 c-6 935 OH HSV CVG 325 N | |
# 4 c-11 c-26 c-6 930 OH JFK PNS 1028 N | |
# 5 c-12 c-6 c-2 1350 MQ DFW LBB 282 Y | |
# 6 c-11 c-24 c-5 1525 FL ATL TPA 406 N | |
dx_split <- h2o.splitFrame(dx, ratios = c(0.9,0.05), seed = 123) | |
dx_train <- dx_split[[1]] | |
dx_valid <- dx_split[[2]] | |
dx_test <- dx_split[[3]] | |
p <- ncol(dx)-1 | |
## Random forest | |
system.time({ | |
md <- h2o.randomForest(training_frame = dx_train, x = 1:p, y = p+1, | |
ntrees = 500, max_depth = 20, nbins = 100, seed = 123) | |
}) | |
h2o.auc(md, train = TRUE) | |
h2o.auc(h2o.performance(md, dx_test)) | |
# user system elapsed | |
# 9.836 2.474 774.053 | |
# [1] 0.7804589 | |
# [1] 0.7840077 | |
## GBM | |
system.time({ | |
md <- h2o.gbm(training_frame = dx_train, x = 1:p, y = p+1, | |
max_depth = 15, learn_rate = 0.1, nbins = 100, | |
ntrees = 10000, ## early stopping | |
validation_frame = dx_valid, | |
score_tree_interval = 10, stopping_metric = "AUC", | |
stopping_tolerance = 1e-3, stopping_rounds = 3, | |
seed = 123) | |
}) | |
h2o.auc(md, train = TRUE, valid = TRUE) | |
h2o.auc(h2o.performance(md, dx_test)) | |
# user system elapsed | |
# 31.568 4.835 2540.175 | |
# 0.9992960 0.8344899 | |
# [1] 0.8294485 | |
## GBM - CV | |
system.time({ | |
md <- h2o.gbm(training_frame = dx_train, x = 1:p, y = p+1, | |
max_depth = 15, learn_rate = 0.1, nbins = 100, | |
ntrees = 10000, ## early stopping | |
nfolds = 5, | |
score_tree_interval = 10, stopping_metric = "AUC", | |
stopping_tolerance = 1e-3, stopping_rounds = 3, | |
seed = 123) | |
}) | |
h2o.auc(md, train = TRUE, xval = TRUE) | |
h2o.auc(h2o.performance(md, dx_test)) | |
# user system elapsed | |
# 173.216 28.586 13931.253 | |
# 0.9991971 0.8189935 | |
# [1] 0.8290205 | |
## GBM - tuning | |
hyper_params = list( ntrees = 10000, ## early stopping | |
max_depth = 10:25, | |
min_rows = c(1,3,10,30,100), | |
learn_rate = c(0.03,0.1), | |
learn_rate_annealing = c(0.99,0.995,1,1), | |
sample_rate = c(0.4,0.7,1,1), | |
col_sample_rate = c(0.7,1,1), | |
nbins = c(30,100,300), | |
nbins_cats = c(64,256,1024), | |
min_split_improvement = c(0,1e-8,1e-6,1e-4), | |
histogram_type = c("UniformAdaptive","QuantilesGlobal","RoundRobin") | |
) | |
search_criteria = list( strategy = "RandomDiscrete", | |
max_runtime_secs = 24*3600, | |
max_models = 30 | |
) | |
system.time({ | |
mds <- h2o.grid(algorithm = "gbm", grid_id = "grd", | |
x = 1:p, y = p+1, training_frame = dx_train, | |
validation_frame = dx_valid, | |
hyper_params = hyper_params, | |
search_criteria = search_criteria, | |
score_tree_interval = 10, stopping_metric = "AUC", | |
stopping_tolerance = 1e-3, stopping_rounds = 3, | |
seed = 123) | |
}) | |
mds_sort <- h2o.getGrid(grid_id = "grd", sort_by = "auc", decreasing = TRUE) | |
mds_sort | |
md_best <- h2o.getModel(mds_sort@model_ids[[1]]) | |
h2o.auc(h2o.performance(md_best, dx_test)) | |
# user system elapsed | |
# 531.794 93.188 43022.395 | |
# [1] 0.8301675 | |
## Logistic regression | |
system.time({ | |
md <- h2o.glm(x = 1:p, y = p+1, training_frame = dx_train, | |
family = "binomial", alpha = 1, lambda = 0) | |
}) | |
h2o.auc(md, train = TRUE) | |
h2o.auc(h2o.performance(md, dx_test)) | |
# user system elapsed | |
# 0.346 0.000 3.475 | |
# [1] 0.7139426 | |
# [1] 0.7144322 | |
## Neural Nets / Deep Learning | |
system.time({ | |
md <- h2o.deeplearning(x = 1:p, y = p+1, training_frame = dx_train, | |
validation_frame = dx_valid, | |
activation = "Rectifier", hidden = c(200,200), epochs = 10000, | |
## activation = "RectifierWithDropout", hidden = c(200,200,200,200), | |
## l1 = 1e-5, l2 = 1e-5, hidden_dropout_ratios=c(0.2,0.1,0.1,0), | |
stopping_rounds = 5, stopping_metric = "AUC", stopping_tolerance = 0, | |
seed = 123) | |
}) | |
h2o.auc(md, train = TRUE, valid = TRUE) | |
h2o.performance(md, dx_test)@metrics$AUC | |
# user system elapsed | |
# 9.681 0.814 850.018 | |
# 0.8077432 0.7768504 | |
# [1] 0.7739333 | |
# Ensembles | |
library(h2oEnsemble) | |
bl_rf1 <- function(..., ntrees = 500, max_depth = 20) | |
h2o.randomForest.wrapper(..., ntrees = ntrees, max_depth = max_depth) | |
bl_gbm1 <- function(..., ntrees = 10000, learn_rate = 0.1, max_depth = 15, | |
score_tree_interval = 10, stopping_metric = "AUC", stopping_tolerance = 1e-3, stopping_rounds = 3) | |
h2o.gbm.wrapper(..., ntrees = ntrees, learn_rate = learn_rate, max_depth = max_depth, | |
score_tree_interval = score_tree_interval, stopping_metric = stopping_metric, | |
stopping_tolerance = stopping_tolerance, stopping_rounds = stopping_rounds) | |
bl_gbm2 <- function(..., ntrees = 10000, learn_rate = 0.03, max_depth = 10, | |
score_tree_interval = 10, stopping_metric = "AUC", stopping_tolerance = 1e-3, stopping_rounds = 3) | |
h2o.gbm.wrapper(..., ntrees = ntrees, learn_rate = learn_rate, max_depth = max_depth, | |
score_tree_interval = score_tree_interval, stopping_metric = stopping_metric, | |
stopping_tolerance = stopping_tolerance, stopping_rounds = stopping_rounds) | |
bl_dl1 <- function(..., activation = "Rectifier", hidden = c(200,200), epochs = 10000, | |
stopping_metric = "AUC", stopping_tolerance = 0, stopping_rounds = 5) | |
h2o.deeplearning.wrapper(..., activation = activation, hidden = hidden, epochs = epochs, | |
stopping_metric = stopping_metric, stopping_tolerance = stopping_tolerance, stopping_rounds = stopping_rounds) | |
system.time({ | |
md <- h2o.ensemble(x = 1:p, y = p+1, training_frame = dx_train, | |
learner = c("bl_rf1","bl_gbm1","bl_gbm2","bl_dl1"), | |
## metalearner = "h2o.glm.wrapper", | |
## cvControl = list(V = 5, shuffle = TRUE), | |
seed = 123) | |
}) | |
# user system elapsed | |
# 527.176 88.102 43250.288 | |
h2o.ensemble_performance(md, newdata = dx_test) | |
# Base learner performance, sorted by specified metric: | |
# learner AUC | |
# 4 bl_dl1 0.7713471 | |
# 1 bl_rf1 0.7817813 | |
# 3 bl_gbm2 0.8186214 | |
# 2 bl_gbm1 0.8281934 | |
# H2O Ensemble Performance on <newdata>: | |
# Ensemble performance (AUC): 0.82988283994141 | |
md$metafit | |
# Coefficients: glm coefficients | |
# names coefficients standardized_coefficients | |
# 1 Intercept -3.018938 -1.726670 | |
# 2 bl_rf1 0.395931 0.059080 | |
# 3 bl_gbm1 3.076788 0.622808 | |
# 4 bl_gbm2 2.903059 0.506143 | |
# 5 bl_dl1 0.660071 0.129077 | |
## Scoring from R | |
h2o.saveModel(md, "./") | |
d_new <- as.data.frame(dx_test[1:3,1:p]) | |
md <- h2o.loadModel("./DeepLearning_model_R_1466713759449_6139") | |
h2o.predict(md, as.h2o(d_new)) | |
# predict N Y | |
# 1 N 0.7528926 0.24710742 | |
# 2 N 0.8047686 0.19523137 | |
# 3 N 0.9135341 0.08646589 | |
## Scoring from Java | |
h2o.download_pojo(md, path = "./", getjar = TRUE) | |
# DeepLearning_model_R_1466713759449_6139.java | |
# h2o-genmodel.jar | |
# main.java | |
# hex.genmodel.GenModel rawModel; | |
# rawModel = (hex.genmodel.GenModel) Class.forName(modelClassName).newInstance(); | |
# EasyPredictModelWrapper model = new EasyPredictModelWrapper(rawModel); | |
# | |
# RowData row = new RowData(); | |
# row.put("Month", "c-7"); | |
# row.put("DayofMonth", "c-25"); | |
# row.put("DayOfWeek", "c-3"); | |
# row.put("DepTime", "615"); | |
# row.put("UniqueCarrier", "YV"); | |
# row.put("Origin", "MRY"); | |
# row.put("Dest", "PHX"); | |
# row.put("Distance", "598"); | |
# | |
# BinomialModelPrediction p = model.predictBinomial(row); | |
# | |
# System.out.print(p.classProbabilities[1]); | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment