Skip to content

Instantly share code, notes, and snippets.

@himanshurawlani
Last active September 10, 2020 18:45
Show Gist options
  • Save himanshurawlani/b42addff1ccf697066d64d3ca3c6865d to your computer and use it in GitHub Desktop.
Save himanshurawlani/b42addff1ccf697066d64d3ca3c6865d to your computer and use it in GitHub Desktop.
An example Keras callback to report metrics to Ray Tune after every epoch
class TuneReporter(tf.keras.callbacks.Callback):
"""Tune Callback for Keras."""
def __init__(self, reporter=None, freq="epoch", logs=None):
"""Initializer.
Args:
freq (str): Sets the frequency of reporting intermediate results.
"""
self.iteration = 0
logs = logs or {}
self.freq = freq
super(TuneReporter, self).__init__()
def on_epoch_end(self, epoch, logs=None):
from ray import tune
logs = logs or {}
if not self.freq == "epoch":
return
self.iteration += 1
if "acc" in logs:
tune.report(keras_info=logs, val_loss=logs['val_loss'], mean_accuracy=logs["acc"])
else:
tune.report(keras_info=logs, val_loss=logs['val_loss'], mean_accuracy=logs.get("accuracy"))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment