Skip to content

Instantly share code, notes, and snippets.

@MohanaRC
Created August 10, 2023 15:50
Show Gist options
  • Save MohanaRC/feda99c7c89e8b64352a20d4d1a0ff3d to your computer and use it in GitHub Desktop.
Save MohanaRC/feda99c7c89e8b64352a20d4d1a0ff3d to your computer and use it in GitHub Desktop.
def apply_gradient(optimizer, model, x, y):
"""
Function for computing gradient and updating the weights
"""
with tf.GradientTape() as tape:
# Get model prediction and compute the loss
logits = model(x)
loss_value = loss_object(y_true=y, y_pred=logits)
# Calculate the gradient using tape.gradient and then update the model weights using our optimizer
gradients = tape.gradient(loss_value, model.trainable_weights)
optimizer.apply_gradients(zip(gradients, model.trainable_weights))
return logits, loss_value
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment