Skip to content

Instantly share code, notes, and snippets.

@krsnewwave
Created March 11, 2022 15:08
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save krsnewwave/68c7d5c641ed69d523ef0ffdc306fa19 to your computer and use it in GitHub Desktop.
Save krsnewwave/68c7d5c641ed69d523ef0ffdc306fa19 to your computer and use it in GitHub Desktop.
random forest hyperparam optimization using optuna, kedro and mlflow
# in <root>/src/<project>/pipelines/data_science/nodes.py
def rr_objective(X_train: pd.DataFrame, y_train: pd.Series,
X_test: pd.DataFrame, y_test: pd.Series,
trial: optuna.trial):
max_depth = trial.suggest_int("max_depth", 8, 64, log=True)
min_samples_split = trial.suggest_int("min_samples_split", 50, 1000, )
ccp_alpha = trial.suggest_float("ccp_alpha", 0.001, 0.03, log=True)
rr_clf = RandomForestClassifier(max_depth=max_depth,
min_samples_split=min_samples_split,
ccp_alpha=ccp_alpha,
class_weight='balanced_subsample',
verbose=1)
rr_clf.fit(X_train, y_train)
y_proba = rr_clf.predict_proba(X_test)[:, 1]
ap = average_precision_score(y_test, y_proba)
return ap
def fit_rr_ho(X_train: pd.DataFrame, y_train: pd.Series,
X_test: pd.DataFrame, y_test: pd.Series):
study = optuna.create_study(direction="maximize")
fun_rr_object = partial(rr_objective, X_train, y_train, X_test, y_test)
# increase n_trials > 100 for better success
study.optimize(fun_rr_object, n_trials=5)
best_params = study.best_params
mlflow.log_params(best_params)
rr_clf = RandomForestClassifier(**best_params)
rr_clf.fit(X_train, y_train)
dict_metrics = evaluate_model(rr_clf, X_test, y_test)
return {"clf": rr_clf, "model_metrics": dict_metrics}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment