Created
April 30, 2020 20:54
-
-
Save nuzrub/f1527654572b3e2da5125d0581e7bdad to your computer and use it in GitHub Desktop.
Writing TensorFlow 2 Custom Loops: A step-by-step guide from Keras to TensorFlow 2
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Author: Ygor Rebouças | |
# | |
### The Training Loop | |
# | |
# 0) Imports | |
import tensorflow as tf | |
import numpy as np | |
# 1) Dataset loading and preparation | |
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.cifar10.load_data() | |
X_train = X_train.astype('float32') / 255 | |
X_test = X_test.astype('float32') / 255 | |
y_train = tf.keras.utils.to_categorical(y_train, 10) | |
y_test = tf.keras.utils.to_categorical(y_test, 10) | |
# 2) Model loading / creation | |
model = tf.keras.models.Sequential() | |
model.add(tf.keras.layers.Input(shape=X_train.shape[1:])) | |
for n_filters in [32, 64, 128]: | |
model.add(tf.keras.layers.Conv2D(n_filters, (3, 3), padding='same', use_bias=False)) | |
model.add(tf.keras.layers.BatchNormalization()) | |
model.add(tf.keras.layers.Activation('elu')) | |
model.add(tf.keras.layers.MaxPooling2D(pool_size=(2, 2))) | |
model.add(tf.keras.layers.Flatten()) | |
model.add(tf.keras.layers.Dense(512, activation='elu')) | |
model.add(tf.keras.layers.Dense(10, activation='softmax')) | |
# 3) Compile and fit | |
model.compile(loss='categorical_crossentropy', optimizer='adam') | |
model.fit(x=X_train, y=y_train, validation_data=(X_test, y_test), batch_size=128, epochs=10, shuffle=True) | |
### The Custom Loop | |
# The train_on_batch function | |
loss = tf.keras.losses.categorical_crossentropy | |
optimizer = tf.keras.optimizers.Adam() | |
def train_on_batch(X, y): | |
with tf.GradientTape() as tape: | |
ŷ = model(X, training=True) | |
loss_value = loss(y, ŷ) | |
grads = tape.gradient(loss_value, model.trainable_weights) | |
optimizer.apply_gradients(zip(grads, model.trainable_weights)) | |
train_on_batch(X_train[0:128], y_train[0:128]) | |
# The validate_on_batch function | |
def validate_on_batch(X, y): | |
ŷ = model(X, training=False) | |
loss_value = loss(y, ŷ) | |
return loss_value | |
validate_on_batch(X_test[0:128], y_test[0:128]) | |
# Putting it all together | |
loss = tf.keras.losses.categorical_crossentropy | |
optimizer = tf.keras.optimizers.Adam(0.001) | |
batch_size = 1024 | |
epochs = 10 | |
for epoch in range(0, epochs): | |
for i in range(0, len(X_train) // batch_size): | |
X = X_train[i * batch_size:min(len(X_train), (i+1) * batch_size)] | |
y = y_train[i * batch_size:min(len(y_train), (i+1) * batch_size)] | |
train_on_batch(X, y) | |
val_loss = [] | |
for i in range(0, len(X_test) // batch_size): | |
X = X_test[i * batch_size:min(len(X_test), (i+1) * batch_size)] | |
y = y_test[i * batch_size:min(len(y_test), (i+1) * batch_size)] | |
val_loss.append(validate_on_batch(X, y)) | |
print('Validation Loss: ' + str(np.mean(val_loss))) | |
## Improving the Loop | |
# The Dataset API | |
train_data = tf.data.Dataset.from_tensor_slices((X_train, y_train)).shuffle(buffer_size=len(X_train)).batch(batch_size) | |
test_data = tf.data.Dataset.from_tensor_slices((X_test, y_test)).shuffle(buffer_size=len(X_test)).batch(batch_size) | |
# Enumerating the Dataset | |
for epoch in range(0, epochs): | |
for batch, (X, y) in enumerate(train_data): | |
train_on_batch(X, y) | |
val_loss = [] | |
for batch, (X, y) in enumerate(test_data): | |
val_loss.append(validate_on_batch(X, y)) | |
print('Validation Loss: ' + str(np.mean(val_loss))) | |
# Model Checkpointing and better prints | |
best_loss = 99999 | |
for epoch in range(0, epochs): | |
for batch, (X, y) in enumerate(train_data): | |
train_on_batch(X, y) | |
print('\rEpoch [%d/%d] Batch: %d%s' % (epoch + 1, epochs, batch, '.' * (batch % 10)), end='') | |
val_loss = np.mean([np.mean(validate_on_batch(X, y)) for (X, y) in test_data]) | |
print('. Validation Loss: ' + str(val_loss)) | |
if val_loss < best_loss: | |
model.save_weights('model.h5') | |
best_loss = val_loss | |
### The tf.function | |
@tf.function | |
def train_on_batch(X, y): | |
with tf.GradientTape() as tape: | |
ŷ = model(X, training=True) | |
loss_value = loss(y, ŷ) | |
grads = tape.gradient(loss_value, model.trainable_weights) | |
optimizer.apply_gradients(zip(grads, model.trainable_weights)) | |
@tf.function | |
def validate_on_batch(X, y): | |
ŷ = model(X, training=False) | |
loss_value = loss(y, ŷ) | |
return loss_value |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
nice study!