Skip to content

Instantly share code, notes, and snippets.

@pkgandhi
Last active August 16, 2020 21:10
Show Gist options
  • Save pkgandhi/e76ae55c77e3aaa0ca03c0ae16737835 to your computer and use it in GitHub Desktop.
Save pkgandhi/e76ae55c77e3aaa0ca03c0ae16737835 to your computer and use it in GitHub Desktop.
# Importing the Packages:
import optuna
import pandas as pd
from sklearn import linear_model
from sklearn import datasets
from sklearn import model_selection
#Grabbing a sklearn Classification dataset:
X,y = datasets.load_breast_cancer(return_X_y=True, as_frame=True)
classes = list(set(y))
x_train, x_valid, y_train, y_valid = model_selection.train_test_split(X, y)
#Step 1. Define an objective function to be maximized.
def objective(trial):
classifier_name = trial.suggest_categorical("classifier", ["LogReg", "RandomForest"])
# Step 2. Setup values for the hyperparameters:
if classifier_name == 'LogReg':
logreg_c = trial.suggest_float("logreg_c", 1e-10, 1e10, log=True)
classifier_obj = linear_model.LogisticRegression(C=logreg_c)
else:
rf_n_estimators = trial.suggest_int("rf_n_estimators", 10, 1000)
rf_max_depth = trial.suggest_int("rf_max_depth", 2, 32, log=True)
classifier_obj = ensemble.RandomForestClassifier(
max_depth=rf_max_depth, n_estimators=rf_n_estimators
)
for step in range(100):
classifier_obj.fit(x_train, y_train)
# Report intermediate objective value.
intermediate_value = classifier_obj.score(x_valid, y_valid)
trial.report(intermediate_value, step)
# Handle pruning based on the intermediate value.
if trial.should_prune():
raise optuna.TrialPruned()
return intermediate_value
# Step 4: Running it
study = optuna.create_study(direction="maximize")
study.optimize(objective, n_trials=100)
# Calculating the pruned and completed trials
pruned_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.PRUNED]
complete_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.COMPLETE]
print(" Number of finished trials: ", len(study.trials))
print(" Number of pruned trials: ", len(pruned_trials))
print(" Number of complete trials: ", len(complete_trials))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment