Skip to content

Instantly share code, notes, and snippets.

@manuzhang
Forked from yaroslavvb/simple_barrier.py
Last active March 27, 2020 07:12
Show Gist options
  • Save manuzhang/48fa9fe6de8bb9470f0b7092186a74b8 to your computer and use it in GitHub Desktop.
Save manuzhang/48fa9fe6de8bb9470f0b7092186a74b8 to your computer and use it in GitHub Desktop.
TensorFlow in-graph replication example
"""
This example is adapted from https://gist.github.com/yaroslavvb/ef407a599f0f549f62d91c3a00dcfb6c
Example of barrier implementation using TensorFlow shared variables.
All workers synchronize on barrier, copy global parameters to local versions
and increment global parameter variable asynchronously. Should see something
like this:
python simple_barrier.py --wk "node13-1:21393,node13-1:21395"
Creating session
worker 0, local_param 0 global_param 1
worker 1, local_param 0 global_param 2
worker 0, local_param 2 global_param 3
worker 1, local_param 2 global_param 4
worker 0, local_param 4 global_param 5
worker 1, local_param 4 global_param 6
worker 0, local_param 6 global_param 7
worker 1, local_param 6 global_param 8
worker 0, local_param 8 global_param 9
worker 1, local_param 8 global_param 10
worker 0, local_param 10 global_param 11
worker 1, local_param 10 global_param 12
worker 0, local_param 12 global_param 13
worker 1, local_param 12 global_param 14
worker 0, local_param 14 global_param 15
worker 1, local_param 14 global_param 16
worker 0, local_param 16 global_param 17
worker 1, local_param 16 global_param 18
worker 0, local_param 18 global_param 19
worker 1, local_param 18 global_param 20
"""
import numpy as np
import subprocess
import sys
import tensorflow as tf
import threading
import time
tf.app.flags.DEFINE_integer("iters", 10, "Maximum number of steps")
tf.app.flags.DEFINE_string("wk", "", "worker hosts")
tf.app.flags.DEFINE_float("sleep_interval", 0.1, "how long to sleep in wait loop")
FLAGS = tf.app.flags.FLAGS
worker_hosts = FLAGS.wk.split(',')
num_workers = len(worker_hosts)
# global ops
init_op = None
train_ops = [] # worker local train ops, read local params, update global
counter_vars = [] # counters for barrier
counter_adder_ops = []
global_param_var = None
local_param_vars = []
local_param_sync_ops = []
def default_config():
optimizer_options = tf.OptimizerOptions(opt_level=tf.OptimizerOptions.L0)
config = tf.ConfigProto(
graph_options=tf.GraphOptions(optimizer_options=optimizer_options))
config.log_device_placement = False
config.allow_soft_placement = False
return config
def create_graph(devices):
"""Create graph that keeps global params + counters on devices[0] and
local params/train ops on devices[:]"""
global train_ops, counter_vars, counter_adder_ops, global_param_var, local_param_vars, local_param_sync_ops
dtype=tf.int32
with tf.device(devices[0]):
global_param_var = tf.get_variable("param", shape=(), dtype=dtype,
initializer=tf.zeros_initializer)
for i in range(2):
counter_var = tf.get_variable("counter-"+str(i), (), tf.int32,
initializer=tf.zeros_initializer)
counter_vars.append(counter_var)
counter_adder_ops.append(counter_var.assign_add(1, use_locking=True))
# create local version of parameters
for (i, device) in enumerate(devices):
with tf.device(device):
local_param_var = tf.get_variable("local_param-"+str(i), (), dtype,
initializer=tf.zeros_initializer)
local_param_vars.append(local_param_var)
local_param_sync_op = local_param_var.assign(global_param_var)
local_param_sync_ops.append(local_param_sync_op)
train_op = global_param_var.assign_add(1)
train_ops.append(train_op)
init_op = tf.global_variables_initializer()
return (init_op, train_ops)
def create_worker_threads(sess):
"""Creates a thread for each op in ops, running it iters times."""
def barrier():
sess.run(counter_adder_ops[0])
while sess.run(counter_vars[0]) % num_workers != 0:
time.sleep(FLAGS.sleep_interval)
sess.run(counter_adder_ops[1])
while sess.run(counter_vars[1]) % num_workers != 0:
time.sleep(FLAGS.sleep_interval)
def create_run_method(worker_id):
def _run():
local_param_var = local_param_vars[worker_id]
sync_op = local_param_sync_ops[worker_id]
train_op = train_ops[worker_id]
for i in range(FLAGS.iters):
barrier()
sess.run(sync_op)
barrier()
old_val, updated_val = sess.run([local_param_var, train_op])
print("worker %2d, local_param %2d global_param %2d" %(worker_id,
old_val,
updated_val))
return _run
return [threading.Thread(target=create_run_method(i))
for i in range(num_workers)]
def wait_for_threads_to_finish(threads):
while any(t.is_alive() for t in threads):
time.sleep(FLAGS.sleep_interval)
def run_client():
tasks = ["/job:worker/task:%d"%(i) for i in range(num_workers)]
(init_op, add_ops) = create_graph(tasks)
# need tf.Session.reset if there are worker servers launched from before
# However, tf.Session.reset can hang if workers are in process of being
# brought up, hence more robust to do killall python
# tf.Session.reset("grpc://" + worker_ip)
print("Creating session")
sess = tf.Session("grpc://" + worker_hosts[0],
config=default_config())
sess.run(init_op)
worker_threads = create_worker_threads(sess)
[t.start() for t in worker_threads]
wait_for_threads_to_finish(worker_threads)
if __name__=='__main__':
run_client()
@777ki
Copy link

777ki commented Oct 26, 2018

in graph的这种模式异步模式感觉特别惊悚,例子很棒

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment