Skip to content

Instantly share code, notes, and snippets.

@thierryherrmann
Created August 2, 2020 22:38
Show Gist options
  • Save thierryherrmann/8622dd2fb3760e880ee31fdd9b704074 to your computer and use it in GitHub Desktop.
Save thierryherrmann/8622dd2fb3760e880ee31fdd9b704074 to your computer and use it in GitHub Desktop.
Typical custom training loop
for epoch in range(epochs):
print("\nStart of epoch %d" % (epoch,))
start_time = time.time()
# Iterate over the batches of the dataset.
for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
with tf.GradientTape() as tape:
logits = model(x_batch_train, training=True)
loss_value = loss_fn(y_batch_train, logits)
grads = tape.gradient(loss_value, model.trainable_weights)
optimizer.apply_gradients(zip(grads, model.trainable_weights))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment