Last active
June 29, 2023 05:06
-
-
Save eukaryo/d7d80f2b7abea39a6d157645dee873f1 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import lightgbm as lgb | |
import numpy as np | |
import optuna | |
import sklearn.datasets | |
import sklearn.ensemble | |
from optuna.terminator.callback import TerminatorCallback | |
from optuna.terminator.erroreval import report_cross_validation_scores | |
from optuna.terminator.terminator import Terminator | |
from sklearn.model_selection import KFold, cross_val_score, train_test_split | |
# pip install optuna==3.2.0 matplotlib lightgbm botorch | |
data, target = sklearn.datasets.load_breast_cancer(return_X_y=True) | |
train_x, test_x, train_y, test_y = train_test_split( | |
data, target, test_size=0.2, random_state=111 | |
) | |
fixed_param = { | |
"boosting_type": "gbdt", | |
"verbosity": -1, | |
"seed": 333, | |
} | |
def scorer(estimator, X, y): | |
preds = estimator.predict(X) | |
pred_labels = np.rint(preds) | |
return sklearn.metrics.accuracy_score(y, pred_labels) | |
def objective(trial): | |
param = fixed_param | { | |
"lambda_l1": trial.suggest_float("lambda_l1", 1e-8, 10.0, log=True), | |
"lambda_l2": trial.suggest_float("lambda_l2", 1e-8, 10.0, log=True), | |
"num_leaves": trial.suggest_int("num_leaves", 2, 256), | |
"feature_fraction": trial.suggest_float("feature_fraction", 0.4, 1.0), | |
"bagging_fraction": trial.suggest_float("bagging_fraction", 0.4, 1.0), | |
"bagging_freq": trial.suggest_int("bagging_freq", 1, 7), | |
"min_child_samples": trial.suggest_int("min_child_samples", 5, 100), | |
} | |
clf = lgb.LGBMRegressor(**param) | |
scores = cross_val_score( | |
clf, | |
train_x, | |
train_y, | |
scoring=scorer, | |
cv=KFold(n_splits=5, shuffle=True, random_state=1), | |
) | |
report_cross_validation_scores(trial, scores) | |
return scores.mean() | |
if __name__ == "__main__": | |
sampler = optuna.samplers.TPESampler(seed=12345) | |
study = optuna.create_study(direction="maximize", sampler=sampler) | |
study.optimize( | |
objective, callbacks=[TerminatorCallback(Terminator(min_n_trials=10))] | |
) | |
fig = optuna.visualization.matplotlib.plot_optimization_history(study) | |
fig.figure.savefig(f"fig-terminator-history-matplotlib.png") | |
fig = optuna.visualization.matplotlib.plot_timeline(study) | |
fig.figure.savefig("fig-terminator-timeline-matplotlib.png") | |
fig = optuna.visualization.matplotlib.plot_terminator_improvement( | |
study, plot_error=True | |
) | |
fig.figure.savefig("fig-terminator-improvement-matplotlib.png") | |
print(f"Number of finished trials: {len(study.trials)}") | |
best_trial = study.best_trial | |
best_param = fixed_param | best_trial.params | |
dtrain = lgb.Dataset(train_x, label=train_y) | |
best_gbm = lgb.train(best_param, dtrain) | |
test_accuracy = scorer(best_gbm, test_x, test_y) | |
print(f"Best trial: number={best_trial.number}") | |
print(f"validation accuracy: {best_trial.value}") | |
print(f"test accuracy: {test_accuracy}") | |
print("Params:") | |
for key, value in best_trial.params.items(): | |
print(f"{key}: {value}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment