Skip to content

Instantly share code, notes, and snippets.

@ageron
Created March 25, 2019 15:06
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ageron/12eeb990192205b04f7c65820bf4a885 to your computer and use it in GitHub Desktop.
Save ageron/12eeb990192205b04f7c65820bf4a885 to your computer and use it in GitHub Desktop.
Test of a custom training loop
import numpy as np
import tensorflow as tf
from tensorflow import keras
model = keras.models.Sequential([keras.layers.Dense(1, input_shape=[5])])
optimizer = keras.optimizers.SGD()
def step(model, optimizer, X_batch, y_batch):
with tf.GradientTape() as tape:
y_pred = model(X_batch)
grads = tape.gradient(y_pred, model.trainable_weights)
optimizer.apply_gradients(zip(grads, model.trainable_weights))
@tf.function
def train(model, optimizer, dataset):
for X_batch, y_batch in dataset:
step(model, optimizer, X_batch, y_batch)
X = np.random.rand(1000, 5)
y = np.random.rand(1000, 1)
dataset = tf.data.Dataset.from_tensor_slices((X, y)).batch(32).repeat(5)
train(model, optimizer, dataset)
@ageron
Copy link
Author

ageron commented Mar 25, 2019

Replacing the model with a subclassed model works just as well:

class MyModel(keras.models.Model):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.out = keras.layers.Dense(1)
    def call(self, inputs):
        return self.out(inputs)

model = MyModel()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment