Skip to content

Instantly share code, notes, and snippets.

@keisuke-umezawa
Created August 22, 2021 08:58
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save keisuke-umezawa/e58940ca71c9f2e061837c8b8017d7a0 to your computer and use it in GitHub Desktop.
Save keisuke-umezawa/e58940ca71c9f2e061837c8b8017d7a0 to your computer and use it in GitHub Desktop.
simple_dask
import optuna
from dask.distributed import Client
from sklearn.datasets import load_digits
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
def objective(trial):
X, y = load_digits(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=2)
max_depth = trial.suggest_int("max_depth", 2, 10)
n_estimators = trial.suggest_int("n_estimators", 1, 100)
clf = RandomForestClassifier(max_depth=max_depth, n_estimators=n_estimators)
clf.fit(X_train, y_train)
y_pred = clf.predict(X_test)
score = accuracy_score(y_test, y_pred)
return score
if __name__ == "__main__":
storages = [
None,
"sqlite:///example.db",
]
for s in storages:
with Client() as client:
print(f"Dask dashboard is available at {client.dashboard_link}")
study_name = "burabura"
storage = optuna.integration.dask.DaskStorage(storage=s)
study = optuna.create_study(
study_name=study_name,
storage=storage,
direction="maximize",
)
study.optimize(objective, n_trials=100)
print(f"Best params: {study.best_params}")
if s is None:
continue
study = optuna.integration.dask.DaskStudy(study_name, storage=storage)
print(f"Best params: {study.best_params}")
study.optimize(objective, n_trials=10)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment