Skip to content

Instantly share code, notes, and snippets.

@takuseno
Last active February 18, 2020 04:37
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save takuseno/e3cb1cfa5a7aa0d4f188dc7cb3eb29af to your computer and use it in GitHub Desktop.
Save takuseno/e3cb1cfa5a7aa0d4f188dc7cb3eb29af to your computer and use it in GitHub Desktop.
MLFlow autologging script for nnabla
import numpy as np
import mlflow
import gorilla
import time
from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed
from mlflow.utils.autologging_utils import try_mlflow_log
def _check_interval(index, flush_at, interval):
return (index - flush_at) >= interval
def autolog():
@gorilla.patch(MonitorSeries)
def add_series(self, index, value):
if _check_interval(index, self.flush_at, self.interval):
value = np.mean(self.buf + [value])
client = mlflow.tracking.MlflowClient()
try_mlflow_log(mlflow.log_metric, self.name, value, step=index)
original = gorilla.get_original_attribute(MonitorSeries, 'add')
original(self, index, value)
@gorilla.patch(MonitorTimeElapsed)
def add_time_elapsed(self, index):
if _check_interval(index, self.flush_at, self.interval):
now = time.time()
elapsed = now - self.lap
client = mlflow.tracking.MlflowClient()
try_mlflow_log(mlflow.log_metric, self.name, elapsed, step=index)
original = gorilla.get_original_attribute(MonitorTimeElapsed, 'add')
original(self, index)
settings = gorilla.Settings(allow_hit=True, store_hit=True)
patches = [
gorilla.Patch(MonitorSeries, 'add', add_series, settings=settings),
gorilla.Patch(MonitorTimeElapsed, 'add', add_time_elapsed,
settings=settings),
]
for x in patches:
gorilla.apply(x)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment