Skip to content

Instantly share code, notes, and snippets.

@smly
Created February 3, 2020 03:15
Show Gist options
  • Save smly/99328876b27869fa5067ff8f8a4083d6 to your computer and use it in GitHub Desktop.
Save smly/99328876b27869fa5067ff8f8a4083d6 to your computer and use it in GitHub Desktop.
import sklearn.datasets
import sklearn.metrics
import xgboost as xgb
import optuna
# FYI: Objective functions can take additional arguments
# (https://optuna.readthedocs.io/en/stable/faq.html#objective-func-additional-args).
def objective(trial):
data, target = sklearn.datasets.load_breast_cancer(return_X_y=True)
dtrain = xgb.DMatrix(data, label=target)
param = {
'silent': 1,
'objective': 'binary:logistic',
'eval_metric': 'auc',
'booster': trial.suggest_categorical('booster', ['gbtree', 'gblinear', 'dart']),
'lambda': trial.suggest_loguniform('lambda', 1e-8, 1.0),
'alpha': trial.suggest_loguniform('alpha', 1e-8, 1.0)
}
if param['booster'] == 'gbtree' or param['booster'] == 'dart':
param['max_depth'] = trial.suggest_int('max_depth', 1, 9)
param['eta'] = trial.suggest_loguniform('eta', 1e-8, 1.0)
param['gamma'] = trial.suggest_loguniform('gamma', 1e-8, 1.0)
param['grow_policy'] = trial.suggest_categorical('grow_policy', ['depthwise', 'lossguide'])
if param['booster'] == 'dart':
param['sample_type'] = trial.suggest_categorical('sample_type', ['uniform', 'weighted'])
param['normalize_type'] = trial.suggest_categorical('normalize_type', ['tree', 'forest'])
param['rate_drop'] = trial.suggest_loguniform('rate_drop', 1e-8, 1.0)
param['skip_drop'] = trial.suggest_loguniform('skip_drop', 1e-8, 1.0)
# Add a callback for pruning.
pruning_callback = optuna.integration.XGBoostPruningCallback(trial, 'test-auc')
history = xgb.cv(param, dtrain, nfold=2, num_boost_round=50, callbacks=[pruning_callback])
return history.iloc[-1]['test-auc-mean']
if __name__ == '__main__':
study = optuna.create_study(pruner=optuna.pruners.MedianPruner(n_warmup_steps=5),
direction='maximize')
study.optimize(objective, n_trials=100)
print(study.best_trial)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment