Skip to content

Instantly share code, notes, and snippets.

@kaczmarj
Created March 27, 2020 03:13
Show Gist options
  • Save kaczmarj/6a2cb3825eef16a79804422b2da88064 to your computer and use it in GitHub Desktop.
Save kaczmarj/6a2cb3825eef16a79804422b2da88064 to your computer and use it in GitHub Desktop.
training a bayesian neural network with mirrored strategy
# Training a bayesian neural network with mirrored strategy.
# When using `model.fit`, the keras symbolic tensor error arises, and the common fix of
# experimental_run_tf_function=False does not fix it. Creating our own training loop seems to fix it.
# With lots of help from https://www.tensorflow.org/tutorials/distribute/custom_training#training_loop
from nobrainer.models import bayesian
import numpy as np
import tensorflow as tf
# Set up multi-gpu things.
strategy = tf.distribute.MirroredStrategy()
BATCH_SIZE_PER_REPLICA = 1
GLOBAL_BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync
# Create fake data.
shape = (2, 32, 32, 32, 1)
x = np.random.randn(*shape).astype(np.float32)
y = (x > 0).astype(np.int32).squeeze(-1)
train_dataset = tf.data.Dataset.from_tensor_slices((x, y)).batch(GLOBAL_BATCH_SIZE)
train_dist_dataset = strategy.experimental_distribute_dataset(train_dataset)
with strategy.scope():
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=False, reduction=tf.keras.losses.Reduction.NONE)
optimizer = tf.keras.optimizers.Adam(1e-03)
model = bayesian.variational_meshnet(2, (32, 32, 32, 1), receptive_field=37)
def train_step(inputs):
with tf.GradientTape() as tape:
x, y = inputs
y_ = model(x, training=True)
loss_value = loss_fn(y_true=y, y_pred=y_)
loss_value = tf.nn.compute_average_loss(
loss_value, global_batch_size=GLOBAL_BATCH_SIZE)
grads = tape.gradient(loss_value, model.trainable_variables)
optimizer.apply_gradients(zip(grads, model.trainable_variables))
return loss_value
@tf.function
def distributed_train_step(inputs):
per_replica_losses = strategy.experimental_run_v2(train_step, args=(inputs,))
return strategy.reduce(
tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
# Train loop.
EPOCHS = 5
for epoch in range(EPOCHS):
total_loss = 0.0
num_batches = 0
for this_x in train_dist_dataset:
total_loss += distributed_train_step(this_x)
num_batches += 1
train_loss = total_loss / num_batches
print(f"loss={train_loss:0.4f} (epoch {epoch + 1}/{EPOCHS})")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment