Skip to content

Instantly share code, notes, and snippets.

@smurching
Last active May 7, 2020 21:02
Show Gist options
  • Save smurching/181ba02995e15a2b2a00bf1c3cf64f44 to your computer and use it in GitHub Desktop.
Save smurching/181ba02995e15a2b2a00bf1c3cf64f44 to your computer and use it in GitHub Desktop.
OSS MLflow post-run-creation hook
from mlflow.tracking.context.abstract_context import RunContextProvider
from mlflow.utils import databricks_utils
from mlflow.entities import SourceType
from mlflow.utils.mlflow_tags import (
MLFLOW_SOURCE_TYPE,
MLFLOW_SOURCE_NAME,
MLFLOW_DATABRICKS_WEBAPP_URL,
MLFLOW_DATABRICKS_NOTEBOOK_PATH,
MLFLOW_DATABRICKS_NOTEBOOK_ID
)
class MlflowCreateRunHook(object):
"""
IPython event hook that maintains a counter of created runs & submits this count
to the Databricks frontend after each command execution. Note that this hook does
not directly detect run-creation & increment the counter itself - detecting run
creation is the responsibility of ``DatabricksNotebookRunContext`` below.
For more info, see the IPython event API:
https://ipython.readthedocs.io/en/stable/config/callbacks.html#ipython-events
"""
def __init__(self):
self._mlflow_runs_created = 0
self._user_ns = None
def pre_execute(self):
# Reset count of Mlflow runs created
self._mlflow_runs_created = 0
def post_execute(self):
if self._mlflow_runs_created > 0:
self._user_ns.add_frontend_message({"mlflowRunsCreated": self._mlflow_runs_created})
def increment_runs_created(self):
self._mlflow_runs_created += 1
def register(self):
import IPython
ipython = IPython.get_ipython()
self._user_ns = ipython.user_ns
ipython.events.register('pre_execute', self.pre_execute)
ipython.events.register('post_execute', self.pre_execute)
class DatabricksNotebookRunContext(RunContextProvider):
"""
Context provider defining a callback to be executed on run creation via the MlflowClient.create_run
API. Increments the count of created runs for the currently-running cell.
"""
def __init__(self):
self.hook = None
if self.in_context():
# Register IPython hook for submitting a count of # of created runs to the frontend
# if running in Databricks
self.hook = MlflowCreateRunHook()
self.hook.register()
def in_context(self):
return databricks_utils.is_in_databricks_notebook()
def post_create_run_hook(self, run):
"""
Hook that executes after a run is created via the MlflowClient.create_run API
(note that fluent APIs like mlflow.start_run() ultimately call MlflowClient.create_run)
:return:
"""
experiment_id = run.info.experiment_id
if experiment_id == databricks_utils.get_notebook_id():
self.hook.increment_runs_created()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment