Skip to content

Instantly share code, notes, and snippets.

@nuzrub
Created April 30, 2020 20:54
Show Gist options
  • Star 5 You must be signed in to star a gist
  • Fork 2 You must be signed in to fork a gist
  • Save nuzrub/f1527654572b3e2da5125d0581e7bdad to your computer and use it in GitHub Desktop.
Save nuzrub/f1527654572b3e2da5125d0581e7bdad to your computer and use it in GitHub Desktop.
Writing TensorFlow 2 Custom Loops: A step-by-step guide from Keras to TensorFlow 2
# Author: Ygor Rebouças
#
### The Training Loop
#
# 0) Imports
import tensorflow as tf
import numpy as np
# 1) Dataset loading and preparation
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.cifar10.load_data()
X_train = X_train.astype('float32') / 255
X_test = X_test.astype('float32') / 255
y_train = tf.keras.utils.to_categorical(y_train, 10)
y_test = tf.keras.utils.to_categorical(y_test, 10)
# 2) Model loading / creation
model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Input(shape=X_train.shape[1:]))
for n_filters in [32, 64, 128]:
model.add(tf.keras.layers.Conv2D(n_filters, (3, 3), padding='same', use_bias=False))
model.add(tf.keras.layers.BatchNormalization())
model.add(tf.keras.layers.Activation('elu'))
model.add(tf.keras.layers.MaxPooling2D(pool_size=(2, 2)))
model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dense(512, activation='elu'))
model.add(tf.keras.layers.Dense(10, activation='softmax'))
# 3) Compile and fit
model.compile(loss='categorical_crossentropy', optimizer='adam')
model.fit(x=X_train, y=y_train, validation_data=(X_test, y_test), batch_size=128, epochs=10, shuffle=True)
### The Custom Loop
# The train_on_batch function
loss = tf.keras.losses.categorical_crossentropy
optimizer = tf.keras.optimizers.Adam()
def train_on_batch(X, y):
with tf.GradientTape() as tape:
ŷ = model(X, training=True)
loss_value = loss(y, ŷ)
grads = tape.gradient(loss_value, model.trainable_weights)
optimizer.apply_gradients(zip(grads, model.trainable_weights))
train_on_batch(X_train[0:128], y_train[0:128])
# The validate_on_batch function
def validate_on_batch(X, y):
ŷ = model(X, training=False)
loss_value = loss(y, ŷ)
return loss_value
validate_on_batch(X_test[0:128], y_test[0:128])
# Putting it all together
loss = tf.keras.losses.categorical_crossentropy
optimizer = tf.keras.optimizers.Adam(0.001)
batch_size = 1024
epochs = 10
for epoch in range(0, epochs):
for i in range(0, len(X_train) // batch_size):
X = X_train[i * batch_size:min(len(X_train), (i+1) * batch_size)]
y = y_train[i * batch_size:min(len(y_train), (i+1) * batch_size)]
train_on_batch(X, y)
val_loss = []
for i in range(0, len(X_test) // batch_size):
X = X_test[i * batch_size:min(len(X_test), (i+1) * batch_size)]
y = y_test[i * batch_size:min(len(y_test), (i+1) * batch_size)]
val_loss.append(validate_on_batch(X, y))
print('Validation Loss: ' + str(np.mean(val_loss)))
## Improving the Loop
# The Dataset API
train_data = tf.data.Dataset.from_tensor_slices((X_train, y_train)).shuffle(buffer_size=len(X_train)).batch(batch_size)
test_data = tf.data.Dataset.from_tensor_slices((X_test, y_test)).shuffle(buffer_size=len(X_test)).batch(batch_size)
# Enumerating the Dataset
for epoch in range(0, epochs):
for batch, (X, y) in enumerate(train_data):
train_on_batch(X, y)
val_loss = []
for batch, (X, y) in enumerate(test_data):
val_loss.append(validate_on_batch(X, y))
print('Validation Loss: ' + str(np.mean(val_loss)))
# Model Checkpointing and better prints
best_loss = 99999
for epoch in range(0, epochs):
for batch, (X, y) in enumerate(train_data):
train_on_batch(X, y)
print('\rEpoch [%d/%d] Batch: %d%s' % (epoch + 1, epochs, batch, '.' * (batch % 10)), end='')
val_loss = np.mean([np.mean(validate_on_batch(X, y)) for (X, y) in test_data])
print('. Validation Loss: ' + str(val_loss))
if val_loss < best_loss:
model.save_weights('model.h5')
best_loss = val_loss
### The tf.function
@tf.function
def train_on_batch(X, y):
with tf.GradientTape() as tape:
ŷ = model(X, training=True)
loss_value = loss(y, ŷ)
grads = tape.gradient(loss_value, model.trainable_weights)
optimizer.apply_gradients(zip(grads, model.trainable_weights))
@tf.function
def validate_on_batch(X, y):
ŷ = model(X, training=False)
loss_value = loss(y, ŷ)
return loss_value
@duvictor
Copy link

nice study!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment