Skip to content

Instantly share code, notes, and snippets.

@applenob
Last active August 27, 2020 03:18
Show Gist options
  • Save applenob/bde220e19483d20c0aa3814160cdbaf3 to your computer and use it in GitHub Desktop.
Save applenob/bde220e19483d20c0aa3814160cdbaf3 to your computer and use it in GitHub Desktop.
tensorflow distribute learning demo for multi-worker with `collective_all_reduce_strategy`.
import threading
import contextlib
import tensorflow as tf
from tensorflow.python.framework import ops
from tensorflow.contrib.distribute.python import collective_all_reduce_strategy
from tensorflow.python.distribute import multi_worker_test_base
from tensorflow.python.training import coordinator
from tensorflow.python.training import server_lib
from tensorflow.python.eager import context
from tensorflow.python import keras
from tensorflow.python.ops import nn
from tensorflow.python.ops import random_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.ops.losses import losses
from tensorflow.python.training import adam
from tensorflow.python.training import training_util
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.client import session
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.distribute import reduce_util
print(tf.__version__)
coord = coordinator.Coordinator() # 管理线程
thread_local = threading.local()
thread_local.cached_session = None
worker_num = 3
gpu_num = 1
def create_test_objects(cluster_spec=None,
task_type=None,
task_id=None,
num_gpus=None):
sess_config = tf.ConfigProto()
if num_gpus is None:
num_gpus = context.num_gpus()
strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy(
num_gpus_per_worker=num_gpus)
if task_type and task_id is not None:
strategy.configure(
session_config=sess_config,
cluster_spec=cluster_spec,
task_type=task_type,
task_id=task_id)
target = 'grpc://' + cluster_spec[task_type][task_id]
else:
target = ''
return strategy, target, sess_config
cluster_spec = multi_worker_test_base.create_in_process_cluster(num_workers=worker_num, num_ps=0)
print("cluster created!")
print(f"cluster_spec: {cluster_spec}")
def model_fn():
"""Mnist model with synthetic input."""
data_format = 'channels_last'
input_shape = [28, 28, 1]
l = keras.layers
max_pool = l.MaxPooling2D((2, 2), (2, 2),
padding='same',
data_format=data_format)
model = keras.Sequential([
l.Reshape(target_shape=input_shape, input_shape=(28 * 28,)),
l.Conv2D(
32,
5,
padding='same',
data_format=data_format,
activation=nn.relu), max_pool,
l.Conv2D(
64,
5,
padding='same',
data_format=data_format,
activation=nn.relu), max_pool,
l.Flatten(),
l.Dense(1024, activation=nn.relu),
l.Dropout(0.4),
l.Dense(10)
])
image = random_ops.random_uniform([2, 28, 28])
label = random_ops.random_uniform([2, 1], maxval=10, dtype=dtypes.int32)
logits = model(image, training=True)
loss = losses.sparse_softmax_cross_entropy(labels=label, logits=logits)
optimizer = adam.AdamOptimizer(learning_rate=1e-4)
train_op = optimizer.minimize(loss, training_util.get_or_create_global_step())
return train_op, loss
def run_client(client_fn, task_type, task_id, num_gpus, eager_mode,
*args, **kwargs):
def wrapped_client_fn():
with coord.stop_on_exception():
client_fn(task_type, task_id, num_gpus, *args, **kwargs)
if eager_mode:
with context.eager_mode():
wrapped_client_fn()
else:
with context.graph_mode():
wrapped_client_fn()
def run_between_graph_clients(client_fn, cluster_spec, num_gpus, *args,
**kwargs):
"""Runs several clients for between-graph replication.
Args:
client_fn: a function that needs to accept `task_type`, `task_id`,
`num_gpus`.
cluster_spec: a dict specifying jobs in a cluster.
num_gpus: number of GPUs per worker.
*args: will be passed to `client_fn`.
**kwargs: will be passed to `client_fn`.
"""
threads = []
for task_type in ['chief', 'worker']:
for task_id in range(len(cluster_spec.get(task_type, []))):
t = threading.Thread(
target=run_client,
args=(client_fn, task_type, task_id, num_gpus,
False) + args,
kwargs=kwargs)
t.start()
threads.append(t)
coord.join(threads)
# @contextlib.contextmanager
# def cached_session(graph=None, config=None, target=None):
# if getattr(thread_local, 'cached_session', None) is None:
# thread_local.cached_session = session.Session(
# graph=None, config=config, target=target)
# sess = thread_local.cached_session
# with sess.graph.as_default(), sess.as_default():
# yield sess
def test_complex_model(task_type, task_id, num_gpus):
strategy, target, session_config = create_test_objects(
cluster_spec=cluster_spec,
task_type=task_type,
task_id=task_id,
num_gpus=num_gpus)
print(f"get strategy in thread: {threading.get_ident()}")
with ops.Graph().as_default(), \
session.Session(config=session_config, target=target) as sess:
with strategy.scope():
train_op, loss = strategy.extended.call_for_each_replica(model_fn)
train_op = strategy.group(strategy.experimental_local_results(train_op))
loss = strategy.reduce(reduce_util.ReduceOp.SUM, loss, axis=None)
print(f"start variable initialize in thread: {threading.get_ident()}")
sess.run(variables.global_variables_initializer())
print(f"start train_op in thread: {threading.get_ident()}")
loss_v, _ = sess.run([loss, train_op])
print(loss_v)
logging.set_verbosity(logging.INFO)
run_between_graph_clients(test_complex_model, cluster_spec, num_gpus=gpu_num)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment