Skip to content

Instantly share code, notes, and snippets.

Forked from haje01/
Created Jan 14, 2017
What would you like to do?
Distributed TensorFlow

분산 텐서플로우

이 글을 작성하는 시점(2017-01-11)에서 분산 텐서플로우의 관련 자료 부족으로 확실히 분산 학습이 되는지 확인이 되지 않았습니다. 안타깝지만 본 내용은 참고만 하시기 바랍니다.

원문 [] (

개념 설명

  • 클러스터는 텐서플로우 그래프의 분산 수행에 참여하는 테스크들의 집합
  • 각 테스크는 세션을 생성할 수 있는 마스터와 그래프에서 작업을 수행하는 워커를 가지는 텐서플로우 서버에 연결
  • 클러스터는 하나 이상의 테스크를 가지는 하나 이상의 으로 구분됨
  • 클러스터를 생성하기 위해 테스크 당 하나의 텐서플로우 서버를 띄운다.
  • 하나의 테스크는 하나의 장비를 사용하지만 여러 테스크를 하나의 장비에서 띄울 수도 있다.(복수의 GPU가 있는 경우 등)

각 테스크에 대해:

  1. 클러스터의 모든 테스크를 기술하는 tf.train.ClusterSpec를 만든다. 모든 테스크에 대해서 동일하다.
  2. tf.train.ClusterSpec을 생성자에 넣어 tf.train.Server를 만든다. 이때 잡 이름과 테스크 인덱스도 함께 넣어 로컬 테스크를 구분한다.

서버 tf.train.Server 인스턴스 만들기

tf.train.Server 객체는 로컬 디바이스 셋, tf.train.ClusterSpec에 있는 다른 테스크들과의 커넥션 그리고 이것들을 이용해 분산 연산을 수행하는 세션 타겟을 가진다. 각 서버는 이름이 있는 잡의 멤버이며 그 잡에 있는 테스크 인덱스를 가진다. 서버는 크러스터 내의 다른 서버들과 통신할 수 있다.

Terraform으로 배포

분산 처리를 위해서는 하나 이상의 장비가 필요한데, 이들간 셋팅이 번거롭다. 이를 해결하기 위해 Terraform을 통해 AWS를 사용한다. 여기에서는 텐서플로우 소스내 배포되는 mnist_replica.py로 설명하겠다.


Terraform용 변수 파일(.tfvars)에 이용할 환경을 설정. 여기에서는 seoul.tfvars

region = "ap-northeast-2"  # AWS Seoul 리전
ami = "ami-f293459c"  # 사용할 AMI, 여기에서는 Seoul리전의 Ubuntu
key_name = "AWS 키 이름"
key_file = "AWS용 SSH PRIVATE KEY 경로"
owner = ""  # 소유자
instance_type = "t2.small"  # 사용할 EC2 타입
ps_cnt = 1  # 패러미터 서버 개수
worker_cnt = 2  # 워커 개수

AWS를 프로바이더로 사용하기 위해 AWS_ACCESS_KEY_IDAWS_SECRET_ACCESS_KEY를 환경변수로 설정하시기 바랍니다.

시스템 구축 및 실행

아래의 명령어로 사용할 AWS 환경을 자동 구축한다.

terraform apply -var-file=seoul.tfvars

에러없이 되었으면 실행한다.


관련 파일 설명:

  • TensorFlow 및 관련 모듈 설치
  • Terraform 레시피
  • seoul.tfvars Terraform 변수 설정파일
  • 생성한 모든 노드에서 스크립트를 실행
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
variable "region" {}
variable "ami" {}
variable "key_name" {}
variable "key_file" {}
variable "instance_type" {}
variable "ps_cnt" {}
variable "worker_cnt" {}
variable "owner" {}
provider "aws" {
region = "${var.region}"
# WARNING: Allow all traffic for simplicty. NOT SAFE!
resource "aws_security_group" "allow_all" {
name = "allow_all"
description = "Allow all inbound traffic"
ingress {
from_port = 22
to_port = 22
protocol = "tcp"
cidr_blocks = [""]
ingress {
from_port = 2222
to_port = 2222
protocol = "tcp"
cidr_blocks = [""]
egress {
from_port = 0
to_port = 0
protocol = "-1"
cidr_blocks = [""]
resource "aws_instance" "ps" {
ami = "${var.ami}"
instance_type = "${var.instance_type}"
count = "${var.ps_cnt}"
key_name = "${var.key_name}"
security_groups = ["${}"]
tags {
Name = "tf-ps${count.index + 1}"
Desc = "Tensorflow PS"
Owner = "${var.owner}"
Service = "RnD"
connection {
type = "ssh"
user = "ubuntu"
private_key = "${file(var.key_file)}"
provisioner "file" {
source = ""
destination = "/home/ubuntu/"
provisioner "remote-exec" {
script = ""
resource "aws_instance" "worker" {
ami = "${var.ami}"
instance_type = "${var.instance_type}"
count = "${var.worker_cnt}"
key_name = "${var.key_name}"
security_groups = ["${}"]
tags {
Name = "tf-worker${count.index + 1}"
Desc = "Tensorflow Worker"
Owner = "${var.owner}"
Service = "RnD"
connection { type = "ssh"
user = "ubuntu"
private_key = "${file(var.key_file)}"
provisioner "file" {
source = ""
destination = "/home/ubuntu/"
provisioner "remote-exec" {
script = ""
resource "null_resource" "run-ps" {
count = "${var.ps_cnt}"
connection {
type = "ssh"
host = "${element(*.public_ip, count.index)}"
user = "ubuntu"
private_key = "${file(var.key_file)}"
provisioner "remote-exec" {
inline = [
"echo 'python3 --num_gpus=0 --ps_hosts=${join(",", formatlist("%s:2222",*.public_ip))} --worker_hosts=${join(",", formatlist("%s:2222", aws_instance.worker.*.public_ip))} --job_name=ps --task_index=${count.index}' > /home/ubuntu/",
"chmod +x /home/ubuntu/"
resource "null_resource" "run-worker" {
count = "${var.worker_cnt}"
connection {
type = "ssh"
host = "${element(aws_instance.worker.*.public_ip, count.index)}"
user = "ubuntu"
private_key = "${file(var.key_file)}"
provisioner "remote-exec" {
inline = [
"echo 'python3 --num_gpus=0 --ps_hosts=${join(",", formatlist("%s:2222",*.public_ip))} --worker_hosts=${join(",", formatlist("%s:2222", aws_instance.worker.*.public_ip))} --job_name=worker --task_index=${count.index}' > /home/ubuntu/",
"chmod +x /home/ubuntu/"
output "ips" {
value = "${join(" ", concat(*.public_ip, aws_instance.worker.*.public_ip))}"
sudo apt -y update
sudo apt -y install python3-pip
export LC_ALL=C
sudo pip3 install --upgrade pip
sudo pip3 install --upgrade $TF_BINARY_URL
sudo pip3 install ipython
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Distributed MNIST training and validation, with model replicas.
A simple softmax model with one hidden layer is defined. The parameters
(weights and biases) are located on two parameter servers (ps), while the
ops are defined on a worker node. The TF sessions also run on the worker
Multiple invocations of this script can be done in parallel, with different
values for --task_index. There should be exactly one invocation with
--task_index, which will create a master session that carries out variable
initialization. The other, non-master, sessions will wait for the master
session to finish the initialization before proceeding to the training stage.
The coordination between the multiple worker invocations occurs due to
the definition of the parameters on the same ps devices. The parameter updates
from one worker is visible to all other workers. As such, the workers can
perform forward computation and gradient calculation in parallel, which
should lead to increased training speed for the simple model.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import sys
import tempfile
import time
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
flags =
flags.DEFINE_string("data_dir", "/tmp/mnist-data",
"Directory for storing mnist data")
flags.DEFINE_boolean("download_only", False,
"Only perform downloading of data; Do not proceed to "
"session preparation, model definition or training")
flags.DEFINE_integer("task_index", None,
"Worker task index, should be >= 0. task_index=0 is "
"the master worker task the performs the variable "
"initialization ")
flags.DEFINE_integer("num_gpus", 1,
"Total number of gpus for each machine."
"If you don't use GPU, please set it to '0'")
flags.DEFINE_integer("replicas_to_aggregate", None,
"Number of replicas to aggregate before parameter update"
"is applied (For sync_replicas mode only; default: "
flags.DEFINE_integer("hidden_units", 100,
"Number of units in the hidden layer of the NN")
flags.DEFINE_integer("train_steps", 200,
"Number of (global) training steps to perform")
flags.DEFINE_integer("batch_size", 100, "Training batch size")
flags.DEFINE_float("learning_rate", 0.01, "Learning rate")
flags.DEFINE_boolean("sync_replicas", False,
"Use the sync_replicas (synchronized replicas) mode, "
"wherein the parameter updates from workers are aggregated "
"before applied to avoid stale gradients")
"existing_servers", False, "Whether servers already exists. If True, "
"will use the worker hosts via their GRPC URLs (one client process "
"per worker host). Otherwise, will create an in-process TensorFlow "
"Comma-separated list of hostname:port pairs")
flags.DEFINE_string("worker_hosts", "localhost:2223,localhost:2224",
"Comma-separated list of hostname:port pairs")
flags.DEFINE_string("job_name", None,"job name: worker or ps")
def main(unused_argv):
mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)
if FLAGS.download_only:
if FLAGS.job_name is None or FLAGS.job_name == "":
raise ValueError("Must specify an explicit `job_name`")
if FLAGS.task_index is None or FLAGS.task_index =="":
raise ValueError("Must specify an explicit `task_index`")
print("job name = %s" % FLAGS.job_name)
print("task index = %d" % FLAGS.task_index)
#Construct the cluster and start the server
ps_spec = FLAGS.ps_hosts.split(",")
worker_spec = FLAGS.worker_hosts.split(",")
# Get the number of workers.
num_workers = len(worker_spec)
cluster = tf.train.ClusterSpec({
"ps": ps_spec,
"worker": worker_spec})
if not FLAGS.existing_servers:
# Not using existing servers. Create an in-process server.
server = tf.train.Server(
cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index)
if FLAGS.job_name == "ps":
is_chief = (FLAGS.task_index == 0)
if FLAGS.num_gpus > 0:
if FLAGS.num_gpus < num_workers:
raise ValueError("number of gpus is less than number of workers")
# Avoid gpu allocation conflict: now allocate task_num -> #gpu
# for each worker in the corresponding machine
gpu = (FLAGS.task_index % FLAGS.num_gpus)
worker_device = "/job:worker/task:%d/gpu:%d" % (FLAGS.task_index, gpu)
elif FLAGS.num_gpus == 0:
# Just allocate the CPU to worker server
cpu = 0
worker_device = "/job:worker/task:%d/cpu:%d" % (FLAGS.task_index, cpu)
# The device setter will automatically place Variables ops on separate
# parameter servers (ps). The non-Variable ops will be placed on the workers.
# The ps use CPU and workers use corresponding GPU
with tf.device(
global_step = tf.Variable(0, name="global_step", trainable=False)
# Variables of the hidden layer
hid_w = tf.Variable(
stddev=1.0 / IMAGE_PIXELS),
hid_b = tf.Variable(tf.zeros([FLAGS.hidden_units]), name="hid_b")
# Variables of the softmax layer
sm_w = tf.Variable(
[FLAGS.hidden_units, 10],
stddev=1.0 / math.sqrt(FLAGS.hidden_units)),
sm_b = tf.Variable(tf.zeros([10]), name="sm_b")
# Ops: located on the worker specified with FLAGS.task_index
x = tf.placeholder(tf.float32, [None, IMAGE_PIXELS * IMAGE_PIXELS])
y_ = tf.placeholder(tf.float32, [None, 10])
hid_lin = tf.nn.xw_plus_b(x, hid_w, hid_b)
hid = tf.nn.relu(hid_lin)
y = tf.nn.softmax(tf.nn.xw_plus_b(hid, sm_w, sm_b))
cross_entropy = -tf.reduce_sum(y_ * tf.log(tf.clip_by_value(y, 1e-10, 1.0)))
opt = tf.train.AdamOptimizer(FLAGS.learning_rate)
if FLAGS.sync_replicas:
if FLAGS.replicas_to_aggregate is None:
replicas_to_aggregate = num_workers
replicas_to_aggregate = FLAGS.replicas_to_aggregate
opt = tf.train.SyncReplicasOptimizer(
train_step = opt.minimize(cross_entropy, global_step=global_step)
if FLAGS.sync_replicas:
local_init_op = opt.local_step_init_op
if is_chief:
local_init_op = opt.chief_init_op
ready_for_local_init_op = opt.ready_for_local_init_op
# Initial token and chief queue runners required by the sync_replicas mode
chief_queue_runner = opt.get_chief_queue_runner()
sync_init_op = opt.get_init_tokens_op()
init_op = tf.global_variables_initializer()
train_dir = tempfile.mkdtemp()
if FLAGS.sync_replicas:
sv = tf.train.Supervisor(
sv = tf.train.Supervisor(
sess_config = tf.ConfigProto(
device_filters=["/job:ps", "/job:worker/task:%d" % FLAGS.task_index])
# The chief worker (task_index==0) session will prepare the session,
# while the remaining workers will wait for the preparation to complete.
if is_chief:
print("Worker %d: Initializing session..." % FLAGS.task_index)
print("Worker %d: Waiting for session to be initialized..." %
if FLAGS.existing_servers:
server_grpc_url = "grpc://" + worker_spec[FLAGS.task_index]
print("Using existing server at: %s" % server_grpc_url)
sess = sv.prepare_or_wait_for_session(server_grpc_url,
sess = sv.prepare_or_wait_for_session(, config=sess_config)
print("Worker %d: Session initialization complete." % FLAGS.task_index)
if FLAGS.sync_replicas and is_chief:
# Chief worker will start the chief queue runner and call the init op.
sv.start_queue_runners(sess, [chief_queue_runner])
# Perform training
time_begin = time.time()
print("Training begins @ %f" % time_begin)
local_step = 0
while True:
# Training feed
batch_xs, batch_ys = mnist.train.next_batch(FLAGS.batch_size)
train_feed = {x: batch_xs, y_: batch_ys}
_, step =[train_step, global_step], feed_dict=train_feed)
local_step += 1
now = time.time()
print("%f: Worker %d: training step %d done (global step: %d)" %
(now, FLAGS.task_index, local_step, step))
if step >= FLAGS.train_steps:
time_end = time.time()
print("Training ends @ %f" % time_end)
training_time = time_end - time_begin
print("Training elapsed time: %f s" % training_time)
# Validation feed
val_feed = {x: mnist.validation.images, y_: mnist.validation.labels}
val_xent =, feed_dict=val_feed)
print("After %d training step(s), validation cross entropy = %g" %
(FLAGS.train_steps, val_xent))
if __name__ == "__main__":
for ip in `terraform output ips`
ssh -i SSH키_경로 -o "StrictHostKeyChecking no" ubuntu@$ip /home/ubuntu/ > $ip.log 2>&1 &
region = "ap-northeast-2"
ami = "ami-f293459c" # Ubuntu AMI of Seoul
key_name = "AWS KEY NAME"
owner = "YOUR EMAIL"
instance_type = "t2.small"
ps_cnt = 1
worker_cnt = 2
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment