Skip to content

Instantly share code, notes, and snippets.

@raven4752
Created July 11, 2018 09:56
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save raven4752/3669ac1cf4aa7f9faf63d3328cd507f7 to your computer and use it in GitHub Desktop.
Save raven4752/3669ac1cf4aa7f9faf63d3328cd507f7 to your computer and use it in GitHub Desktop.
callback to save best model and early stopping with multi-input/multi-output using custom score functions
import numpy as np
import pandas as pd
from keras.callbacks import Callback
class ScoreMetric(Callback):
def __init__(self, score_func, num_input=1, num_target=1):
super(ScoreMetric, self).__init__()
self.num_input = num_input
self.num_target = num_target
self.score_func = score_func
def on_train_begin(self, logs={}):
self.custom_val_scores = []
def on_epoch_end(self, epoch, logs={}):
if self.num_input == 1:
val_predict = self.model.predict(self.validation_data[0])
else:
val_predict = self.model.predict(self.validation_data[0:self.num_input])
if len(val_predict) == 1:
val_targ = [self.validation_data[self.num_input]]
else:
val_targ = self.validation_data[self.num_input:self.num_input + len(val_predict)]
_val_score = self.score_func(val_targ, val_predict)
self.custom_val_scores.append(_val_score)
print('— val_score ' + str(_val_score))
return
class SaveBestModelCallBack(ScoreMetric):
def __init__(self, score_func, model_path, num_input=1, num_target=1, patience=5, verbose=1):
super(SaveBestModelCallBack, self).__init__(score_func, num_input, num_target)
self.model_path = model_path
self.monitor_op = np.greater
self.best_score = np.Inf if self.monitor_op == np.less else -np.Inf
# self.best_weights = None
self.patience = patience
self.verbose = verbose
self.wait = 0
self.stopped_epoch = 0
def on_epoch_end(self, epoch, logs={}):
super(SaveBestModelCallBack, self).on_epoch_end(epoch, logs)
current = self.custom_val_scores[-1]
if self.monitor_op(current, self.best_score):
self.wait = 0
else:
self.wait += 1
if self.wait >= self.patience:
self.stopped_epoch = epoch
self.model.stop_training = True
if current > self.best_score:
self.best_score = current
# save_model(self.model,os.path.join(self.tmp_dir,'best.h5'),include_optimizer=True)
self.model.save(self.model_path)
# self.best_weights.set_weights(self.model.get_weights())
def on_train_begin(self, logs={}):
super(SaveBestModelCallBack, self).on_train_begin(logs)
# Allow instances to be re-used
self.wait = 0
self.stopped_epoch = 0
self.best_score = np.Inf if self.monitor_op == np.less else -np.Inf
def on_train_end(self, logs=None):
super(SaveBestModelCallBack, self).on_train_end(logs)
self.model = load_model(self.model_path)
gc.collect()
if self.stopped_epoch > 0 and self.verbose > 0:
print('Epoch %05d: early stopping' % (self.stopped_epoch + 1))
class BestModelCallBack(ScoreMetric):
def __init__(self, score_func, num_input=1, num_target=1, patience=5, verbose=1):
super(BestModelCallBack, self).__init__(score_func, num_input, num_target)
self.monitor_op = np.greater
self.best_score = np.Inf if self.monitor_op == np.less else -np.Inf
# self.best_weights = None
self.patience = patience
self.verbose = verbose
self.wait = 0
self.stopped_epoch = 0
self.best_weights = None
def on_epoch_end(self, epoch, logs={}):
super(BestModelCallBack, self).on_epoch_end(epoch, logs)
current = self.custom_val_scores[-1]
if self.monitor_op(current, self.best_score):
self.wait = 0
else:
self.wait += 1
if self.wait >= self.patience:
self.stopped_epoch = epoch
self.model.stop_training = True
if current > self.best_score:
self.best_score = current
# save_model(self.model,os.path.join(self.tmp_dir,'best.h5'),include_optimizer=True)
self.best_weights = self.model.get_weights()
# self.best_weights.set_weights(self.model.get_weights())
def on_train_begin(self, logs={}):
super(BestModelCallBack, self).on_train_begin(logs)
# Allow instances to be re-used
self.wait = 0
self.stopped_epoch = 0
self.best_score = np.Inf if self.monitor_op == np.less else -np.Inf
def on_train_end(self, logs=None):
super(BestModelCallBack, self).on_train_end(logs)
self.model.set_weights(self.best_weights)
self.best_weights = None
gc.collect()
if self.stopped_epoch > 0 and self.verbose > 0:
print('Epoch %05d: early stopping' % (self.stopped_epoch + 1))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment