Last active
August 1, 2023 19:41
-
-
Save yaroslavvb/82a5b5302449530ca5ff59df520c369e to your computer and use it in GitHub Desktop.
Example of bringing down parameter server by using shared queue to signal
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Example of launching distributed service and then bringint it down.""" | |
import subprocess | |
import tensorflow as tf | |
import time | |
import sys | |
flags = tf.flags | |
flags.DEFINE_string("port1", "12222", "port of worker1") | |
flags.DEFINE_string("port2", "12223", "port of worker2") | |
flags.DEFINE_string("task", "", "internal use") | |
FLAGS = flags.FLAGS | |
# setup local cluster from flags | |
host = "127.0.0.1:" | |
cluster = {"worker": [host+FLAGS.port1, host+FLAGS.port2]} | |
clusterspec = tf.train.ClusterSpec(cluster).as_cluster_def() | |
if __name__=='__main__': | |
if not FLAGS.task: # start servers and run client | |
# launch distributed service | |
def runcmd(cmd): subprocess.Popen(cmd, shell=True, stderr=subprocess.STDOUT) | |
runcmd("python %s --task=0"%(sys.argv[0])) | |
runcmd("python %s --task=1"%(sys.argv[0])) | |
time.sleep(1) | |
# bring down distributed service | |
sess = tf.Session("grpc://"+host+FLAGS.port1) | |
queue0 = tf.FIFOQueue(1, tf.int32, shared_name="queue0") | |
queue1 = tf.FIFOQueue(1, tf.int32, shared_name="queue1") | |
with tf.device("/job:worker/task:0"): | |
add_op0 = tf.add(tf.ones(()), tf.ones(())) | |
with tf.device("/job:worker/task:1"): | |
add_op1 = tf.add(tf.ones(()), tf.ones(())) | |
print("Running computation on server 0") | |
print(sess.run(add_op0)) | |
print("Running computation on server 1") | |
print(sess.run(add_op1)) | |
print("Bringing down server 0") | |
sess.run(queue0.enqueue(1)) | |
print("Bringing down server 1") | |
sess.run(queue1.enqueue(1)) | |
else: # Launch TensorFlow server | |
server = tf.train.Server(clusterspec, config=None, | |
job_name="worker", | |
task_index=int(FLAGS.task)) | |
print("Starting server "+FLAGS.task) | |
sess = tf.Session(server.target) | |
queue = tf.FIFOQueue(1, tf.int32, shared_name="queue"+FLAGS.task) | |
sess.run(queue.dequeue()) | |
print("Terminating server"+FLAGS.task) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment