Skip to content

Instantly share code, notes, and snippets.

@thierryherrmann
Created November 2, 2019 02:28
Show Gist options
  • Save thierryherrmann/aa176a8107e2cfafe95963396cf75ff9 to your computer and use it in GitHub Desktop.
Save thierryherrmann/aa176a8107e2cfafe95963396cf75ff9 to your computer and use it in GitHub Desktop.
Reproduce TF issue 33150
# https://github.com/tensorflow/tensorflow/issues/33150
import tensorflow as tf
class Net(tf.keras.Model):
def __init__(self):
super(Net, self).__init__()
self.l1 = tf.keras.layers.Dense(5)
def call(self, x):
return self.l1(x)
# create model, optimizer
net = Net()
checkpoint_dir = 'ckpts'
opt = tf.keras.optimizers.Adam(0.1)
ckpt = tf.train.Checkpoint(opt=opt, net=net)
manager = tf.train.CheckpointManager(ckpt, checkpoint_dir, max_to_keep=3)
# train with one example
example_x = tf.constant([[1.]])
example_y = tf.constant([[1.,2.,3.,4.,5.]])
with tf.GradientTape() as tape:
output = net(example_x)
loss = tf.reduce_mean(tf.abs(output - example_y))
variables = net.trainable_variables
gradients = tape.gradient(loss, variables)
opt.apply_gradients(zip(gradients, variables))
save_path = manager.save()
print("Saved checkpoint: {}".format(save_path))
# ========== restart from scratch but restore from checkpoint
net = Net()
opt = tf.keras.optimizers.Adam(0.1)
ckpt = tf.train.Checkpoint(opt=opt, net=net)
manager = tf.train.CheckpointManager(ckpt, checkpoint_dir, max_to_keep=3)
print('restoring...')
status = ckpt.restore(manager.latest_checkpoint)
# assert_consumed() fails with:
# AssertionError: Unresolved object in checkpoint (root).opt.iter: attributes {
# name: "VARIABLE_VALUE"
# full_name: "Adam/iter"
# checkpoint_key: "opt/iter/.ATTRIBUTES/VARIABLE_VALUE"
status.assert_consumed()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment