Skip to content

Instantly share code, notes, and snippets.

@nathanin
Created June 28, 2020 08:02
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save nathanin/c2fe9fd8a5af341997bed35ce9a96f5e to your computer and use it in GitHub Desktop.
Save nathanin/c2fe9fd8a5af341997bed35ce9a96f5e to your computer and use it in GitHub Desktop.
"""
Distributed MNIST to demonstrate slow autographing of larger models when using a MirroredStrategy
Using the VGG16 model ~ 8 minutes to start iterating on 4 GPUs
Using the 2-CNN model ~ 1 minute to start iterating on 4 GPUs
"""
import tensorflow as tf
import numpy as np
import tqdm
from tensorflow.keras.layers import Dense, Conv2D, GlobalMaxPool2D
global_batch = 512
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = (x_train.reshape((x_train.shape[0], 28, 28, 1)) / 255.).astype(np.float32)
y_train = np.eye(y_train.shape[0], 10)[y_train]
print(x_train.shape, y_train.shape)
print(x_train.min(), x_train.max())
strategy = tf.distribute.MirroredStrategy()
with strategy.scope():
model = tf.keras.Sequential([
tf.keras.applications.VGG16(include_top=False,
weights=None,
input_shape=[32,32,1],
pooling='max'),
# Conv2D(32, (5,5), (2,2), activation='relu', input_shape=[32,32,1]),
# Conv2D(32, (5,5), (2,2), activation='relu'),
# GlobalMaxPool2D(),
# Dense(64, activation='relu'),
Dense(10, activation='softmax')
])
optimizer = tf.keras.optimizers.Adam(0.0001)
def calc_loss(labels, pred):
sample_loss = tf.keras.losses.categorical_crossentropy(labels, pred)
return tf.reduce_sum(sample_loss) / tf.constant(global_batch, dtype=tf.float32)
model.summary()
def resize_image(x, y):
return tf.image.resize(x, (32, 32)), y
dataset = (tf.data.Dataset.from_tensor_slices((x_train, y_train))
.repeat(20)
.shuffle(global_batch)
.map(resize_image)
.batch(global_batch, drop_remainder=True))
distributed_dataset = strategy.experimental_distribute_dataset(dataset)
def train_step(inputs, labels):
with tf.GradientTape() as tape:
yhat = model(inputs)
loss = calc_loss(labels, yhat)
grads = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(grads, model.trainable_variables))
return loss
@tf.function
def distributed_train_step(inputs, labels):
losses = strategy.run(train_step, args=(inputs,labels))
summed_loss = strategy.reduce(tf.distribute.ReduceOp.SUM, losses, axis=None)
return summed_loss
for epoch in range(10):
pbar = tqdm.tqdm(distributed_dataset)
for xbatch, ybatch in pbar:
summed_loss = distributed_train_step(xbatch, ybatch)
pbar.set_description(f'epoch {epoch} {np.mean(summed_loss.numpy()):3.5f}')
@nathanin
Copy link
Author

Run with TF 2.1 and replace the line losses = strategy.run(train_step, args=(inputs,labels)) with losses = strategy.experimental_run_v2(train_step, args=(inputs,labels)) and get expected start up time

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment