Skip to content

Instantly share code, notes, and snippets.

@jonnor
Created November 14, 2021 17:59
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jonnor/c107f3ca24a36c91d8ff94029a0cd357 to your computer and use it in GitHub Desktop.
Save jonnor/c107f3ca24a36c91d8ff94029a0cd357 to your computer and use it in GitHub Desktop.
MLFlow integration for Keras-Tuner
"""mlflow integration for KerasTuner
Copyright: Soundsensing AS, 2021
License: MIT
"""
import uuid
import structlog
log = structlog.get_logger()
import mlflow
import keras_tuner
def get_run_id(run):
if run is None:
return None
return run.info.run_id
class MlflowLogger(object):
"""KerasTuner Logger for integrating with mlflow
Each KerasTuner trial is a parent mlflow run,
and then each execution is a child
XXX: assumes that executions are done sequentially and non-concurrently
"""
def __init__(self):
self.search_run = None
self.search_id = None
self.trial_run = None
self.trial_id = None
self.trial_state = None
self.execution_run = None
self.execution_id = 0
def register_tuner(self, tuner_state):
"""Called at start of search"""
log.debug('mlflow-logger-search-start')
self.search_id = str(uuid.uuid4())
# Register a top-level run
self.search_run = mlflow.start_run(nested=False, run_name=f'search-{self.search_id[0:8]}')
def exit(self):
"""Called at end of a search"""
log.debug('mlflow-logger-search-end')
self.seach_run = None
self.search_id = None
def register_trial(self, trial_id, trial_state):
"""Called at beginning of trial"""
log.debug('mlflow-logger-trial-start',
trial_id=trial_id,
active_run_id=get_run_id(mlflow.active_run()),
)
assert self.search_run is not None
assert self.trial_run is None
assert self.execution_run is None
assert self.execution_id == 0
self.trial_id = trial_id
self.trial_state = trial_state
# Start a new run, under the search run
self.trial_run = mlflow.start_run(nested=True,
run_name=f'trial-{self.trial_id[0:8]}-{self.search_id[0:8]}'
)
# For now, only register these on each execution
#hyperparams = self.trial_state['hyperparameters']['values']
#mlflow.log_params(hyperparams)
def report_trial_state(self, trial_id, trial_state):
"""Called at end of trial"""
log.debug('mlflow-logger-trial-end',
trial_id=trial_id,
active_run_id=get_run_id(mlflow.active_run()),
)
assert self.search_run is not None
assert self.trial_run is not None
assert self.execution_run is None
# Start a new run, under the search run
mlflow.end_run() ## XXX: no way to specify the id?
self.trial_run = None
self.trial_id = None
self.trial_state = None
self.execution_id = 0
def register_execution(self):
log.debug('mlflow-logger-execution-start',
active_run_id=get_run_id(mlflow.active_run()),
)
assert self.search_run is not None
assert self.trial_run is not None
assert self.execution_run is None
self.execution_run = mlflow.start_run(nested=True,
run_name=f'exec-{self.execution_id}-{self.trial_id[0:8]}-{self.search_id[0:8]}',
)
self.execution_id += 1
# register hyperparameters from the trial
hyperparams = self.trial_state['hyperparameters']['values']
mlflow.log_params(hyperparams)
def report_execution_state(self, histories):
log.debug('mlflow-logger-execution-end',
active_run_id=get_run_id(mlflow.active_run()),
)
assert self.search_run is not None
assert self.trial_run is not None
assert self.execution_run is not None
mlflow.end_run() ## XXX: no way to specify the id?
self.execution_run = None
class FakeHistories():
def __init__(self, metrics={}):
self.history = metrics
class LoggerTunerMixin():
def __init__(self, *args, **kwargs):
if kwargs.get('logger') is None:
kwargs['logger'] = MlflowLogger()
self.on_exception = kwargs.get('on_exception', 'pass')
return super(LoggerTunerMixin, self).__init__(*args, **kwargs)
# Hack in registration for each model training "execution"
def _build_and_fit_model(self, trial, *args, **kwargs):
# log start of execution
if self.logger:
self.logger.register_execution()
histories = None
try:
# call the original function
histories = super(LoggerTunerMixin, self)._build_and_fit_model(trial, *args, **kwargs)
except Exception as e:
if self.on_exception == 'pass':
o = self.oracle.objective
value = float('inf') if o.direction == 'min' else float('-inf')
histories = FakeHistories({o.name: value})
else:
raise e
# log end of execution
if self.logger:
self.logger.report_execution_state(histories)
return histories
# Integrate with keras tuners
class RandomSearch(LoggerTunerMixin, keras_tuner.RandomSearch):
pass
class BayesianOptimization(LoggerTunerMixin, keras_tuner.BayesianOptimization):
pass
class SklearnTuner(LoggerTunerMixin, keras_tuner.SklearnTuner):
pass
class Hyperband(LoggerTunerMixin, keras_tuner.Hyperband):
pass
@jonnor
Copy link
Author

jonnor commented Nov 23, 2021

One imports these overridden classes, and uses that to instantiate as the search.

    from mlflow_keras_tuner import RandomSearch
    tuner = RandomSearch(
        .... ordinary parameters ...
    )

@metalglove
Copy link

I have to add mlflow.tensorflow.autolog() for it to log my loss as well.

@jonnor
Copy link
Author

jonnor commented Nov 23, 2021

@metalglove ah yes, good point.
And anything one want logged that is not part of the autolog, one should use the mlflow tracking API for (mlflow.log_*).
https://mlflow.org/docs/latest/tracking.html#logging-functions

@abdulbasitds
Copy link

@jonnor I am getting this error ModuleNotFoundError: No module named 'structlog', am I missing something?

@jonnor
Copy link
Author

jonnor commented Dec 4, 2021 via email

@SteefanContractor
Copy link

@metalglove ah yes, good point. And anything one want logged that is not part of the autolog, one should use the mlflow tracking API for (mlflow.log_*). https://mlflow.org/docs/latest/tracking.html#logging-functions

I'm trying to log a plot of a confusion matrix as an mlflow artifact. Where should I add that bit of code? And any suggestions for how I can retrieve the model from the active trial to run model.predict?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment