Skip to content

Instantly share code, notes, and snippets.

@nikitakit
Last active December 15, 2018 21:13
Show Gist options
  • Save nikitakit/6ef3b72be67b86cb7868 to your computer and use it in GitHub Desktop.
Save nikitakit/6ef3b72be67b86cb7868 to your computer and use it in GitHub Desktop.
Restoring TensorFlow Models
"""
By default, TensorFlow's GraphDef only saves the graph architecture
(not the parameter values), while the Saver class only writes parameter
values to each checkpoint.
This code allows combining data from the GraphDef and a checkpoint file
to restore a functioning model.
Sample usage:
```
import tensorflow as tf
from tf_restore_graph import restore_graph
from tensorflow.python.summary.event_accumulator import EventAccumulator
sess = tf.InteractiveSession()
events = EventAccumulator('path-to-tfevents-file')
events.Reload()
(x,y), saver = restore_graph(
events.Graph(),
tf.train.get_checkpoint_state('checkpoint').model_checkpoint_path,
return_elements=['x', 'y']
)
print(sess.run(y, feed_dict={x:1.0}))
```
"""
import tensorflow as tf
from tensorflow.python import ops
import random
class RestoredVariable(tf.Variable):
"""
A variable restored from disk
"""
def __init__(self, name, trainable=True, collections=None, graph=None):
if graph is None:
graph = tf.get_default_graph()
if collections is None:
collections = [ops.GraphKeys.VARIABLES]
if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections:
# pylint: disable=g-no-augmented-assignment
#
# Pylint wants us to write collections += [...TRAINABLE_VARIABLES] which
# is not the same (it modifies the list in place.) Here, we only want to
# modify the value of the variable, not the list.
collections = collections + [ops.GraphKeys.TRAINABLE_VARIABLES]
# pylint: enable=g-no-augmented-assignment
self._variable = graph.as_graph_element(name).outputs[0]
self._snapshot = graph.as_graph_element(name + '/read').outputs[0]
self._initializer_op = graph.as_graph_element(name + '/Assign')
i_name = name + '/Initializer/'
keys = [k for k in graph._nodes_by_name.keys() if k.startswith(i_name) and '/' not in k[len(i_name):] ]
if len(keys) != 1:
raise ValueError('Could not find initializer for variable', keys)
self._initial_value = None #initial_value node
for key in collections:
graph.add_to_collection(key, self)
self._save_slice_info = None
def restore_graph(graph_def, save_path=None,
saver_def=None,
input_map=None, return_elements=None, op_dict=None,
trainable=True, collections=None,
):
"""
Restore a graph from a GraphDef
Args:
graph_def: a GraphDef instance, representing the model architecture
save_path: path where parameter values were saved
saver_def: SaverDef for restoring the saver
input_map, return_elements, op_dict: passed to tf.import_graph_def
trainable: whether the restored variables should be marked as trainable
collections: which collections to add the restored variables to
Returns: (graph_elements, saver)
graph_elements: The return value of tf.import_graph_def
saver: The saver can be used to load further checkpoints
"""
res = tf.import_graph_def(graph_def, name='', input_map=input_map, return_elements=return_elements, op_dict=op_dict)
restored_vars = []
for node in graph_def.node:
if node.op == 'Variable':
restored_vars.append(RestoredVariable(node.name, trainable=trainable, collections=collections))
if saver_def is not None:
saver = tf.train.Saver(saver_def, var_list=restored_vars)
else:
# Saver names must be unique, but we can't reuse the old saver variables without the saver_def
# So we generate a random name, and hope the variable ordering and packing is deterministic and
# unchanged since the checkpoint was saved
saver = tf.train.Saver(var_list=restored_vars,
name='restored-' + ('%016x' % random.randrange(16**16)))
if save_path is not None:
saver.restore(tf.get_default_session(), save_path)
return res, saver
@Vikramank
Copy link

Hi, thanks for the code. I am trying to retrieve variables of using a simple code. I created two python scripts ,save_model.py and restore_model.py in the same folder to save and retrieve the model respectively. I am able to save the model but unable to retrieve it. Could you please help me with it. Thanks.
save_model.py

import tensorflow as tf
v1 = tf.Variable(1.32, name="v1")
v2 = tf.Variable(1.33, name="v2")

init = tf.initialize_all_variables()

saver = tf.train.Saver()

with tf.Session() as sess:
 sess.run(init)
 print v2.eval(sess)
 save_path="model.ckpt"
 saver.save(sess,save_path)
 saver.restore(sess, save_path)
 print("Model restored.")

restore.py

import tensorflow as tf
v1 = tf.Variable(0, name="v1")
v2 = tf.Variable(0, name="v2")
saver = tf.train.Saver()
init = tf.initialize_all_variables()
with tf.Session() as sess:
  save_path="model.ckpt"
  saver.restore(sess, save_path)
  print("Model restored.")

@nikitakit
Copy link
Author

@Vikramank
(I only saw this now because for some reason, I'm not subscribed to comments on my own gist)
Your problem is that restore.py initializes the variables to integer 0, not float 0.0. This makes them incapable of storing floats, so you can't load your checkpoint.

@zaatko
Copy link

zaatko commented Jun 14, 2016

Hallo, many thanks for this!
I'm testing your script but always get an assertion error on events.Graph() !!
when using a normal tf.import_graph_def(graph_def, name='') causes raising the error "Could not find initializer for variable"
any ideas why?

@nikitakit
Copy link
Author

@zaatko

It's not immediately clear to me what the issue is.

However, if you have a recent enough version of tensorflow, could you try out the API described here: https://www.tensorflow.org/versions/r0.9/how_tos/meta_graph/index.html

I haven't used it myself, but it does seem that graph restoration has finally made it into the official release.

@ic
Copy link

ic commented Aug 10, 2016

It seems the way to go is rather the "freeze graph" approach, which seems available from 0.10.0.


I found the script trying to do the same actually (converting a training graph to a constant one). Several models shared for examples and tutorials (e.g. Inception's ZIP file) are binary of constants in the graph, and it is the only way (for now) to load them conveniently with the C++ API, to date.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment