Last active
August 12, 2019 23:17
-
-
Save k-w-w/01c58789c2baf2a350c28e5a89566420 to your computer and use it in GitHub Desktop.
Demonstrating how to save Keras model weights in Estimator, and load them into Keras.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# https://github.com/tensorflow/tensorflow/issues/30233 | |
import tensorflow as tf | |
def keras_model(): | |
return tf.keras.models.Sequential( | |
[tf.keras.layers.Dense(1, input_shape=[1])]) | |
def model_fn(features, labels, mode): | |
"""Model function that adds 2 to the `bias` variable every training step.""" | |
model = keras_model() | |
bias = model.layers[-1].bias | |
# Create a Checkpoint object which saves the checkpoint in a format that is | |
# easily reproducible (see line 43) | |
# Use CheckpointV2 -- in TF v1.14 there's a bug that prevents V1 checkpoint | |
# from being used as a saver. | |
checkpoint = tf.compat.v2.train.Checkpoint( | |
model=model, global_step=tf.train.get_global_step()) | |
return tf.estimator.EstimatorSpec( | |
mode=mode, | |
loss=tf.constant(1), | |
train_op=tf.group([ | |
# Increase variable by 2 every training step. | |
tf.assign_add(bias, [2]), | |
tf.assign_add(tf.train.get_global_step(), 1)]), | |
# Pass the checkpoint into the `scaffold` argument. | |
scaffold=tf.train.Scaffold(saver=checkpoint) | |
) | |
def dummy_input_fn(): | |
return tf.data.Dataset.zip((tf.data.Dataset.from_tensors([1]), | |
tf.data.Dataset.from_tensors([1]))) | |
# Create Estimator and train for one step. | |
est = tf.estimator.Estimator(model_fn, '/tmp/estimator') | |
est.train(dummy_input_fn, steps=1) | |
# Create Keras model, and recreate the checkpoint object used in the model | |
# function. | |
model = keras_model() | |
global_step = tf.train.get_or_create_global_step() | |
checkpoint = tf.train.Checkpoint(model=model, global_step=global_step) | |
with tf.Session() as sess: | |
# Load weights into the model using `checkpoint.restore`. | |
latest_checkpoint = tf.train.latest_checkpoint('/tmp/estimator') | |
status = checkpoint.restore(latest_checkpoint) | |
status.initialize_or_restore(session=sess) | |
# Check the value of the bias variable | |
print(sess.run(model.layers[-1].bias)) # Expected value: 2 | |
sess.run(tf.assign_add(model.layers[-1].bias, [3])) | |
print(sess.run(model.layers[-1].bias)) # New value: 5 | |
# Save out new checkpoint. Estimator uses the 'model.ckpt' filename by default | |
checkpoint.save('/tmp/estimator/model.ckpt') | |
# Train estimator another step. The expected value of the bias variable is now 7 | |
est.train(input_fn, steps=1) | |
with tf.Session() as sess: | |
latest_checkpoint = tf.train.latest_checkpoint('/tmp/estimator') | |
status = checkpoint.restore(latest_checkpoint) | |
status.initialize_or_restore(session=sess) | |
print(sess.run(model.layers[-1].bias)) # Expected Value: 7 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment