Skip to content

Instantly share code, notes, and snippets.

@mercy0387
Created December 6, 2017 09:01
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mercy0387/e6ffbe7d7420ea32a69bff17da1f268a to your computer and use it in GitHub Desktop.
Save mercy0387/e6ffbe7d7420ea32a69bff17da1f268a to your computer and use it in GitHub Desktop.
custom checkpoint to save model at your timing
from keras.callbacks import Callback
import numpy as np
class CustomModelCheckpoint(Callback):
"""
Custom Metricsが特定の閾値を超えたときだけモデルを保存するチェックポイント
thresholdsはkeyがmetrics名、値が閾値となるDictionary
"""
def __init__(self, filepath, thresholds, inverse=False):
super(CustomModelCheckpoint, self).__init__()
self.filepath = filepath
self.thresholds = thresholds
self.inverse = inverse
def on_epoch_end(self, epoch, logs=None):
"""
logsにmetricsに指定した内容が(validationの方はval_というプレフィックス付きで)入っている。
"""
logs = logs or {}
filepath = self.filepath.format(epoch=epoch, **logs)
save = True
for k,v in self.thresholds.items():
if self.inverse:
if (k in logs and logs[k] > v) or np.isnan(logs[k]):
save = False
break
else:
if (k in logs and logs[k] < v) or np.isnan(logs[k]):
save = False
break
if save:
self.model.save(filepath, overwrite=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment