Skip to content

Instantly share code, notes, and snippets.

@smly
Last active April 27, 2020 08:35
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save smly/584e4b9eabacecc7c102d47de4066a2f to your computer and use it in GitHub Desktop.
Save smly/584e4b9eabacecc7c102d47de4066a2f to your computer and use it in GitHub Desktop.
import numpy as np
import sklearn.datasets
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
import mlflow
import optuna
import optuna.integration.lightgbm as lgb
def mlflow_callback(study, trial):
trial_value = trial.value if trial.value is not None else float("nan")
trial_runtime = float("nan")
if trial.datetime_start and trial.datetime_complete:
trial_runtime = (trial.datetime_complete - trial.datetime_start).total_seconds()
with mlflow.start_run(run_name=study.study_name):
mlflow.log_params(trial.params)
mlflow.log_params({"step_name": trial.user_attrs["lightgbm_tuner:step_name"]}
mlflow.log_metrics({
"trial_number": trial.number,
"elapsed_time": trial_runtime,
"mean_squared_error": trial_value,
})
if __name__ == "__main__":
data, target = sklearn.datasets.load_breast_cancer(return_X_y=True)
train_x, val_x, train_y, val_y = train_test_split(data, target, test_size=0.25)
dtrain = lgb.Dataset(train_x, label=train_y)
dval = lgb.Dataset(val_x, label=val_y)
params = {
"objective": "binary",
"metric": "binary_logloss",
"verbosity": -1,
"boosting_type": "gbdt",
}
study = optuna.create_study(study_name="lightgbm_tuner_breast_cancer", direction="minimize")
model = lgb.train(
params, dtrain, valid_sets=[dtrain, dval], verbose_eval=100, early_stopping_rounds=100, optuna_callbacks=[mlflow_callback]
)
prediction = np.rint(model.predict(val_x, num_iteration=model.best_iteration))
accuracy = accuracy_score(val_y, prediction)
best_params = model.params
print("Best params:", best_params)
print(" Accuracy = {}".format(accuracy))
print(" Params: ")
for key, value in best_params.items():
print(" {}: {}".format(key, value))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment