This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import h2o | |
from h2o.automl import H2OAutoML | |
h2o.init() | |
# サンプルバイナリ結果トレイン/テストセットをH2Oにインポートする | |
train = h2o.import_file("https://s3.amazonaws.com/erin-data/higgs/higgs_train_10k.csv") | |
test = h2o.import_file("https://s3.amazonaws.com/erin-data/higgs/higgs_test_5k.csv") | |
# 予測子とレスポンスを特定する | |
x = train.columns | |
y = "response" | |
x.remove(y) | |
# 二項分類の場合 | |
train[y] = train[y].asfactor() | |
test[y] = test[y].asfactor() | |
# AutoMLを30秒間実行する | |
aml = H2OAutoML(max_runtime_secs = 30) | |
aml.train(x = x, y = y, | |
training_frame = train, | |
leaderboard_frame = test) | |
# AutoMLリーダーボードを見る | |
lb = aml.leaderboard | |
lb | |
# model_id auc logloss | |
# -------------------------------------------------- -------- --------- | |
# StackedEnsemble_model_1494643945817_1709 0.780384 0.561501 | |
# GBM_grid__95ebce3d26cd9d3997a3149454984550_model_0 0.764791 0.664823 | |
# GBM_grid__95ebce3d26cd9d3997a3149454984550_model_2 0.758109 0.593887 | |
# DRF_model_1494643945817_3 0.736786 0.614430 | |
# XRT_model_1494643945817_461 0.735946 0.602142 | |
# GBM_grid__95ebce3d26cd9d3997a3149454984550_model_3 0.729492 0.667036 | |
# GBM_grid__95ebce3d26cd9d3997a3149454984550_model_1 0.727456 0.675624 | |
# GLM_grid__95ebce3d26cd9d3997a3149454984550_model_1 0.685216 0.635137 | |
# GLM_grid__95ebce3d26cd9d3997a3149454984550_model_0 0.685216 0.635137 | |
# リーダーモデルはここに保存されています | |
aml.leader | |
# 予測 | |
# 方法 1 | |
preds = aml.predict(test) | |
# または方法 2 | |
preds = aml.leader.predict(test) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment