Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
def map_fun(args, ctx):
worker_num = ctx.worker_num
job_name = ctx.job_name
task_index = ctx.task_index
cluster, server = ctx.start_cluster_server(1)
if job_name == "ps":
server.join()
elif job_name == "worker":
is_chiefing = (task_index == 0)
with tf.device(tf.train.replica_device_setter(
worker_device="/job:worker/task:%d" % task_index,
cluster=cluster)):
def build_model():
pass
hooks=[...]
with tf.train.MonitoredTrainingSession(master=server.target,\
is_chief=is_chiefing,
checkpoint_dir=arsg['save_dir'],\
hooks=hooks,\
save_checkpoint_secs=600.) as mon_sess:
tf_feed = ctx.get_data_feed(train_mode=True)
while not mon_sess.should_stop() and not tf_feed.should_stop():
batch_data = tf_feed.next_batch(args['batch_size']))
#apply what you need to be done here
_ = mon_sess.run(...)
if mon_sess.should_stop():
tf_feed.terminate()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.