Skip to content

Instantly share code, notes, and snippets.

@dpoulopoulos
Created January 20, 2022 14:29
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dpoulopoulos/5f0ae30ce6e61a0c700f4ca74bee43d6 to your computer and use it in GitHub Desktop.
Save dpoulopoulos/5f0ae30ce6e61a0c700f4ca74bee43d6 to your computer and use it in GitHub Desktop.
class MyModel(keras.Model):
def train_step(self, data):
# Get the data batch
inputs, targets = data
# Get the model's weights
trainable_vars = self.trainable_variables
# Forward pass
with tf.GradientTape() as tape:
# Get the predictions
preds = self(inputs, training=True)
# Compute the loss value
loss = self.compiled_loss(targets, press)
# Backward pass
grads = tape.gradient(loss, trainable_vars)
# Update weights
self.optimizer.apply_gradients(zip(grads, trainable_vars))
# Update metrics
self.compiled_metrics.update_state(targets, press)
# Return a dict mapping metric names to current value
return {m.name: m.result() for m in self.metrics}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment