Skip to content

Instantly share code, notes, and snippets.

@thierryherrmann
Last active August 9, 2020 13:41
Show Gist options
  • Save thierryherrmann/f6d4b8b1dc4ec745f2e285f58280d316 to your computer and use it in GitHub Desktop.
Save thierryherrmann/f6d4b8b1dc4ec745f2e285f58280d316 to your computer and use it in GitHub Desktop.
custom training loop in tf.function
np.random.seed(2); tf.random.set_seed(5)
def make_model():
# this constructs a keras Model. We use the functional API and add a custom
# layer for demo purposes but a model of any complexity can be used here
from tensorflow.keras import layers
class CustomLayer(keras.layers.Layer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
l2_reg = keras.regularizers.l2(0.1)
self.dense = layers.Dense(1, kernel_regularizer=l2_reg,
name='my_layer_dense')
def call(self, data):
return self.dense(data)
inputs = keras.Input(shape=(8,))
x1 = layers.Dense(30, activation="relu", name='my_dense')(inputs)
outputs = CustomLayer()(x1)
return keras.Model(inputs=inputs, outputs=outputs)
# Prepare the training dataset.
def get_housing_dataset():
from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
housing = fetch_california_housing()
X_train_full, X_test, y_train_full, y_test = train_test_split(
housing.data, housing.target)
X_train, X_valid, y_train, y_valid = train_test_split(
X_train_full, y_train_full)
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train).astype(np.float32)
X_valid = scaler.transform(X_valid).astype(np.float32)
X_test = scaler.transform(X_test).astype(np.float32)
return X_train, X_valid, X_test, y_train.astype(np.float32), \
y_valid.astype(np.float32), y_test.astype(np.float32)
X_train, X_valid,_, y_train, y_valid, _ = get_housing_dataset()
batch_size = 64
train_dataset = tf.data.Dataset.from_tensor_slices((X_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)
valid_dataset = tf.data.Dataset.from_tensor_slices((X_valid, y_valid)).batch(batch_size)
class CustomModule(tf.Module):
def __init__(self):
super(CustomModule, self).__init__()
self.model = make_model()
self.opt = keras.optimizers.Adam(learning_rate=0.001)
# add @tf.function here to make it faster (run in graph mode) and ensure the right shapes and types
# are used (optional).
# To debug we can
# - either use tf.print() statements that will execute in graph mode
# - or run in eager mode by removing the @tf.function annotation or by specifying
# tf.config.experimental_run_functions_eagerly(True). In eager mode print() or any python
# statement can be used (instead of tf.print()) and we can use debugger breakpoint
@tf.function(input_signature=[tf.TensorSpec([None, 8], tf.float32)])
def __call__(self, X):
return self.model(X)
# the my_train function processes one batch (one step): computes the loss and apply the
# loss gradient to update the model weights
@tf.function(input_signature=[tf.TensorSpec([None, 8], tf.float32), tf.TensorSpec([None], tf.float32)])
def my_train(self, X, y):
with tf.GradientTape() as tape:
logits = self.model(X, training=True)
main_loss = tf.reduce_mean(keras.losses.mean_squared_error(y, logits))
# self.model.losses contains the reularization loss (see l2_reg above)
loss_value = tf.add_n([main_loss] + self.model.losses)
grads = tape.gradient(loss_value, self.model.trainable_weights)
self.opt.apply_gradients(zip(grads, self.model.trainable_weights))
return loss_value
# set to True to force in eager execution despite @tf.functions (debugging)
tf.config.run_functions_eagerly(False)
# instantiate the module
module = CustomModule()
# demo a call to the module. (calls the __call__() method)
print('sample prediction: ', module(X_train[0:1]).numpy())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment