Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
MLflow Log Model decorator
import functools
import logging
import subprocess
import mlflow # type: ignore
import mlflow.sklearn # type: ignore
logger = logging.getLogger(__name__)
def log_skl_model(func):
"""
Log a scikit-learn model statistics to MLflow.
The decorated function must have a `model_reporting` kwargfrom parameters.yml.
The decorated function must return a tuple consisting on the `sklearn.pipeline.Pipeline`object and
the statistics to log.
Example: ::
from util.mlflow import log_skl_model
@log_skl_model
def my_model_run(arg1, arg2, arg3, model_reporting=model_reporting)
...
...
return skl_pipeline, {statistic1: value1, statistic2: value2, ...}
"""
@functools.wraps(func)
def wrapper_mlflow(*args, **kwargs):
# Get URI and Experiment for MLflow
model_reporting = kwargs.get("model_reporting")
if not model_reporting:
raise ValueError(
"You must set `model_reporting` kwarg to use the log_model util. "
"Pass from parameters.yml"
)
uri = model_reporting.get("tracking_uri")
experiment = kwargs.get("experiment") or model_reporting.get(
"default_experiment"
)
if not uri and experiment:
raise ValueError(
"The `model_reporting` must be a dict containing uri and default_experiment keys"
)
logger.info(
"Logging model performance using MLflow. URI: {} Experiment: {}".format(
uri, experiment
)
)
# Debug
args_repr = [repr(a) for a in args]
kwargs_repr = [f"{k}={v!r}" for k, v in kwargs.items()]
signature = ", ".join(args_repr + kwargs_repr)
logger.debug(f"Calling {func.__name__}({signature})")
# Run decorated function and get model/statistics tuple
skl_pipeline, statistics = func(*args, **kwargs)
# Start the MLflow context
mlflow.set_tracking_uri(uri)
experiment_id = mlflow.create_experiment(experiment)
commit_hash = (
subprocess.check_output(["git", "log", "-1", "--pretty=format:%h"]) or None
)
commit_msg = (
subprocess.check_output(["git", "log", "-1", "--pretty=format:%B"]) or None
)
mlflow.set_tags(
{
"mlflow.runName": commit_hash,
"git.hash": commit_hash,
"git.msg": commit_msg,
"calling_funtion.name": func.__name__,
"calling_funtion.args": args_repr,
"calling_funtion.kwargs": kwargs_repr,
}
)
mlflow.start_run(experiment_id=experiment_id)
# Log parameters
for parameter_name, value in skl_pipeline.get_params().items():
try:
value = float(value or 0)
mlflow.log_param(parameter_name, value)
except (ValueError, TypeError):
continue
# Log statistics
for key in statistics:
mlflow.log_metric(key, statistics[key])
# Log model
mlflow.sklearn.log_model(skl_pipeline.steps[-1][1], "model")
mlflow.end_run()
return skl_pipeline
return wrapper_mlflow
"""Test the model_tracking utility built around mlflow."""
# pylint: disable=too-few-public-methods
# pylint: skip-file
# flake8: noqa
import pytest
from sklearn.datasets import fetch_20newsgroups
from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer
from sklearn.metrics import f1_score
from sklearn.model_selection import cross_val_score
from sklearn.pipeline import Pipeline
from sklearn.svm import LinearSVC
from project_utsugi.nodes.common.utils.model_tracking import log_skl_model
def test_log_skl_model(dummy_string, mock_model_reporting):
@log_skl_model
def test_skl_pipeline(dummy_arg, model_reporting):
_, _ = dummy_arg, model_reporting
cats = ["alt.atheism", "sci.space"]
newsgroups_train = fetch_20newsgroups(subset="train", categories=cats)
newsgroups_test = fetch_20newsgroups(subset="test", categories=cats)
x_train, x_test = newsgroups_train.data, newsgroups_test.data
y_train, y_test = newsgroups_train.target, newsgroups_test.target
pipeline = Pipeline(
[
("vect", CountVectorizer()),
("tfidf", TfidfTransformer()),
("clf", LinearSVC()),
]
)
# now train and predict test instances
pipeline.fit(x_train, y_train)
y_pred = pipeline.predict(x_test)
# get scores
cross = cross_val_score(pipeline, x_train, y_train, cv=3, scoring="f1_micro")
f1 = f1_score(y_test, y_pred, average="micro")
return pipeline, {"mean_cross_score": cross.mean(), "f1_score": f1}
try:
test_skl_pipeline(dummy_string, model_reporting=mock_model_reporting)
except Exception as e:
raise pytest.fail("log_skl_model decorator raised {0}".format(e))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment