Last active
August 9, 2020 13:41
-
-
Save thierryherrmann/f6d4b8b1dc4ec745f2e285f58280d316 to your computer and use it in GitHub Desktop.
custom training loop in tf.function
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
np.random.seed(2); tf.random.set_seed(5) | |
def make_model(): | |
# this constructs a keras Model. We use the functional API and add a custom | |
# layer for demo purposes but a model of any complexity can be used here | |
from tensorflow.keras import layers | |
class CustomLayer(keras.layers.Layer): | |
def __init__(self, **kwargs): | |
super().__init__(**kwargs) | |
l2_reg = keras.regularizers.l2(0.1) | |
self.dense = layers.Dense(1, kernel_regularizer=l2_reg, | |
name='my_layer_dense') | |
def call(self, data): | |
return self.dense(data) | |
inputs = keras.Input(shape=(8,)) | |
x1 = layers.Dense(30, activation="relu", name='my_dense')(inputs) | |
outputs = CustomLayer()(x1) | |
return keras.Model(inputs=inputs, outputs=outputs) | |
# Prepare the training dataset. | |
def get_housing_dataset(): | |
from sklearn.datasets import fetch_california_housing | |
from sklearn.model_selection import train_test_split | |
from sklearn.preprocessing import StandardScaler | |
housing = fetch_california_housing() | |
X_train_full, X_test, y_train_full, y_test = train_test_split( | |
housing.data, housing.target) | |
X_train, X_valid, y_train, y_valid = train_test_split( | |
X_train_full, y_train_full) | |
scaler = StandardScaler() | |
X_train = scaler.fit_transform(X_train).astype(np.float32) | |
X_valid = scaler.transform(X_valid).astype(np.float32) | |
X_test = scaler.transform(X_test).astype(np.float32) | |
return X_train, X_valid, X_test, y_train.astype(np.float32), \ | |
y_valid.astype(np.float32), y_test.astype(np.float32) | |
X_train, X_valid,_, y_train, y_valid, _ = get_housing_dataset() | |
batch_size = 64 | |
train_dataset = tf.data.Dataset.from_tensor_slices((X_train, y_train)) | |
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size) | |
valid_dataset = tf.data.Dataset.from_tensor_slices((X_valid, y_valid)).batch(batch_size) | |
class CustomModule(tf.Module): | |
def __init__(self): | |
super(CustomModule, self).__init__() | |
self.model = make_model() | |
self.opt = keras.optimizers.Adam(learning_rate=0.001) | |
# add @tf.function here to make it faster (run in graph mode) and ensure the right shapes and types | |
# are used (optional). | |
# To debug we can | |
# - either use tf.print() statements that will execute in graph mode | |
# - or run in eager mode by removing the @tf.function annotation or by specifying | |
# tf.config.experimental_run_functions_eagerly(True). In eager mode print() or any python | |
# statement can be used (instead of tf.print()) and we can use debugger breakpoint | |
@tf.function(input_signature=[tf.TensorSpec([None, 8], tf.float32)]) | |
def __call__(self, X): | |
return self.model(X) | |
# the my_train function processes one batch (one step): computes the loss and apply the | |
# loss gradient to update the model weights | |
@tf.function(input_signature=[tf.TensorSpec([None, 8], tf.float32), tf.TensorSpec([None], tf.float32)]) | |
def my_train(self, X, y): | |
with tf.GradientTape() as tape: | |
logits = self.model(X, training=True) | |
main_loss = tf.reduce_mean(keras.losses.mean_squared_error(y, logits)) | |
# self.model.losses contains the reularization loss (see l2_reg above) | |
loss_value = tf.add_n([main_loss] + self.model.losses) | |
grads = tape.gradient(loss_value, self.model.trainable_weights) | |
self.opt.apply_gradients(zip(grads, self.model.trainable_weights)) | |
return loss_value | |
# set to True to force in eager execution despite @tf.functions (debugging) | |
tf.config.run_functions_eagerly(False) | |
# instantiate the module | |
module = CustomModule() | |
# demo a call to the module. (calls the __call__() method) | |
print('sample prediction: ', module(X_train[0:1]).numpy()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment