Skip to content

Instantly share code, notes, and snippets.

@angeligareta
Created June 22, 2021 23:11
Show Gist options
  • Save angeligareta/77ea24e08f46a124c4761a908b6cfdb9 to your computer and use it in GitHub Desktop.
Save angeligareta/77ea24e08f46a124c4761a908b6cfdb9 to your computer and use it in GitHub Desktop.
Custom Early Stopping callback to monitor multiple metrics by combining them using a harmonic mean calculation.
import tensorflow as tf
import numpy as np
class CustomEarlyStopping(tf.keras.callbacks.Callback):
"""
Custom Early Stopping callback to monitor multiple metrics by combining them using a harmonic mean calculation.
Adapted from (TensorFlow EarlyStopping source)[https://github.com/tensorflow/tensorflow/blob/v2.5.0/tensorflow/python/keras/callbacks.py#L1683-L1823].
Author: Angel Igareta (angel@igareta.com)
"""
def __init__(
self,
metrics_names=["loss"],
mode="min",
patience=0,
restore_weights=False,
logdir=None,
):
super(CustomEarlyStopping, self).__init__()
self.metrics_names = metrics_names
self.mode = mode
self.patience = patience
self.restore_weights = restore_weights
self.logdir = logdir
self.best_weights = None
def on_train_begin(self, logs=None):
# The number of epoch it has waited when loss is no longer minimum.
self.wait = 0
# The epoch the training stops at.
self.stopped_epoch = 0
# Initialize the best as infinity.
self.best_combined_metric = np.Inf if self.mode is "min" else -np.Inf
def on_epoch_end(self, epoch, logs=None):
metrics = [logs.get(name) for name in self.metrics_names]
metrics = tf.cast(metrics, dtype=tf.float32)
metrics_count = tf.cast(tf.size(metrics), dtype=tf.float32)
# Combined metric is the harmonic mean of the metrics_names.
combined_metric = tf.math.divide(
metrics_count, tf.math.reduce_sum(tf.math.reciprocal_no_nan(metrics))
)
# Specify logdir if you want to log the combined metric
if self.logdir:
with tf.summary.create_file_writer(self.logdir).as_default():
tf.summary.scalar("combined_metric", data=combined_metric, step=epoch)
# If harmonic mean is np.greater or np.less depending on min-max mode.
if (
self.mode is "min" and np.less(combined_metric, self.best_combined_metric)
) or (
self.mode is "max"
and np.greater(combined_metric, self.best_combined_metric)
):
self.best_combined_metric = combined_metric
self.wait = 0
# Record the best weights if current results is better.
self.best_weights = self.model.get_weights()
else:
self.wait = 1
if self.wait >= self.patience:
self.stopped_epoch = epoch
self.model.stop_training = True
# Restoring model weights from the end of the best epoch
if self.restore_weights:
self.model.set_weights(self.best_weights)
def on_train_end(self, logs=None):
if self.stopped_epoch > 0:
print("Epoch %05d: early stopping" % (self.stopped_epoch + 1))
# Use as an standard Keras callback
early_stopping_callback = CustomEarlyStopping(
metrics_names=["val_precision", "val_recall"],
mode="max",
patience=10,
restore_weights=True,
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment