Skip to content

Instantly share code, notes, and snippets.

@hongthaiphi
Created September 21, 2016 15:51
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save hongthaiphi/2ef226bb86be1a862f2fe2762be30b95 to your computer and use it in GitHub Desktop.
Save hongthaiphi/2ef226bb86be1a862f2fe2762be30b95 to your computer and use it in GitHub Desktop.
estimator = tf.contrib.learn.Estimator(model_fn=model_fn, model_dir=MODEL_DIR, config=tf.contrib.learn.RunConfig())
input_fn_train = udc_inputs.create_input_fn(mode=tf.contrib.learn.ModeKeys.TRAIN, input_files=[TRAIN_FILE], batch_size=hparams.batch_size)
input_fn_eval = udc_inputs.create_input_fn(mode=tf.contrib.learn.ModeKeys.EVAL, input_files=[VALIDATION_FILE], batch_size=hparams.eval_batch_size, num_epochs=1)
eval_metrics = udc_metrics.create_evaluation_metrics()
# We need to subclass theis manually for now. The next TF version will
# have support ValidationMonitors with metrics built-in.
# It’s already on the master branch.
class EvaluationMonitor(tf.contrib.learn.monitors.EveryN):
def every_n_step_end(self, step, outputs):
self._estimator.evaluate(input_fn=input_fn_eval, metrics=eval_metrics, steps=None)
eval_monitor = EvaluationMonitor(every_n_steps=FLAGS.eval_every)
estimator.fit(input_fn=input_fn_train, steps=None, monitors=[eval_monitor])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment