Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
with tf.train.MonitoredTrainingSession(master=server.target,\
is_chief=is_chiefing,
checkpoint_dir=arsg['save_dir'],\
hooks=hooks,\
save_checkpoint_secs=600.) as mon_sess:
tf_feed = ctx.get_data_feed(train_mode=True)
step = 0
while not mon_sess.should_stop() and not tf_feed.should_stop() and step < args['steps']:
batch_data, batch_labels = get_next_batch(tf_feed.next_batch(args['batch_size']))
if len(batch_data) > 0:
feed = {model_input: batch_data, model_labels: batch_labels}
_, logloss, step = mon_sess.run([tf_optimizer, tf_loss,tf_global_step],feed_dict=feed)
if mon_sess.should_stop() or step >= args['steps']:
tf_feed.terminate()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.