Created
February 9, 2017 18:57
-
-
Save mckeesh/01af030b147e809b0edbc9f67104e859 to your computer and use it in GitHub Desktop.
ValueError: No variables to save
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/python | |
import time | |
import os | |
import tensorflow as tf | |
import numpy as np | |
import skimage | |
import skimage.io | |
import skimage.transform | |
import freeze_graph | |
def outputGraph(sess, graph): | |
with graph.as_default(): | |
checkpoint_dir = '/home/shane/code/tensorflow/graph_saver/' | |
checkpoint_state_name = "save_checkpoint.ckpt" | |
input_graph = "input_graph.pb" | |
output_graph_name = "output_graph.pb" | |
# printing variable lists | |
print(tf.get_collection(tf.GraphKeys.VARIABLES)) #deprecated | |
print(tf.all_variables()) #deprecated | |
print(tf.global_variables()) | |
print(tf.local_variables()) | |
print(tf.model_variables()) | |
print(tf.trainable_variables()) | |
print(tf.moving_average_variables()) | |
print("Saving to checkpoint") | |
saver = tf.train.Saver() | |
# Fails on this line | |
# ValueError: No variables to save | |
saver.save(sess, checkpoint_dir + checkpoint_state_name, global_step=0, | |
latest_filename=checkpoint_state_name) | |
print("Saving the graph") | |
tf.train.write_graph(graph, checkpoint_dir, input_graph) | |
input_graph_path = os.path.join(checkpoint_dir, input_graph) | |
input_saver = "" | |
input_binary = False | |
input_checkpoint_path = os.path.join(checkpoint_dir, checkpoint_state_name) + "-0.index" | |
output_node_names = "import/softmax:0" | |
restore_op_name = "save/restore_all" | |
filename_tensor_name = "save/Const:0" | |
output_graph_path = os.path.join(checkpoint_dir, output_graph_name) | |
clear_devices = False | |
initializer_nodes = "" | |
freeze_graph.freeze_graph(input_graph_path, | |
input_saver, | |
input_binary, | |
input_checkpoint_path, | |
output_node_names, | |
restore_op_name, | |
filename_tensor_name, | |
output_graph_path, | |
clear_devices, | |
initializer_nodes) | |
def main(): | |
with tf.device('/cpu:0'): | |
with open("tensorflow_inception_graph.pb", mode='rb') as f: | |
fileContent = f.read() | |
graph_def = tf.GraphDef() | |
graph_def.ParseFromString(fileContent) | |
images = tf.placeholder("float", [ None, 299, 299, 3]) | |
tf.import_graph_def(graph_def, input_map={ "Mul": images }) | |
print "graph loaded from disk" | |
graph = tf.get_default_graph() | |
im = load_image("grace_hopper.jpg") | |
with tf.Session() as sess: | |
init = tf.global_variables_initializer() | |
sess.run(init) | |
print "variables initialized" | |
batch = im.reshape((1, 299, 299, 3)) | |
assert batch.shape == (1, 299, 299, 3) | |
feed_dict = { images: batch } | |
prob_tensor = graph.get_tensor_by_name("import/softmax:0") | |
# printing variable lists | |
print(tf.get_collection(tf.GraphKeys.VARIABLES)) #deprecated | |
print(tf.all_variables()) #deprecated | |
print(tf.global_variables()) | |
print(tf.local_variables()) | |
print(tf.model_variables()) | |
print(tf.trainable_variables()) | |
print(tf.moving_average_variables()) | |
ITERATIONS = 10 | |
total = 0 | |
for i in range(ITERATIONS): | |
print(i) | |
t1 = time.time() | |
prob = sess.run(prob_tensor, feed_dict=feed_dict) | |
t2 = time.time() | |
total += (t2-t1) | |
print_prob(prob[0]) | |
print("%fms per run\n" % ((total/ITERATIONS)*1000)) | |
outputGraph(sess, graph) | |
sess.close() | |
################## Helper Functions ##################### | |
# returns image of shape [299, 299, 3] | |
# [height, width, depth] | |
def load_image(path): | |
# load image | |
img = skimage.io.imread(path) | |
img = img / 255.0 | |
assert (0 <= img).all() and (img <= 1.0).all() | |
#print "Original Image Shape: ", img.shape | |
# we crop image from center | |
short_edge = min(img.shape[:2]) | |
yy = int((img.shape[0] - short_edge) / 2) | |
xx = int((img.shape[1] - short_edge) / 2) | |
crop_img = img[yy : yy + short_edge, xx : xx + short_edge] | |
# resize to 299, 299 | |
resized_img = skimage.transform.resize(crop_img, (299, 299)) | |
return resized_img | |
# returns the top1 string | |
def print_prob(prob): | |
synset = [l.strip() for l in open('imagenet_comp_graph_label_strings.txt').readlines()] | |
#print prob | |
print "prob shape", prob.shape | |
pred = np.argsort(prob)[::-1] | |
# Get top1 label | |
top1 = synset[pred[0]] | |
print "Top1: ", top1 | |
# Get top5 label | |
top5 = [synset[pred[i]] for i in range(5)] | |
print "Top5: ", top5 | |
return top1 | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment