Last active
January 31, 2018 10:37
-
-
Save Ingwar/3b8a7232f4cb906ff1f0 to your computer and use it in GitHub Desktop.
Script that reproduce issues with tf.train.write_graph() and graph_utils.convert_to_constant_graph()
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import division | |
from __future__ import print_function | |
from __future__ import unicode_literals | |
import numpy as np | |
import os | |
import tensorflow as tf | |
from tensorflow.python.client import graph_util | |
DATA_PLACEHOLDER_NAME = 'data' | |
IS_TRAIN_PHASE_PLACEHOLDER_NAME = 'is_train_phase' | |
RESULT_NODE_NAME = 'result' | |
DATA_BATCH_SHAPE = (5, 5) | |
def train_and_save_graph(checkpoints_dir): | |
with tf.Graph().as_default(): | |
data_placeholder = tf.placeholder(dtype=tf.float32, shape=DATA_BATCH_SHAPE, name=DATA_PLACEHOLDER_NAME) | |
is_train_phase_placeholder = tf.placeholder(dtype=tf.bool, name=IS_TRAIN_PHASE_PLACEHOLDER_NAME) | |
# The actual computations are meaningless, but the pattern are similar to | |
# calculation of batch mean/variance for batch normalization | |
batch_mean, batch_var = tf.nn.moments(data_placeholder, [0], name='moments') | |
ema = tf.train.ExponentialMovingAverage(decay=0.9) | |
ema_apply_op = ema.apply([batch_mean, batch_var]) | |
ema_mean = ema.average(batch_mean) | |
ema_var = ema.average(batch_var) | |
def mean_var_with_update(): | |
with tf.control_dependencies([ema_apply_op]): | |
return tf.identity(batch_mean), tf.identity(batch_var) | |
mean, var = tf.cond(is_train_phase_placeholder, mean_var_with_update, lambda: (ema_mean, ema_var)) | |
normalized_data = (data_placeholder - mean) / var | |
result_op = tf.reduce_sum(normalized_data ** 2, name=RESULT_NODE_NAME) | |
saver = tf.train.Saver(tf.all_variables()) | |
init_op = tf.initialize_all_variables() | |
with tf.Session() as session: | |
session.run(init_op) | |
current_step = None | |
for i in range(100): | |
data = np.random.randn(*DATA_BATCH_SHAPE) * np.random.uniform(0.1, 10) | |
result = session.run([result_op], feed_dict={data_placeholder: data, is_train_phase_placeholder: True}) | |
print('Step {:d}. result: {}'.format(i, result)) | |
current_step = i | |
path_to_checkpoint = os.path.join(checkpoints_dir, 'model.ckpt') | |
try: | |
os.mkdir(checkpoints_dir) | |
except OSError: | |
pass | |
saver.save(session, path_to_checkpoint, global_step=current_step) | |
def convert_to_constant_graph(checkpoints_dir, constant_graph_dir, constant_graph_file): | |
with tf.Graph().as_default(): | |
with tf.Session() as session: | |
checkpoint = tf.train.get_checkpoint_state(checkpoints_dir) | |
if checkpoint and checkpoint.model_checkpoint_path: | |
saver = tf.train.import_meta_graph(checkpoint.model_checkpoint_path + '.meta') | |
else: | |
raise ValueError('No checkpoint file found') | |
saver.restore(session, checkpoint.model_checkpoint_path) | |
constant_graph = graph_util.convert_variables_to_constants(session, session.graph_def, [RESULT_NODE_NAME]) | |
tf.train.write_graph(constant_graph, constant_graph_dir, constant_graph_file, as_text=False) | |
def use_saved_graph(constant_graph_dir, constant_graph_file): | |
with tf.Graph().as_default(): | |
graph_def = tf.GraphDef() | |
path_to_graph_file = os.path.join(constant_graph_dir, constant_graph_file) | |
with open(path_to_graph_file, 'rb') as graph_file: | |
graph_def.ParseFromString(graph_file.read()) | |
tf.import_graph_def(graph_def) | |
with tf.Session() as session: | |
result_op = session.graph.get_tensor_by_name(RESULT_NODE_NAME + ':0') | |
data_placeholder = session.graph.get_tensor_by_name(DATA_PLACEHOLDER_NAME + ':0') | |
is_train_phase_placeholder = session.get_tensor_by_name(DATA_PLACEHOLDER_NAME + ':0') | |
data = 4.5 * np.random.randn(*DATA_BATCH_SHAPE) | |
result = session.run([result_op], feed_dict={data_placeholder: data, is_train_phase_placeholder: False}) | |
print('Result is ', result) | |
checkpoint_dir = 'checkpoints' | |
graph_dir = 'graphs' | |
graph_file_name = 'constant_graph.pb' | |
train_and_save_graph(checkpoint_dir) | |
convert_to_constant_graph(checkpoint_dir, graph_dir, graph_file_name) | |
use_saved_graph(graph_dir, graph_file_name) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Traceback (most recent call last): | |
File "scripts/demo.py", line 99, in <module> | |
use_saved_graph(graph_dir, graph_file_name) | |
File "scripts/demo.py", line 80, in use_saved_graph | |
tf.import_graph_def(graph_def) | |
File "/home/ingwar/virtualenvs/upwork_data_science/local/lib/python2.7/site-packages/tensorflow/python/framework/importer.py", line 320, in import_graph_def | |
node, 'Input tensor %r %s' % (input_name, te))) | |
ValueError: graph_def is invalid at node u'ExponentialMovingAverage/AssignMovingAvg': Input tensor 'moments/moments_1/mean/ExponentialMovingAverage:0' Cannot convert a tensor of type float32 to an input of type float32_ref. |
Did you ever find a solution to this issue? Also hitting it when using ExponentialMovingAverage.
Thanks!
Did you find a solution? I also faced the same issue 😢
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi, I encountered exactly the same error. Have you worked around this issue?