Last active
August 27, 2020 03:18
-
-
Save applenob/bde220e19483d20c0aa3814160cdbaf3 to your computer and use it in GitHub Desktop.
tensorflow distribute learning demo for multi-worker with `collective_all_reduce_strategy`.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import threading | |
import contextlib | |
import tensorflow as tf | |
from tensorflow.python.framework import ops | |
from tensorflow.contrib.distribute.python import collective_all_reduce_strategy | |
from tensorflow.python.distribute import multi_worker_test_base | |
from tensorflow.python.training import coordinator | |
from tensorflow.python.training import server_lib | |
from tensorflow.python.eager import context | |
from tensorflow.python import keras | |
from tensorflow.python.ops import nn | |
from tensorflow.python.ops import random_ops | |
from tensorflow.python.framework import dtypes | |
from tensorflow.python.ops.losses import losses | |
from tensorflow.python.training import adam | |
from tensorflow.python.training import training_util | |
from tensorflow.python.ops import variable_scope | |
from tensorflow.python.ops import variables | |
from tensorflow.python.client import session | |
from tensorflow.python.platform import tf_logging as logging | |
from tensorflow.python.distribute import reduce_util | |
print(tf.__version__) | |
coord = coordinator.Coordinator() # 管理线程 | |
thread_local = threading.local() | |
thread_local.cached_session = None | |
worker_num = 3 | |
gpu_num = 1 | |
def create_test_objects(cluster_spec=None, | |
task_type=None, | |
task_id=None, | |
num_gpus=None): | |
sess_config = tf.ConfigProto() | |
if num_gpus is None: | |
num_gpus = context.num_gpus() | |
strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy( | |
num_gpus_per_worker=num_gpus) | |
if task_type and task_id is not None: | |
strategy.configure( | |
session_config=sess_config, | |
cluster_spec=cluster_spec, | |
task_type=task_type, | |
task_id=task_id) | |
target = 'grpc://' + cluster_spec[task_type][task_id] | |
else: | |
target = '' | |
return strategy, target, sess_config | |
cluster_spec = multi_worker_test_base.create_in_process_cluster(num_workers=worker_num, num_ps=0) | |
print("cluster created!") | |
print(f"cluster_spec: {cluster_spec}") | |
def model_fn(): | |
"""Mnist model with synthetic input.""" | |
data_format = 'channels_last' | |
input_shape = [28, 28, 1] | |
l = keras.layers | |
max_pool = l.MaxPooling2D((2, 2), (2, 2), | |
padding='same', | |
data_format=data_format) | |
model = keras.Sequential([ | |
l.Reshape(target_shape=input_shape, input_shape=(28 * 28,)), | |
l.Conv2D( | |
32, | |
5, | |
padding='same', | |
data_format=data_format, | |
activation=nn.relu), max_pool, | |
l.Conv2D( | |
64, | |
5, | |
padding='same', | |
data_format=data_format, | |
activation=nn.relu), max_pool, | |
l.Flatten(), | |
l.Dense(1024, activation=nn.relu), | |
l.Dropout(0.4), | |
l.Dense(10) | |
]) | |
image = random_ops.random_uniform([2, 28, 28]) | |
label = random_ops.random_uniform([2, 1], maxval=10, dtype=dtypes.int32) | |
logits = model(image, training=True) | |
loss = losses.sparse_softmax_cross_entropy(labels=label, logits=logits) | |
optimizer = adam.AdamOptimizer(learning_rate=1e-4) | |
train_op = optimizer.minimize(loss, training_util.get_or_create_global_step()) | |
return train_op, loss | |
def run_client(client_fn, task_type, task_id, num_gpus, eager_mode, | |
*args, **kwargs): | |
def wrapped_client_fn(): | |
with coord.stop_on_exception(): | |
client_fn(task_type, task_id, num_gpus, *args, **kwargs) | |
if eager_mode: | |
with context.eager_mode(): | |
wrapped_client_fn() | |
else: | |
with context.graph_mode(): | |
wrapped_client_fn() | |
def run_between_graph_clients(client_fn, cluster_spec, num_gpus, *args, | |
**kwargs): | |
"""Runs several clients for between-graph replication. | |
Args: | |
client_fn: a function that needs to accept `task_type`, `task_id`, | |
`num_gpus`. | |
cluster_spec: a dict specifying jobs in a cluster. | |
num_gpus: number of GPUs per worker. | |
*args: will be passed to `client_fn`. | |
**kwargs: will be passed to `client_fn`. | |
""" | |
threads = [] | |
for task_type in ['chief', 'worker']: | |
for task_id in range(len(cluster_spec.get(task_type, []))): | |
t = threading.Thread( | |
target=run_client, | |
args=(client_fn, task_type, task_id, num_gpus, | |
False) + args, | |
kwargs=kwargs) | |
t.start() | |
threads.append(t) | |
coord.join(threads) | |
# @contextlib.contextmanager | |
# def cached_session(graph=None, config=None, target=None): | |
# if getattr(thread_local, 'cached_session', None) is None: | |
# thread_local.cached_session = session.Session( | |
# graph=None, config=config, target=target) | |
# sess = thread_local.cached_session | |
# with sess.graph.as_default(), sess.as_default(): | |
# yield sess | |
def test_complex_model(task_type, task_id, num_gpus): | |
strategy, target, session_config = create_test_objects( | |
cluster_spec=cluster_spec, | |
task_type=task_type, | |
task_id=task_id, | |
num_gpus=num_gpus) | |
print(f"get strategy in thread: {threading.get_ident()}") | |
with ops.Graph().as_default(), \ | |
session.Session(config=session_config, target=target) as sess: | |
with strategy.scope(): | |
train_op, loss = strategy.extended.call_for_each_replica(model_fn) | |
train_op = strategy.group(strategy.experimental_local_results(train_op)) | |
loss = strategy.reduce(reduce_util.ReduceOp.SUM, loss, axis=None) | |
print(f"start variable initialize in thread: {threading.get_ident()}") | |
sess.run(variables.global_variables_initializer()) | |
print(f"start train_op in thread: {threading.get_ident()}") | |
loss_v, _ = sess.run([loss, train_op]) | |
print(loss_v) | |
logging.set_verbosity(logging.INFO) | |
run_between_graph_clients(test_complex_model, cluster_spec, num_gpus=gpu_num) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment