Skip to content

Instantly share code, notes, and snippets.

@TuranTimur
Created November 3, 2017 07:50
Show Gist options
  • Save TuranTimur/5a2eb8c01409a63f64fa64ab2af2bdc2 to your computer and use it in GitHub Desktop.
Save TuranTimur/5a2eb8c01409a63f64fa64ab2af2bdc2 to your computer and use it in GitHub Desktop.
in-graph replica tensorflow
with tf.device("/job:ps/task:0"):
weights_1 = tf.Variable(...)
biases_1 = tf.Variable(...)
...
worker_devices = ["/job:worker/task:0/gpu:0", ..., "/job:worker/task:7/gpu:0"]
for worker_device in worker_devices:
with tf.device(worker_device):
with tf.device("/job:worker/task:7"):
input, labels = ...
layer_1 = tf.nn.relu(tf.matmul(input, weights_1) + biases_1)
logits = tf.nn.relu(tf.matmul(layer_1, weights_2) + iases_2)
train_op = ...
with tf.Session("grpc://worker7.example.com:2222") as sess:
for _ in range(10000):
sess.run(train_op)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment