Skip to content

Instantly share code, notes, and snippets.

@reuben
Last active February 26, 2019 14:28
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save reuben/dcc2deaf85568591e34ce363bc3bac2a to your computer and use it in GitHub Desktop.
Save reuben/dcc2deaf85568591e34ce363bc3bac2a to your computer and use it in GitHub Desktop.
diff --git a/DeepSpeech.py b/DeepSpeech.py
index 006e1a6..a80200a 100755
--- a/DeepSpeech.py
+++ b/DeepSpeech.py
@@ -484,6 +484,8 @@ def train(server=None):
saver = tf.train.Saver(max_to_keep=FLAGS.max_to_keep)
hooks.append(tf.train.CheckpointSaverHook(checkpoint_dir=FLAGS.checkpoint_dir, save_secs=FLAGS.checkpoint_secs, saver=saver))
+ Config.best_validation_saver = tf.train.Saver(max_to_keep=1)
+
no_dropout_feed_dict = {
dropout_rates[0]: 0.,
dropout_rates[1]: 0.,
@@ -543,6 +545,8 @@ def train(server=None):
config=Config.session_config) as session:
tf.get_default_graph().finalize()
+ Config.train_session = session
+
try:
if Config.is_chief:
# Retrieving global_step from the (potentially restored) model
diff --git a/util/config.py b/util/config.py
index ded9dd6..b53b95b 100644
--- a/util/config.py
+++ b/util/config.py
@@ -140,4 +140,9 @@ def initialize_globals():
# Determine, if we are the chief worker
c.is_chief = len(FLAGS.worker_hosts) == 0 or (FLAGS.task_index == 0 and FLAGS.job_name == 'worker')
+ c.best_validation_save_path = os.path.join(FLAGS.checkpoint_dir, 'best_validation')
+ c.best_validation_loss = float('inf')
+ c.best_validation_saver = None
+ c.train_session = None
+
ConfigSingleton._config = c
diff --git a/util/coordinator.py b/util/coordinator.py
index e26eb37..76bba6d 100644
--- a/util/coordinator.py
+++ b/util/coordinator.py
@@ -69,6 +69,16 @@ def new_id():
id_counter += 1
return id_counter
+
+def get_session(sess):
+ session = sess
+ while session is not None and type(session).__name__ != 'Session':
+ session = session._sess
+ print(type(session).__name__)
+ print(type(session).__name__)
+ return session
+
+
class WorkerJob(object):
'''Represents a job that should be executed by a worker.
@@ -183,6 +193,11 @@ class Epoch(object):
self.loss = agg_loss / num_jobs
+ if self.set_name == 'dev' and self.loss < Config.best_validation_loss:
+ Config.best_validation_loss = self.loss
+ path = Config.best_validation_saver.save(sess=get_session(Config.train_session), save_path=Config.best_validation_save_path, write_state=False)
+ log_info('Saving model with best validation loss ({}) at {}...'.format(Config.best_validation_loss, path))
+
# if the job was for validation dataset then append it to the COORD's _loss for early stop verification
if (FLAGS.early_stop is True) and (self.set_name == 'dev'):
self.coord._dev_losses.append(self.loss)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment