Created
March 27, 2020 03:13
-
-
Save kaczmarj/6a2cb3825eef16a79804422b2da88064 to your computer and use it in GitHub Desktop.
training a bayesian neural network with mirrored strategy
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Training a bayesian neural network with mirrored strategy. | |
# When using `model.fit`, the keras symbolic tensor error arises, and the common fix of | |
# experimental_run_tf_function=False does not fix it. Creating our own training loop seems to fix it. | |
# With lots of help from https://www.tensorflow.org/tutorials/distribute/custom_training#training_loop | |
from nobrainer.models import bayesian | |
import numpy as np | |
import tensorflow as tf | |
# Set up multi-gpu things. | |
strategy = tf.distribute.MirroredStrategy() | |
BATCH_SIZE_PER_REPLICA = 1 | |
GLOBAL_BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync | |
# Create fake data. | |
shape = (2, 32, 32, 32, 1) | |
x = np.random.randn(*shape).astype(np.float32) | |
y = (x > 0).astype(np.int32).squeeze(-1) | |
train_dataset = tf.data.Dataset.from_tensor_slices((x, y)).batch(GLOBAL_BATCH_SIZE) | |
train_dist_dataset = strategy.experimental_distribute_dataset(train_dataset) | |
with strategy.scope(): | |
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy( | |
from_logits=False, reduction=tf.keras.losses.Reduction.NONE) | |
optimizer = tf.keras.optimizers.Adam(1e-03) | |
model = bayesian.variational_meshnet(2, (32, 32, 32, 1), receptive_field=37) | |
def train_step(inputs): | |
with tf.GradientTape() as tape: | |
x, y = inputs | |
y_ = model(x, training=True) | |
loss_value = loss_fn(y_true=y, y_pred=y_) | |
loss_value = tf.nn.compute_average_loss( | |
loss_value, global_batch_size=GLOBAL_BATCH_SIZE) | |
grads = tape.gradient(loss_value, model.trainable_variables) | |
optimizer.apply_gradients(zip(grads, model.trainable_variables)) | |
return loss_value | |
@tf.function | |
def distributed_train_step(inputs): | |
per_replica_losses = strategy.experimental_run_v2(train_step, args=(inputs,)) | |
return strategy.reduce( | |
tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None) | |
# Train loop. | |
EPOCHS = 5 | |
for epoch in range(EPOCHS): | |
total_loss = 0.0 | |
num_batches = 0 | |
for this_x in train_dist_dataset: | |
total_loss += distributed_train_step(this_x) | |
num_batches += 1 | |
train_loss = total_loss / num_batches | |
print(f"loss={train_loss:0.4f} (epoch {epoch + 1}/{EPOCHS})") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment