Skip to content

Instantly share code, notes, and snippets.

@pingsutw
Created February 28, 2020 03:12
Show Gist options
  • Save pingsutw/bd4e1b1b2d4055d4e807abcdd2f0d6f4 to your computer and use it in GitHub Desktop.
Save pingsutw/bd4e1b1b2d4055d4e807abcdd2f0d6f4 to your computer and use it in GitHub Desktop.
Distributed train mnist
# tensorflow 2 supported
# using MultiWorkerMirroredStrategy
# Need to install tf-nightly
# https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras
from __future__ import absolute_import, division, print_function, unicode_literals
import numpy as np
import tensorflow_datasets as tfds
import tensorflow as tf
tfds.disable_progress_bar()
# tf.config.set_soft_device_placement(True)
# tf.debugging.set_log_device_placement(True)
BUFFER_SIZE = 10000
BATCH_SIZE = 64
NUM_WORKERS = 2
GLOBAL_BATCH_SIZE = 64 * NUM_WORKERS
# strategy = tf.distribute.experimental.ParameterServerStrategy()
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
# strategy = tf.distribute.MirroredStrategy()
def input_fn():
# Scaling MNIST data from (0, 255] to (0., 1.]
def scale(image, label):
image = tf.cast(image, tf.float32)
image /= 255.0
return image, label
datasets, info = tfds.load(name='mnist',
data_dir='/tmp/data',
as_supervised=True,
with_info=True)
return datasets['train'].map(scale).shuffle(BUFFER_SIZE).batch(64)
def build_and_compile_cnn_model():
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(10)
])
model.compile(
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=tf.keras.optimizers.SGD(learning_rate=0.001),
metrics=['accuracy'])
return model
def main():
config = tf.estimator.RunConfig(train_distribute=strategy, eval_distribute=strategy)
with strategy.scope():
multi_worker_model = build_and_compile_cnn_model()
train_datasets = input_fn()
multi_worker_model.fit(x=train_datasets, epochs=3, steps_per_epoch=5)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment