Skip to content

Instantly share code, notes, and snippets.

@smly
Last active June 15, 2020 02:08
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save smly/33036bb6b4833865854d94072f2c2902 to your computer and use it in GitHub Desktop.
Save smly/33036bb6b4833865854d94072f2c2902 to your computer and use it in GitHub Desktop.
diff --git a/optuna/integration/lightgbm_tuner/optimize.py b/optuna/integration/lightgbm_tuner/optimize.py
index f11c8326..4434e51c 100644
--- a/optuna/integration/lightgbm_tuner/optimize.py
+++ b/optuna/integration/lightgbm_tuner/optimize.py
@@ -112,7 +112,14 @@ class BaseTuner(object):
else:
raise NotImplementedError
+ if self.lgbm_params.get("metric") == "None":
+ if len(booster.best_score[valid_name].keys()) > 0:
+ metric = list(booster.best_score[valid_name].keys())[0]
+ else:
+ raise ValueError("No given metric in parameters.")
+
val_score = booster.best_score[valid_name][metric]
+
return val_score
def _metric_with_eval_at(self, metric):
import numpy as np
import sklearn.datasets
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
import optuna.integration.lightgbm as lgb
def custom_metrics_func(preds, train_data):
y = train_data.get_label()
error_sum = (preds - y).sum()
return ('custom_metrics', error_sum, False)
if __name__ == "__main__":
data, target = sklearn.datasets.load_breast_cancer(return_X_y=True)
train_x, val_x, train_y, val_y = train_test_split(data, target, test_size=0.25)
dtrain = lgb.Dataset(train_x, label=train_y)
dval = lgb.Dataset(val_x, label=val_y)
params = {
"objective": "binary",
"metric": "None",
"verbosity": -1,
"boosting_type": "gbdt",
}
model = lgb.train(
params, dtrain, valid_sets=[dtrain, dval], verbose_eval=100, early_stopping_rounds=100,
feval=custom_metrics_func,
)
prediction = np.rint(model.predict(val_x, num_iteration=model.best_iteration))
accuracy = accuracy_score(val_y, prediction)
best_params = model.params
print("Best params:", best_params)
print(" Accuracy = {}".format(accuracy))
print(" Params: ")
for key, value in best_params.items():
print(" {}: {}".format(key, value))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment