Skip to content

Instantly share code, notes, and snippets.

@k-w-w
Last active August 12, 2019 23:17
Show Gist options
  • Save k-w-w/01c58789c2baf2a350c28e5a89566420 to your computer and use it in GitHub Desktop.
Save k-w-w/01c58789c2baf2a350c28e5a89566420 to your computer and use it in GitHub Desktop.
Demonstrating how to save Keras model weights in Estimator, and load them into Keras.
# https://github.com/tensorflow/tensorflow/issues/30233
import tensorflow as tf
def keras_model():
return tf.keras.models.Sequential(
[tf.keras.layers.Dense(1, input_shape=[1])])
def model_fn(features, labels, mode):
"""Model function that adds 2 to the `bias` variable every training step."""
model = keras_model()
bias = model.layers[-1].bias
# Create a Checkpoint object which saves the checkpoint in a format that is
# easily reproducible (see line 43)
# Use CheckpointV2 -- in TF v1.14 there's a bug that prevents V1 checkpoint
# from being used as a saver.
checkpoint = tf.compat.v2.train.Checkpoint(
model=model, global_step=tf.train.get_global_step())
return tf.estimator.EstimatorSpec(
mode=mode,
loss=tf.constant(1),
train_op=tf.group([
# Increase variable by 2 every training step.
tf.assign_add(bias, [2]),
tf.assign_add(tf.train.get_global_step(), 1)]),
# Pass the checkpoint into the `scaffold` argument.
scaffold=tf.train.Scaffold(saver=checkpoint)
)
def dummy_input_fn():
return tf.data.Dataset.zip((tf.data.Dataset.from_tensors([1]),
tf.data.Dataset.from_tensors([1])))
# Create Estimator and train for one step.
est = tf.estimator.Estimator(model_fn, '/tmp/estimator')
est.train(dummy_input_fn, steps=1)
# Create Keras model, and recreate the checkpoint object used in the model
# function.
model = keras_model()
global_step = tf.train.get_or_create_global_step()
checkpoint = tf.train.Checkpoint(model=model, global_step=global_step)
with tf.Session() as sess:
# Load weights into the model using `checkpoint.restore`.
latest_checkpoint = tf.train.latest_checkpoint('/tmp/estimator')
status = checkpoint.restore(latest_checkpoint)
status.initialize_or_restore(session=sess)
# Check the value of the bias variable
print(sess.run(model.layers[-1].bias)) # Expected value: 2
sess.run(tf.assign_add(model.layers[-1].bias, [3]))
print(sess.run(model.layers[-1].bias)) # New value: 5
# Save out new checkpoint. Estimator uses the 'model.ckpt' filename by default
checkpoint.save('/tmp/estimator/model.ckpt')
# Train estimator another step. The expected value of the bias variable is now 7
est.train(input_fn, steps=1)
with tf.Session() as sess:
latest_checkpoint = tf.train.latest_checkpoint('/tmp/estimator')
status = checkpoint.restore(latest_checkpoint)
status.initialize_or_restore(session=sess)
print(sess.run(model.layers[-1].bias)) # Expected Value: 7
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment