Skip to content

Instantly share code, notes, and snippets.

View thierryherrmann's full-sized avatar

Thierry Herrmann thierryherrmann

  • Montreal, Canada
View GitHub Profile
thierryherrmann /
Created November 2, 2019 02:28
Reproduce TF issue 33150
import tensorflow as tf
class Net(tf.keras.Model):
def __init__(self):
super(Net, self).__init__()
self.l1 = tf.keras.layers.Dense(5)
def call(self, x):
return self.l1(x)
thierryherrmann /
Created November 13, 2019 04:03
# This was tested from tf-nightly 2.1.0.dev20191111 (Linux Ubuntu 18.04)
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from scipy import stats
import os
import tensorflow as tf
from tensorflow import keras
thierryherrmann /
Created August 2, 2020 22:38
Typical custom training loop
for epoch in range(epochs):
print("\nStart of epoch %d" % (epoch,))
start_time = time.time()
# Iterate over the batches of the dataset.
for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
with tf.GradientTape() as tape:
logits = model(x_batch_train, training=True)
loss_value = loss_fn(y_batch_train, logits)
grads = tape.gradient(loss_value, model.trainable_weights)
thierryherrmann /
Last active August 9, 2020 13:41
custom training loop in tf.function
np.random.seed(2); tf.random.set_seed(5)
def make_model():
# this constructs a keras Model. We use the functional API and add a custom
# layer for demo purposes but a model of any complexity can be used here
from tensorflow.keras import layers
class CustomLayer(keras.layers.Layer):
def __init__(self, **kwargs):
def train_module(module, train_dataset, valid_dataset):
valid_metric = keras.metrics.MeanSquaredError()
loss_hist = []
for epoch in range(3):
for X, y in train_dataset:
loss = module.my_train(X, y)
if step % 100 == 0:
<tf.Variable 'Adam/dense_2/bias/m:0' shape=(30,) dtype=float32, numpy=
array([ 1.3742445e-04, 3.0024436e-05, 7.1526818e-05, -1.0563848e-03,
-2.0427089e-03, 7.6999364e-05, -3.1418181e-03, -2.4974323e-03,
3.2060378e-04, -3.7756050e-04, 1.7517927e-04, -1.3496901e-03,
3.1575797e-05, -1.4640440e-03, 1.7805261e-04, -7.5319828e-04,
2.4552579e-04, -3.8849441e-03, -1.3961941e-03, 1.4816693e-05,
-4.0749349e-03, -8.9195929e-04, 1.1976792e-04, -5.5552716e-04,
2.1161152e-04, 1.3880052e-04, -1.4332745e-03, 1.2115676e-04,
loss_hist = train_module(module, train_dataset, valid_dataset)
<tf.Variable 'Adam/dense_2/bias/m:0' shape=(30,) dtype=float32, numpy=
array([ 3.51306298e-05, 3.61366037e-05, -3.67252505e-06, 9.21028666e-04,
7.78463436e-04, 2.24373052e-05, 6.05550595e-04, 7.36912712e-04,
-4.31884764e-05, 1.44443940e-04, 1.24389135e-05, 8.46692594e-04,
1.70874955e-05, 3.72679904e-04, 5.41794288e-05, 6.08396949e-04,
1.95211032e-06, 8.75406899e-04, 9.23899701e-04, 2.17679326e-06,
8.70055985e-04, 6.87883934e-04, 5.30559737e-06, 5.81342028e-04,
2.78645912e-05, 4.61369600e-05, 7.27826264e-04, 1.64074972e-05,
def save_module(module, model_dir):
# When saving a tf.keras.Model with either or
# tf.keras.models.save_model() or,
# the saved model contains a `serving_default` signature used to get the
# output of the model from an input sample. But here we don't save a keras
# Model but a tf.Module. This requires to specify the signatures manually
# Note that we also export the training function here
INFO:tensorflow:Assets written to: saved_model/assets
non-tensor: _CHECKPOINTABLE_OBJECT_GRAPH <class 'bytes'>
tensor : model/layer_with_weights-0/bias/.ATTRIBUTES/VARIABLE_VALUE (30,)
tensor : model/layer_with_weights-0/bias/.OPTIMIZER_SLOT/opt/m/.ATTRIBUTES/VARIABLE_VALUE (30,)
tensor : model/layer_with_weights-0/bias/.OPTIMIZER_SLOT/opt/v/.ATTRIBUTES/VARIABLE_VALUE (30,)
tensor : model/layer_with_weights-0/kernel/.ATTRIBUTES/VARIABLE_VALUE (8, 30)
tensor : model/layer_with_weights-0/kernel/.OPTIMIZER_SLOT/opt/m/.ATTRIBUTES/VARIABLE_VALUE (8, 30)
tensor : model/layer_with_weights-0/kernel/.OPTIMIZER_SLOT/opt/v/.ATTRIBUTES/VARIABLE_VALUE (8, 30)
tensor : model/variables/2/.ATTRIBUTES/VARIABLE_VALUE (30, 1)
tensor : model/variables/2/.OPTIMIZER_SLOT/opt/m/.ATTRIBUTES/VARIABLE_VALUE (30, 1)