Created
September 21, 2016 15:51
-
-
Save hongthaiphi/2ef226bb86be1a862f2fe2762be30b95 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
estimator = tf.contrib.learn.Estimator(model_fn=model_fn, model_dir=MODEL_DIR, config=tf.contrib.learn.RunConfig()) | |
input_fn_train = udc_inputs.create_input_fn(mode=tf.contrib.learn.ModeKeys.TRAIN, input_files=[TRAIN_FILE], batch_size=hparams.batch_size) | |
input_fn_eval = udc_inputs.create_input_fn(mode=tf.contrib.learn.ModeKeys.EVAL, input_files=[VALIDATION_FILE], batch_size=hparams.eval_batch_size, num_epochs=1) | |
eval_metrics = udc_metrics.create_evaluation_metrics() | |
# We need to subclass theis manually for now. The next TF version will | |
# have support ValidationMonitors with metrics built-in. | |
# It’s already on the master branch. | |
class EvaluationMonitor(tf.contrib.learn.monitors.EveryN): | |
def every_n_step_end(self, step, outputs): | |
self._estimator.evaluate(input_fn=input_fn_eval, metrics=eval_metrics, steps=None) | |
eval_monitor = EvaluationMonitor(every_n_steps=FLAGS.eval_every) | |
estimator.fit(input_fn=input_fn_train, steps=None, monitors=[eval_monitor]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment