-
-
Save Noobzik/cdf7a4754067e587010d4819fae671f4 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import mlflow | |
from mlflow.tracking import MlflowClient | |
def push_to_model_registry(registry_name: str, run_id: int): | |
""" | |
Pushes a model's version to the specified registry. | |
""" | |
mlflow.set_tracking_uri(os.getenv("MLFLOW_SERVER")) | |
result = mlflow.register_model( | |
"runs:/{}/artifacts/model".format(run_id), registry_name | |
) | |
return result.version | |
def stage_model(registry_name: str, version: int): | |
""" | |
Stages a model version pushed to model registry. | |
""" | |
env = os.getenv("ENV") | |
if env not in ["staging", "production"]: | |
return | |
client = MlflowClient() | |
client.transition_model_version_stage( | |
name=registry_name, version=str(version), stage=env[0].upper() + env[1:] | |
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
This is a boilerplate pipeline 'deployement' | |
generated using Kedro 0.19.10 | |
""" | |
from kedro.pipeline import Pipeline, node | |
from .nodes import push_to_model_registry, stage_model | |
def create_pipeline(**kwargs): | |
return Pipeline( | |
[ | |
node( | |
func=push_to_model_registry, | |
inputs=["params:mlflow_model_registry", "mlflow_run_id"], | |
outputs="mlflow_model_version", | |
name="push_to_model_registry", | |
), | |
node( | |
func=stage_model, | |
inputs=["params:mlflow_model_registry", "mlflow_model_version"], | |
outputs=None, | |
name="stage_model", | |
), | |
] | |
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
This is a boilerplate pipeline 'training' | |
generated using Kedro 0.19.10 | |
""" | |
import os | |
import numpy as np | |
import pandas as pd | |
import mlflow | |
from typing import Callable, Tuple, Any, Dict | |
from matplotlib import pyplot as plt | |
import matplotlib.ticker as mtick | |
from mlflow.models import infer_signature | |
from sklearn.base import BaseEstimator | |
from sklearn.metrics import f1_score, precision_recall_curve, PrecisionRecallDisplay | |
from sklearn.model_selection import RepeatedKFold | |
from lightgbm.sklearn import LGBMClassifier | |
from hyperopt import hp, tpe, fmin | |
import warnings | |
def save_pr_curve(X, y, model): | |
plt.figure(figsize=(16,11)) | |
prec, recall, _ = precision_recall_curve(y, model.predict_proba(X)[:,1], pos_label=1) | |
pr_display = PrecisionRecallDisplay(precision=prec, recall=recall).plot(ax=plt.gca()) | |
plt.title("PR Curve", fontsize=16) | |
plt.gca().xaxis.set_major_formatter(mtick.PercentFormatter(1, 0)) | |
plt.gca().yaxis.set_major_formatter(mtick.PercentFormatter(1, 0)) | |
plt.savefig(os.path.expanduser("data/08_reporting/pr_curve.png")) | |
plt.close() | |
def auto_ml( | |
X_train: np.ndarray, | |
y_train: np.ndarray, | |
X_test: np.ndarray, | |
y_test: np.ndarray, | |
max_evals: int = 40, | |
log_to_mlflow: bool = False, | |
experiment_id: int = -1 | |
) -> BaseEstimator: | |
""" | |
Runs training of multiple model instances and select the most accurated based on objective function. | |
""" | |
X = pd.concat((X_train, X_test)) | |
y = pd.concat((y_train, y_test)) | |
run_id = "" | |
if log_to_mlflow: | |
mlflow.set_tracking_uri(os.getenv("MLFLOW_SERVER")) | |
run = mlflow.start_run(experiment_id=str(experiment_id)) | |
run_id = run.info.run_id | |
opt_models = [] | |
for model_specs in MODELS: | |
# Finding best hyper-parameters with bayesian optimization | |
optimum_params = optimize_hyp( | |
model_specs["class"], | |
dataset=(X, y), | |
search_space=model_specs["params"], | |
metric=lambda x, y: -f1_score(x, y), | |
max_evals=max_evals, | |
) | |
print("done") | |
# Training the supposed best model with found hyper-parameters | |
model = train_model( | |
model_specs["class"], | |
training_set=(X_train, y_train), | |
params=optimum_params, | |
) | |
opt_models.append( | |
{ | |
"model": model, | |
"name": model_specs["name"], | |
"params": optimum_params, | |
"score": f1_score(y_test, model.predict(X_test)), | |
} | |
) | |
# In case we have multiple models | |
best_model = max(opt_models, key=lambda x: x["score"]) | |
if log_to_mlflow: | |
model_metrics = { | |
"f1": best_model["score"], | |
} | |
signature = infer_signature(X_train, best_model["model"].predict(X_train)) | |
save_pr_curve(X_test, y_test, best_model["model"]) | |
mlflow.log_metrics(model_metrics) | |
mlflow.log_params(best_model["params"]) | |
mlflow.log_artifact("data/08_reporting/pr_curve.png", artifact_path="plots") | |
mlflow.sklearn.log_model(best_model["model"], "model", signature=signature) | |
mlflow.end_run() | |
return dict(model=best_model, mlflow_run_id=run_id) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
This is a boilerplate pipeline 'training' | |
generated using Kedro 0.19.10 | |
""" | |
from kedro.pipeline import Pipeline, node | |
from .nodes import auto_ml | |
def create_pipeline(**kwargs) -> Pipeline: | |
return Pipeline( | |
[ | |
node( | |
func=auto_ml, | |
inputs=[ | |
"X_train", | |
"y_train", | |
"X_test", | |
"y_test", | |
"params:automl_max_evals", | |
"params:mlflow_enabled", | |
"params:mlflow_experiment_id", | |
], | |
outputs=dict(model="model", mlflow_run_id="mlflow_run_id"), | |
name="auto_ml", | |
) | |
] | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment