Skip to content

Instantly share code, notes, and snippets.

@Ingwar
Last active January 31, 2018 10:37
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Ingwar/3b8a7232f4cb906ff1f0 to your computer and use it in GitHub Desktop.
Save Ingwar/3b8a7232f4cb906ff1f0 to your computer and use it in GitHub Desktop.
Script that reproduce issues with tf.train.write_graph() and graph_utils.convert_to_constant_graph()
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import numpy as np
import os
import tensorflow as tf
from tensorflow.python.client import graph_util
DATA_PLACEHOLDER_NAME = 'data'
IS_TRAIN_PHASE_PLACEHOLDER_NAME = 'is_train_phase'
RESULT_NODE_NAME = 'result'
DATA_BATCH_SHAPE = (5, 5)
def train_and_save_graph(checkpoints_dir):
with tf.Graph().as_default():
data_placeholder = tf.placeholder(dtype=tf.float32, shape=DATA_BATCH_SHAPE, name=DATA_PLACEHOLDER_NAME)
is_train_phase_placeholder = tf.placeholder(dtype=tf.bool, name=IS_TRAIN_PHASE_PLACEHOLDER_NAME)
# The actual computations are meaningless, but the pattern are similar to
# calculation of batch mean/variance for batch normalization
batch_mean, batch_var = tf.nn.moments(data_placeholder, [0], name='moments')
ema = tf.train.ExponentialMovingAverage(decay=0.9)
ema_apply_op = ema.apply([batch_mean, batch_var])
ema_mean = ema.average(batch_mean)
ema_var = ema.average(batch_var)
def mean_var_with_update():
with tf.control_dependencies([ema_apply_op]):
return tf.identity(batch_mean), tf.identity(batch_var)
mean, var = tf.cond(is_train_phase_placeholder, mean_var_with_update, lambda: (ema_mean, ema_var))
normalized_data = (data_placeholder - mean) / var
result_op = tf.reduce_sum(normalized_data ** 2, name=RESULT_NODE_NAME)
saver = tf.train.Saver(tf.all_variables())
init_op = tf.initialize_all_variables()
with tf.Session() as session:
session.run(init_op)
current_step = None
for i in range(100):
data = np.random.randn(*DATA_BATCH_SHAPE) * np.random.uniform(0.1, 10)
result = session.run([result_op], feed_dict={data_placeholder: data, is_train_phase_placeholder: True})
print('Step {:d}. result: {}'.format(i, result))
current_step = i
path_to_checkpoint = os.path.join(checkpoints_dir, 'model.ckpt')
try:
os.mkdir(checkpoints_dir)
except OSError:
pass
saver.save(session, path_to_checkpoint, global_step=current_step)
def convert_to_constant_graph(checkpoints_dir, constant_graph_dir, constant_graph_file):
with tf.Graph().as_default():
with tf.Session() as session:
checkpoint = tf.train.get_checkpoint_state(checkpoints_dir)
if checkpoint and checkpoint.model_checkpoint_path:
saver = tf.train.import_meta_graph(checkpoint.model_checkpoint_path + '.meta')
else:
raise ValueError('No checkpoint file found')
saver.restore(session, checkpoint.model_checkpoint_path)
constant_graph = graph_util.convert_variables_to_constants(session, session.graph_def, [RESULT_NODE_NAME])
tf.train.write_graph(constant_graph, constant_graph_dir, constant_graph_file, as_text=False)
def use_saved_graph(constant_graph_dir, constant_graph_file):
with tf.Graph().as_default():
graph_def = tf.GraphDef()
path_to_graph_file = os.path.join(constant_graph_dir, constant_graph_file)
with open(path_to_graph_file, 'rb') as graph_file:
graph_def.ParseFromString(graph_file.read())
tf.import_graph_def(graph_def)
with tf.Session() as session:
result_op = session.graph.get_tensor_by_name(RESULT_NODE_NAME + ':0')
data_placeholder = session.graph.get_tensor_by_name(DATA_PLACEHOLDER_NAME + ':0')
is_train_phase_placeholder = session.get_tensor_by_name(DATA_PLACEHOLDER_NAME + ':0')
data = 4.5 * np.random.randn(*DATA_BATCH_SHAPE)
result = session.run([result_op], feed_dict={data_placeholder: data, is_train_phase_placeholder: False})
print('Result is ', result)
checkpoint_dir = 'checkpoints'
graph_dir = 'graphs'
graph_file_name = 'constant_graph.pb'
train_and_save_graph(checkpoint_dir)
convert_to_constant_graph(checkpoint_dir, graph_dir, graph_file_name)
use_saved_graph(graph_dir, graph_file_name)
Traceback (most recent call last):
File "scripts/demo.py", line 99, in <module>
use_saved_graph(graph_dir, graph_file_name)
File "scripts/demo.py", line 80, in use_saved_graph
tf.import_graph_def(graph_def)
File "/home/ingwar/virtualenvs/upwork_data_science/local/lib/python2.7/site-packages/tensorflow/python/framework/importer.py", line 320, in import_graph_def
node, 'Input tensor %r %s' % (input_name, te)))
ValueError: graph_def is invalid at node u'ExponentialMovingAverage/AssignMovingAvg': Input tensor 'moments/moments_1/mean/ExponentialMovingAverage:0' Cannot convert a tensor of type float32 to an input of type float32_ref.
@lilac
Copy link

lilac commented May 21, 2016

Hi, I encountered exactly the same error. Have you worked around this issue?

@bcaine
Copy link

bcaine commented Jul 28, 2016

Did you ever find a solution to this issue? Also hitting it when using ExponentialMovingAverage.

Thanks!

@ntk148v
Copy link

ntk148v commented Jan 31, 2018

Did you find a solution? I also faced the same issue 😢

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment