Skip to content

Instantly share code, notes, and snippets.

@mainyaa
Forked from nikitakit/_NOTICE.md
Created October 11, 2016 02:07
Show Gist options
  • Save mainyaa/99a217ec0602c248a929f252d1ad4f05 to your computer and use it in GitHub Desktop.
Save mainyaa/99a217ec0602c248a929f252d1ad4f05 to your computer and use it in GitHub Desktop.
Restoring TensorFlow Models
"""
By default, TensorFlow's GraphDef only saves the graph architecture
(not the parameter values), while the Saver class only writes parameter
values to each checkpoint.
This code allows combining data from the GraphDef and a checkpoint file
to restore a functioning model.
Sample usage:
```
import tensorflow as tf
from tf_restore_graph import restore_graph
from tensorflow.python.summary.event_accumulator import EventAccumulator
sess = tf.InteractiveSession()
events = EventAccumulator('path-to-tfevents-file')
events.Reload()
(x,y), saver = restore_graph(
events.Graph(),
tf.train.get_checkpoint_state('checkpoint').model_checkpoint_path,
return_elements=['x', 'y']
)
print(sess.run(y, feed_dict={x:1.0}))
```
"""
import tensorflow as tf
from tensorflow.python import ops
import random
class RestoredVariable(tf.Variable):
"""
A variable restored from disk
"""
def __init__(self, name, trainable=True, collections=None, graph=None):
if graph is None:
graph = tf.get_default_graph()
if collections is None:
collections = [ops.GraphKeys.VARIABLES]
if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections:
# pylint: disable=g-no-augmented-assignment
#
# Pylint wants us to write collections += [...TRAINABLE_VARIABLES] which
# is not the same (it modifies the list in place.) Here, we only want to
# modify the value of the variable, not the list.
collections = collections + [ops.GraphKeys.TRAINABLE_VARIABLES]
# pylint: enable=g-no-augmented-assignment
self._variable = graph.as_graph_element(name).outputs[0]
self._snapshot = graph.as_graph_element(name + '/read').outputs[0]
self._initializer_op = graph.as_graph_element(name + '/Assign')
i_name = name + '/Initializer/'
keys = [k for k in graph._nodes_by_name.keys() if k.startswith(i_name) and '/' not in k[len(i_name):] ]
if len(keys) != 1:
raise ValueError('Could not find initializer for variable', keys)
self._initial_value = None #initial_value node
for key in collections:
graph.add_to_collection(key, self)
self._save_slice_info = None
def restore_graph(graph_def, save_path=None,
saver_def=None,
input_map=None, return_elements=None, op_dict=None,
trainable=True, collections=None,
):
"""
Restore a graph from a GraphDef
Args:
graph_def: a GraphDef instance, representing the model architecture
save_path: path where parameter values were saved
saver_def: SaverDef for restoring the saver
input_map, return_elements, op_dict: passed to tf.import_graph_def
trainable: whether the restored variables should be marked as trainable
collections: which collections to add the restored variables to
Returns: (graph_elements, saver)
graph_elements: The return value of tf.import_graph_def
saver: The saver can be used to load further checkpoints
"""
res = tf.import_graph_def(graph_def, name='', input_map=input_map, return_elements=return_elements, op_dict=op_dict)
restored_vars = []
for node in graph_def.node:
if node.op == 'Variable':
restored_vars.append(RestoredVariable(node.name, trainable=trainable, collections=collections))
if saver_def is not None:
saver = tf.train.Saver(saver_def, var_list=restored_vars)
else:
# Saver names must be unique, but we can't reuse the old saver variables without the saver_def
# So we generate a random name, and hope the variable ordering and packing is deterministic and
# unchanged since the checkpoint was saved
saver = tf.train.Saver(var_list=restored_vars,
name='restored-' + ('%016x' % random.randrange(16**16)))
if save_path is not None:
saver.restore(tf.get_default_session(), save_path)
return res, saver
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment