Skip to content

Instantly share code, notes, and snippets.

@crcrpar
Last active January 21, 2020 10:00
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save crcrpar/d025b837720e08131ed02e7511188bde to your computer and use it in GitHub Desktop.
Save crcrpar/d025b837720e08131ed02e7511188bde to your computer and use it in GitHub Desktop.
import sklearn
import sklearn.datasets
from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
import optuna
def str_to_class(model):
if model == 'LogisticRegression':
return LogisticRegression
else:
return SVC
def objective(trial: optuna.Trial) -> float:
data, target = sklearn.datasets.load_breast_cancer(return_X_y=True)
train_X, test_X, train_y, test_y = train_test_split(data, target, test_size=.25)
clf = trial.suggest_categorical('classifier', ['LogisticRegression', 'SVC'])
clf = str_to_class(clf)
if clf == LogisticRegression:
penalty = trial.suggest_categorical('lr_penalty', ['l1', 'l2'])
if penalty == 'l2':
solver = trial.suggest_categorical('lr_solver_l2', ['newton-cg', 'lbfgs', 'sag', 'liblinear', 'saga'])
else:
solver = trial.suggest_categorical('lr_solver_l1', ['liblinear', 'saga'])
params = {
'C': trial.suggest_loguniform('lr_C', 1e-5, 1e2),
'penalty': penalty,
'dual': trial.suggest_categorical('lr_dual', [True, False]) if penalty == 'l2' and solver == 'liblinear' else False,
'solver': solver,
}
else:
params = {
'C': trial.suggest_loguniform('svm_C', 1e-5, 1e2)
}
model = clf()
model.set_params(**params)
model.fit(train_X, train_y)
return cross_val_score(model, test_X, test_y).mean()
def main():
study = optuna.create_study(direction='maximize')
study.optimize(objective, n_trials=50)
print(study.best_params)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment