Skip to content

Instantly share code, notes, and snippets.

@alexwal
Forked from yaroslavvb/simple_barrier.py
Last active March 27, 2020 07:13
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save alexwal/decd4cf124023113b2633dac9ef34fc5 to your computer and use it in GitHub Desktop.
Save alexwal/decd4cf124023113b2633dac9ef34fc5 to your computer and use it in GitHub Desktop.
Example of using shared counters to implement Barrier primitive
'''
Alex Walczak, 2017
Example of barrier implementation using TensorFlow shared variables
across a multi-machine cluster.
All workers synchronize on the barrier, copy global parameters to local versions,
and increment the global parameter variable asynchronously.
On each worker run:
$ killall python3
If you have a cluster of 4 machines, then on the first machine run:
$ python3 cluster_barrier_for_tensorflow.py --job_name=ps --task_index=0 --ps_hosts=... --worker_hosts=... &
(with the trailing &)
$ python3 cluster_barrier_for_tensorflow.py --job_name=worker --task_index=0 --ps_hosts=... --worker_hosts=...
And on the other 3 machines, run one of:
$ python3 cluster_barrier_for_tensorflow.py --job_name=worker --task_index=1 --ps_hosts=... --worker_hosts=...
$ python3 cluster_barrier_for_tensorflow.py --job_name=worker --task_index=2 --ps_hosts=... --worker_hosts=...
$ python3 cluster_barrier_for_tensorflow.py --job_name=worker --task_index=3 --ps_hosts=... --worker_hosts=...
You should see something like this for each worker k=0,1,2,3:
Worker k: local_param 0, global_param 1
Worker k: local_param 4, global_param 5
Worker k: local_param 8, global_param 9
Worker k: local_param 12, global_param 13
Worker k: local_param 16, global_param 17
(Tested with Tensorflow r1.5)
Thanks to Yaroslav Bulatov (yaroslavvb) for the original implementation,
which spawns multiple processes on a single machine.
https://gist.github.com/yaroslavvb/ef407a599f0f549f62d91c3a00dcfb6c
'''
import numpy as np
import os
import time
import tensorflow as tf
os.environ['TF_CPP_MIN_LOG_LEVEL']='2' # optionally supress TF warnings in log files
tf.app.flags.DEFINE_string('ps_hosts', '', 'comma separated list of ps_host_ip:port')
tf.app.flags.DEFINE_string('worker_hosts', '', 'comma separated list of worker_host_ip:port')
tf.app.flags.DEFINE_string('job_name', '', 'the job: ps | worker')
tf.app.flags.DEFINE_integer('task_index', 0, 'which task number for the ps or worker')
tf.app.flags.DEFINE_float('sleep_interval', 0.1, 'how long to sleep in wait loop')
tf.app.flags.DEFINE_integer('iters', 10, 'maximum number of steps to run per worker')
FLAGS = tf.app.flags.FLAGS
def default_config():
optimizer_options = tf.OptimizerOptions(opt_level=tf.OptimizerOptions.L0)
config = tf.ConfigProto(
graph_options=tf.GraphOptions(optimizer_options=optimizer_options))
config.log_device_placement = False
config.allow_soft_placement = False
return config
def run_test():
# TF DIST SETUP
ps_hosts = FLAGS.ps_hosts.split(',')
worker_hosts = FLAGS.worker_hosts.split(',')
# Create a cluster from the parameter server and worker hosts.
cluster = tf.train.ClusterSpec({'ps': ps_hosts, 'worker': worker_hosts})
# Create and start a server for the local task.
server = tf.train.Server(cluster,
job_name=FLAGS.job_name,
task_index=FLAGS.task_index,
config=default_config())
if FLAGS.job_name == 'ps':
server.join()
elif FLAGS.job_name == 'worker':
dtype = tf.int32
num_workers = len(worker_hosts)
# vars and ops
init_op = None
train_ops = [] # worker local train ops, read local params, update global
counter_vars = [] # counters for barrier
counter_adder_ops = []
global_param_var = None
local_param_vars = []
local_param_sync_ops = []
# all ps and worker tasks
ps_device = '/job:ps/task:0/cpu:0'
ps_host = ps_hosts[0] # assume we are only launching a single ps host and many workers
worker_devices = ['/job:worker/task:{}'.format(i) for i in range(num_workers)]
# create global parameters
with tf.device(ps_device):
global_param_var = tf.get_variable('param', shape=(), dtype=dtype,
initializer=tf.zeros_initializer)
for i in range(2):
counter_var = tf.get_variable('counter-{}'.format(i), (), dtype,
initializer=tf.zeros_initializer)
counter_vars.append(counter_var)
counter_adder_ops.append(counter_var.assign_add(1, use_locking=True))
# create local version of parameters
for (i, device) in enumerate(worker_devices):
with tf.device(device):
local_param_var = tf.get_variable('local_param-{}'.format(i), (), dtype,
initializer=tf.zeros_initializer)
local_param_vars.append(local_param_var)
local_param_sync_op = local_param_var.assign(global_param_var)
local_param_sync_ops.append(local_param_sync_op)
train_op = global_param_var.assign_add(1)
train_ops.append(train_op)
init_op = tf.global_variables_initializer()
with tf.Session('grpc://{}'.format(ps_host), config=default_config()) as sess: # Workers connect to the same Session
def barrier():
# When this function returns, every worker will execute the following line next.
for i in range(2):
sess.run(counter_adder_ops[i]) # Increment global counter once on this worker
while sess.run(counter_vars[i]) % num_workers != 0: # Wait until every worker has incremented the global counter
time.sleep(FLAGS.sleep_interval) # Sleep ensures that every worker will increment the global counter
sess.run(init_op)
worker_id = FLAGS.task_index
local_param_var = local_param_vars[worker_id]
sync_op = local_param_sync_ops[worker_id]
train_op = train_ops[worker_id]
for i in range(FLAGS.iters):
# 1. Wait for all workers to finish incrementing global_param_var
barrier()
sess.run(sync_op)
# 2. Wait for all workers to finish assigning global_param_var to local_param_var
barrier()
local_val, global_val = sess.run([local_param_var, train_op]) # Increment global_param_var
print('Worker {}: local_param {}, global_param {}'.format(worker_id, local_val, global_val))
barrier()
sess.run(sync_op)
local_val, global_val = sess.run([local_param_var, global_param_var])
print('+++ Final value for worker {}: local_param {}, global_param {}'.format(worker_id, local_val, global_val))
@formath
Copy link

formath commented Apr 28, 2018

Why use two counters in barrier?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment