with tf.train.MonitoredTrainingSession(master=server.target,\ | |
is_chief=is_chiefing, | |
checkpoint_dir=arsg['save_dir'],\ | |
hooks=hooks,\ | |
save_checkpoint_secs=600.) as mon_sess: | |
tf_feed = ctx.get_data_feed(train_mode=True) | |
step = 0 | |
while not mon_sess.should_stop() and not tf_feed.should_stop() and step < args['steps']: | |
batch_data, batch_labels = get_next_batch(tf_feed.next_batch(args['batch_size'])) | |
if len(batch_data) > 0: | |
feed = {model_input: batch_data, model_labels: batch_labels} | |
_, logloss, step = mon_sess.run([tf_optimizer, tf_loss,tf_global_step],feed_dict=feed) | |
if mon_sess.should_stop() or step >= args['steps']: | |
tf_feed.terminate() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment