Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
save/restore models in tf 0.12
######################################## train and save model in train.py
# input, output, hyperparameter as placeholders, e.g.
x = tf.placeholder(tf.float32, (None, 32, 32, 3), name="x")
y = tf.placeholder(tf.int32, (None), name="y")
keep_prob = tf.placeholder(tf.float32, name="keep_prob")
# build model
yhat, loss = build_whatevermodel(x)
train_op = whateveroptimizer.minimize(loss)
# train the model
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for _ in range(epochs):
sess.run(train_op, feed_dict={x:train_batch_x, y:train_batch_y, keep_prob=0.5})
# ...
# save model and placeholders
tf.add_to_collection("vars", x)
tf.add_to_collection("vars", yhat)
tf.add_to_collection("vars", keep_prob)
saver = tf.train.Saver()
saver.save(sess, "./yourmodel")
################################################## restore and use model in predict.py
# when restoring, you don't have to recreate the model `build_whatevermodel`
# and you don't have to run `tf.global_variables_initializer()` anymore.
# restoring a model will restore both the graph and variable values
with tf.Session() as ses:
saver = tf.train.import_meta_graph("yourmodel.meta")
saver.restore(sess, tf.train.latest_checkpoint("./"))
# restore place holders explicitly
x = tf.get_collection("vars")[0]
yhat = tf.get_collection("vars")[1]
keep_prob = tf.get_collection("vars")[2]
new_yhat = sess.run(yhat, feed_dict={x: new_x, keep_prob: 1.})
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment