Skip to content

Instantly share code, notes, and snippets.

@yaroslavvb
Last active December 27, 2022 06:25
Show Gist options
  • Star 19 You must be signed in to star a gist
  • Fork 12 You must be signed in to fork a gist
  • Save yaroslavvb/ea1b1bae0a75c4aae593df7eca72d9ca to your computer and use it in GitHub Desktop.
Save yaroslavvb/ea1b1bae0a75c4aae593df7eca72d9ca to your computer and use it in GitHub Desktop.
Example of local cluster with multiple workers/training loops sharded parameter server
#!/usr/bin/env python
# Benchmark transferring data, part of troubleshooting https://github.com/tensorflow/tensorflow/issues/6116
#
# Take a independent workers communicating with b parameter shards
# Each worker tries to add to variables stored on parameter server as fast as
# possible.
#
# macbook
# ps=1: 1.6 GB/s
# ps=2: 2.6 GB/s
#
# xeon:
# ps=1: 0.5-0.6 GB/s
# ps=2: 1.1-1.3 GB/s
# ps=4: 1.7-1.9 GB/s
# ps=8: 2.6-3.1 GB/s
# ps=16: 2.3 GB/s
#
# There is significant slowdown when using larger sizes. For instance
# transferring 128MB chunks give about 446 MB/second. Changing to
# 1024MB chunks, the rate goes down to 102 MB/second
#
# to run with tcmalloc, set
# export LD_PRELOAD="/usr/lib/libtcmalloc.so.4"
#
# reduce spurious logging with TF_CPP_MIN_LOG_LEVEL=2
# Problems:
# - sometimes get scary message at the end, possibly because our ps worker
# quits while being connected to a session
import os
import subprocess
import sys
import tensorflow as tf
import threading
import time
flags = tf.flags
flags.DEFINE_integer("iters", 10, "Maximum number of additions")
flags.DEFINE_integer("data_mb", 128, "size of vector in MBs")
flags.DEFINE_integer("workers", 1, "number of workers")
flags.DEFINE_string("strategy", "push", "push to have workers update ps, pull "
"to have them pull data from ps, both to do both")
flags.DEFINE_integer("ps", 1, "number of ps shards")
flags.DEFINE_integer("starting_port", 12222, "first port to use")
flags.DEFINE_boolean("verbose", False, "extra logging")
# internal flags, don't use
flags.DEFINE_string("job_name", "", "worker or ps")
flags.DEFINE_integer("task_index", -1, "# of current task")
FLAGS = flags.FLAGS
session_config = tf.ConfigProto(intra_op_parallelism_threads=10,
inter_op_parallelism_threads=10)
# setup local cluster from flags
host = "127.0.0.1"
ps_ports = range(FLAGS.starting_port, FLAGS.starting_port+FLAGS.ps)
worker_ports = range(FLAGS.starting_port+FLAGS.ps, FLAGS.starting_port+FLAGS.ps+FLAGS.workers)
cluster = {"ps": [host+":"+str(p) for p in ps_ports],
"worker": [host+":"+str(p) for p in worker_ports]}
clusterspec = tf.train.ClusterSpec(cluster).as_cluster_def()
dtype=tf.int32
params_size = 250*1000*FLAGS.data_mb # 1MB is 250k integers
sharded_params_size = params_size/FLAGS.ps
def log(s):
if FLAGS.verbose:
print(s)
def create_graph(worker):
"""Creates graph for worker worker and all ps shards"""
params = []
updates = []
param_init_ops = []
for i in range(FLAGS.ps):
with tf.device("job:ps/task:"+str(i)):
param = tf.get_variable(name="params"+str(i),
shape=[sharded_params_size],
dtype=dtype,
initializer=tf.zeros_initializer)
params.append(param)
param_init_ops.append(param.initializer)
add_ops = []
update_init_ops = []
with tf.device("job:worker/task:"+str(worker)):
for i in range(FLAGS.ps):
update = tf.get_variable(name="update"+str(i),
shape=[sharded_params_size],
dtype=dtype,
initializer=tf.zeros_initializer)
if FLAGS.strategy == "push":
add_op = params[i].assign_add(update)
elif FLAGS.strategy == "pull":
add_op = update.assign_add(params[i])
elif FLAGS.strategy == "both":
local_update = tf.identity(params[i].read_value())
add_op = params[i].assign_add(local_update)
add_ops.append(add_op)
update_init_ops.append(update.initializer)
return update_init_ops, param_init_ops, add_ops
def create_done_queue(i):
"""Queue used to signal death for i'th ps shard. Intended to have
all workers enqueue an item onto it to signal doneness."""
with tf.device("/job:ps/task:%d" % (i)):
return tf.FIFOQueue(FLAGS.workers, tf.int32, shared_name="done_queue"+
str(i))
def create_done_queues():
return [create_done_queue(i) for i in range(FLAGS.ps)]
def run_ps():
"""Main loop for single ps server shard. Initializes variables on that
shard."""
log("ps %d: running"%(FLAGS.task_index))
server = tf.train.Server(cluster,
job_name=FLAGS.job_name,
task_index=FLAGS.task_index)
sess = tf.Session(server.target, config=session_config)
# run initialization for variables on this shard
update_init_ops, param_init_ops, add_ops = create_graph(0)
log("ps %d: initializing"%(FLAGS.task_index))
sess.run(param_init_ops[FLAGS.task_index])
queue = create_done_queue(FLAGS.task_index)
# wait until all workers are done
for i in range(FLAGS.workers):
sess.run(queue.dequeue())
log("ps %d received done %d" % (FLAGS.task_index, i))
log("ps %d: quitting"%(FLAGS.task_index))
def run_worker():
"""Main loop for single worker."""
log("worker %d: running"%(FLAGS.task_index))
update_init_ops, param_init_ops, add_ops = create_graph(FLAGS.task_index)
server = tf.train.Server(cluster,
job_name=FLAGS.job_name,
task_index=FLAGS.task_index)
sess = tf.Session(server.target, config=session_config)
sess.run(update_init_ops)
# wait for parameter server variables to be initialized
uninited_op = tf.report_uninitialized_variables()
while(len(sess.run(uninited_op)) > 0):
log("worker %d: ps uninitialized, sleeping" % FLAGS.task_index)
time.sleep(1)
for add_op in add_ops:
sess.run(add_op.op) # warm-up
start_time = time.time()
# communicate with parameter server in separate threads
def create_worker_thread(add_op, iters):
def worker_thread():
for i in range(iters):
sess.run(add_op.op)
return worker_thread
threads = []
for i in range(FLAGS.ps):
worker_thread_body = create_worker_thread(add_ops[i], FLAGS.iters)
worker_thread = threading.Thread(target=worker_thread_body)
worker_thread.start()
threads.append(worker_thread)
for thread in threads:
thread.join()
elapsed_time = time.time() - start_time
rate = float(FLAGS.iters)*FLAGS.data_mb/elapsed_time
print("worker %d done: %.2f MB per second" % (FLAGS.task_index, rate))
# signal to ps shards that we are done
for q in create_done_queues():
sess.run(q.enqueue(1))
def launch_ps():
for i in range(FLAGS.ps):
cmd = "./" + " ".join(sys.argv) + " --job_name=ps --task="+str(i)
my_env = os.environ.copy()
my_env["CUDA_VISIBLE_DEVICES"] = ""
subprocess.Popen(cmd, shell=True, stderr=subprocess.STDOUT,
env=my_env)
def launch_workers():
for i in range(FLAGS.workers):
cmd = "./" + " ".join(sys.argv) + " --job_name=worker --task="+str(i)
my_env = os.environ.copy()
# turn off GPU for speed
my_env["CUDA_VISIBLE_DEVICES"] = ""
subprocess.Popen(cmd, shell=True, stderr=subprocess.STDOUT,
env=my_env)
if __name__=='__main__':
if FLAGS.job_name == "ps":
run_ps()
elif FLAGS.job_name == "worker":
run_worker()
else:
log("client: launching ps")
launch_ps()
log("client: launching workers")
launch_workers()
@hustcat
Copy link

hustcat commented Feb 17, 2017

@yaroslavvb, I want to stop ps server gracefully, but failed with RuntimeError: Graph is finalized and cannot be modified.. Can you give me some advices?

  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/platform/app.py", line 43, in run
    sys.exit(main(sys.argv[:1] + flags_passthrough))
  File "mnist_dist2.py", line 138, in main
    for q in create_done_queues():
  File "mnist_dist2.py", line 39, in create_done_queues
    return [create_done_queue(i) for i in range(FLAGS.ps)]
  File "mnist_dist2.py", line 36, in create_done_queue
    str(i))
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/data_flow_ops.py", line 670, in __init__
    shared_name=shared_name, name=name)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/gen_data_flow_ops.py", line 522, in _fifo_queue
    name=name)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/op_def_library.py", line 759, in apply_op
    op_def=op_def)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 2199, in create_op
    self._check_not_finalized()
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 1925, in _check_not_finalized
    raise RuntimeError("Graph is finalized and cannot be modified.")
RuntimeError: Graph is finalized and cannot be modified.

Code as follows:

import math
import sys
import time

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

# Flags for defining the tf.train.ClusterSpec
tf.app.flags.DEFINE_string("ps_hosts", "",
                           "Comma-separated list of hostname:port pairs")
tf.app.flags.DEFINE_string("worker_hosts", "",
                           "Comma-separated list of hostname:port pairs")

# Flags for defining the tf.train.Server
tf.app.flags.DEFINE_string("job_name", "", "One of 'ps', 'worker'")
tf.app.flags.DEFINE_integer("task_index", 0, "Index of task within the job")
tf.app.flags.DEFINE_integer("hidden_units", 100,
                            "Number of units in the hidden layer of the NN")
tf.app.flags.DEFINE_string("data_dir", "/tmp/tensorflow/mnist/input_data/",
                           "Directory for storing mnist data")
tf.app.flags.DEFINE_integer("batch_size", 100, "Training batch size")
tf.app.flags.DEFINE_integer("workers", 2, "Number of workers")
tf.app.flags.DEFINE_integer("ps", 1, "Number of ps")
tf.app.flags.DEFINE_integer("max_step", 2000, "Number of max steps")

FLAGS = tf.app.flags.FLAGS

IMAGE_PIXELS = 28

def create_done_queue(i):
  """Queue used to signal death for i'th ps shard. Intended to have 
  all workers enqueue an item onto it to signal doneness."""
  
  with tf.device("/job:ps/task:%d" % (i)):
    return tf.FIFOQueue(FLAGS.workers, tf.int32, shared_name="done_queue"+
                        str(i))
  
def create_done_queues():
  return [create_done_queue(i) for i in range(FLAGS.ps)]

def main(_):
  ps_hosts = FLAGS.ps_hosts.split(",")
  worker_hosts = FLAGS.worker_hosts.split(",")

  # Create a cluster from the parameter server and worker hosts.
  cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts})

  # Create and start a server for the local task.
  server = tf.train.Server(cluster,
                           job_name=FLAGS.job_name,
                           task_index=FLAGS.task_index)

  if FLAGS.job_name == "ps":
    sess = tf.Session(server.target)
    queue = create_done_queue(FLAGS.task_index)
  
    # wait until all workers are done
    for i in range(FLAGS.workers):
      sess.run(queue.dequeue())
      print("ps %d received done %d" % (FLAGS.task_index, i))
     
    print("ps %d: quitting"%(FLAGS.task_index))
  elif FLAGS.job_name == "worker":

    # Assigns ops to the local worker by default.
    with tf.device(tf.train.replica_device_setter(
        worker_device="/job:worker/task:%d" % FLAGS.task_index,
        cluster=cluster)):

      # Variables of the hidden layer
      hid_w = tf.Variable(
          tf.truncated_normal([IMAGE_PIXELS * IMAGE_PIXELS, FLAGS.hidden_units],
                              stddev=1.0 / IMAGE_PIXELS), name="hid_w")
      hid_b = tf.Variable(tf.zeros([FLAGS.hidden_units]), name="hid_b")

      # Variables of the softmax layer
      sm_w = tf.Variable(
          tf.truncated_normal([FLAGS.hidden_units, 10],
                              stddev=1.0 / math.sqrt(FLAGS.hidden_units)),
          name="sm_w")
      sm_b = tf.Variable(tf.zeros([10]), name="sm_b")

      x = tf.placeholder(tf.float32, [None, IMAGE_PIXELS * IMAGE_PIXELS])
      y_ = tf.placeholder(tf.float32, [None, 10])

      hid_lin = tf.nn.xw_plus_b(x, hid_w, hid_b)
      hid = tf.nn.relu(hid_lin)

      y = tf.nn.softmax(tf.nn.xw_plus_b(hid, sm_w, sm_b))
      loss = -tf.reduce_sum(y_ * tf.log(tf.clip_by_value(y, 1e-10, 1.0)))

      global_step = tf.Variable(0)

      train_op = tf.train.AdagradOptimizer(0.01).minimize(
          loss, global_step=global_step)
       
      correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
      accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

      saver = tf.train.Saver()
      summary_op = tf.merge_all_summaries()
      init_op = tf.initialize_all_variables()

    # Create a "supervisor", which oversees the training process.
    sv = tf.train.Supervisor(is_chief=(FLAGS.task_index == 0),
                             logdir="./logs_%d" % FLAGS.task_index,
                             init_op=init_op,
                             summary_op=summary_op,
                             saver=saver,
                             global_step=global_step,
                             save_model_secs=60)

    mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)

    begin_time = time.time()
    frequency = 100
    # The supervisor takes care of session initialization, restoring from
    # a checkpoint, and closing when done or an error occurs.
    with sv.managed_session(server.target) as sess:
      # Loop until the supervisor shuts down or 100000 steps have completed.
      step = 0
      while not sv.should_stop() and step < FLAGS.max_step:
        # Run a training step asynchronously.
        # See `tf.train.SyncReplicasOptimizer` for additional details on how to
        # perform *synchronous* training.
        start_time = time.time()

        batch_xs, batch_ys = mnist.train.next_batch(FLAGS.batch_size)
        train_feed = {x: batch_xs, y_: batch_ys}

        _, step = sess.run([train_op, global_step], feed_dict=train_feed)
        elapsed_time = time.time() - start_time
        if step % frequency == 0: 
            print ("Done step %d" % step, " AvgTime: %3.2fms" % float(elapsed_time*1000/frequency))


      # signal to ps shards that we are done
      for q in create_done_queues():
        sess.run(q.enqueue(1))
      # Test trained model
      print("Test-Accuracy: %2.4f" % sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}))

    print("Total Time: %3.2fs" % float(time.time() - begin_time))
  

    # Ask for all the services to stop.
    sv.stop()

if __name__ == "__main__":
  tf.app.run()

@samwhitlock
Copy link

I think zeros_initializer needs to be updated to zeros_initializer() for TF 1.0

@Flamefire
Copy link

@hustcat The problem is, that the supervisor finalizes the the graph after which q.enqueue is not possible anymore.

However: q.enqueue(1) returns an operation which can be run by the session object. The solution is simply:
Add finalize_ops = [q.enqueue(1) for q in create_done_queue(i)] before creating the supervisor. And instead of the original loop on the bottom you loop over this: for op in finalize_ops: sess.run(op)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment