Skip to content

Instantly share code, notes, and snippets.

@vishal-keshav
Last active May 6, 2019 15:22
Show Gist options
  • Save vishal-keshav/f81e70e46fcafd8de45fbbdf52737689 to your computer and use it in GitHub Desktop.
Save vishal-keshav/f81e70e46fcafd8de45fbbdf52737689 to your computer and use it in GitHub Desktop.
distributed training example in tensorflow 1.13
def get_model(model_name, input_tensor, is_train, pretrained):
....
def optimisation(label, logits, param_dict):
....
def get_data_provider(dataset_name, dataset_path, param_dict):
....
def test_distributed_training(model, dataset, param_dict):
import time
project_path = os.getcwd()
# First get the strategy for training distribution
strategy = tf.distribute.MirroredStrategy()
# Define the computation that will be taken place on each GPU, with a batch of
# examples taken from dataset (each batch will be different, called by get_next())
# Here, we basically want to train the replicated (mirrored model)
# The variables, tensors, metrics, summaries etc are all created under strategy scope
# so, they are aware of the creating the nodes on each machine.
#---------------------------------------------------------------------------
def train_graph_replica(inputs):
# Assume that get_next() is called on the iterator
input = inputs['image']
label = inputs['label']
# Since train_graph_replica is passed in startegy in the scope, variables are copied
model_fn = get_model(model, input, is_train = True, pretrained = False)
loss_op, opt_op = optimisation(label, model_fn['feature_logits'],
param_dict)
# Evaluate the loss after optimization
with tf.control_dependencies([opt_op]):
return tf.identity(loss_op)
#---------------------------------------------------------------------------
with startegy.scope():
# Now, under strategy scope, dataset is created.
# All-reduce aggregates tensors across all the devices by adding them up, and
# makes them available on each device. This is a synced operations.
dp = get_data_provider(dataset, project_path, param_dict)
dp.make_distributed(strategy)
train_iterator = dp.get_train_dataset(batch_size = 128,
shuffle = 1, prefetch = 1)
# Whatever happens in train_graph_replica, the tensors are averaged sychronously
rep_loss=strategy.experimental_run(train_graph_replica, train_iterator)
# This is an aggregator op, using the strategy op, it aggregates some param
avg_loss = strategy.reduce(tf.distribute.ReduceOp.MEAN,rep_loss)
with tf.Session() as sess:
sess.run(tf.initializers.global_variables())
dp.initialize_train(sess)
for i in range(100):
loss_aggregate = sess.run(avg_loss)
print(loss_aggregate)
if __name__ == "__main__":
model_name = "simple_convnet"
dataset_name = "mnist"
parameters = {'learning_rate': 0.001}
test_distributed_training(model_name, dataset_name, parameters)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment