Skip to content

Instantly share code, notes, and snippets.

@mckeesh
Created February 9, 2017 18:57
Show Gist options
  • Save mckeesh/01af030b147e809b0edbc9f67104e859 to your computer and use it in GitHub Desktop.
Save mckeesh/01af030b147e809b0edbc9f67104e859 to your computer and use it in GitHub Desktop.
ValueError: No variables to save
#!/usr/bin/python
import time
import os
import tensorflow as tf
import numpy as np
import skimage
import skimage.io
import skimage.transform
import freeze_graph
def outputGraph(sess, graph):
with graph.as_default():
checkpoint_dir = '/home/shane/code/tensorflow/graph_saver/'
checkpoint_state_name = "save_checkpoint.ckpt"
input_graph = "input_graph.pb"
output_graph_name = "output_graph.pb"
# printing variable lists
print(tf.get_collection(tf.GraphKeys.VARIABLES)) #deprecated
print(tf.all_variables()) #deprecated
print(tf.global_variables())
print(tf.local_variables())
print(tf.model_variables())
print(tf.trainable_variables())
print(tf.moving_average_variables())
print("Saving to checkpoint")
saver = tf.train.Saver()
# Fails on this line
# ValueError: No variables to save
saver.save(sess, checkpoint_dir + checkpoint_state_name, global_step=0,
latest_filename=checkpoint_state_name)
print("Saving the graph")
tf.train.write_graph(graph, checkpoint_dir, input_graph)
input_graph_path = os.path.join(checkpoint_dir, input_graph)
input_saver = ""
input_binary = False
input_checkpoint_path = os.path.join(checkpoint_dir, checkpoint_state_name) + "-0.index"
output_node_names = "import/softmax:0"
restore_op_name = "save/restore_all"
filename_tensor_name = "save/Const:0"
output_graph_path = os.path.join(checkpoint_dir, output_graph_name)
clear_devices = False
initializer_nodes = ""
freeze_graph.freeze_graph(input_graph_path,
input_saver,
input_binary,
input_checkpoint_path,
output_node_names,
restore_op_name,
filename_tensor_name,
output_graph_path,
clear_devices,
initializer_nodes)
def main():
with tf.device('/cpu:0'):
with open("tensorflow_inception_graph.pb", mode='rb') as f:
fileContent = f.read()
graph_def = tf.GraphDef()
graph_def.ParseFromString(fileContent)
images = tf.placeholder("float", [ None, 299, 299, 3])
tf.import_graph_def(graph_def, input_map={ "Mul": images })
print "graph loaded from disk"
graph = tf.get_default_graph()
im = load_image("grace_hopper.jpg")
with tf.Session() as sess:
init = tf.global_variables_initializer()
sess.run(init)
print "variables initialized"
batch = im.reshape((1, 299, 299, 3))
assert batch.shape == (1, 299, 299, 3)
feed_dict = { images: batch }
prob_tensor = graph.get_tensor_by_name("import/softmax:0")
# printing variable lists
print(tf.get_collection(tf.GraphKeys.VARIABLES)) #deprecated
print(tf.all_variables()) #deprecated
print(tf.global_variables())
print(tf.local_variables())
print(tf.model_variables())
print(tf.trainable_variables())
print(tf.moving_average_variables())
ITERATIONS = 10
total = 0
for i in range(ITERATIONS):
print(i)
t1 = time.time()
prob = sess.run(prob_tensor, feed_dict=feed_dict)
t2 = time.time()
total += (t2-t1)
print_prob(prob[0])
print("%fms per run\n" % ((total/ITERATIONS)*1000))
outputGraph(sess, graph)
sess.close()
################## Helper Functions #####################
# returns image of shape [299, 299, 3]
# [height, width, depth]
def load_image(path):
# load image
img = skimage.io.imread(path)
img = img / 255.0
assert (0 <= img).all() and (img <= 1.0).all()
#print "Original Image Shape: ", img.shape
# we crop image from center
short_edge = min(img.shape[:2])
yy = int((img.shape[0] - short_edge) / 2)
xx = int((img.shape[1] - short_edge) / 2)
crop_img = img[yy : yy + short_edge, xx : xx + short_edge]
# resize to 299, 299
resized_img = skimage.transform.resize(crop_img, (299, 299))
return resized_img
# returns the top1 string
def print_prob(prob):
synset = [l.strip() for l in open('imagenet_comp_graph_label_strings.txt').readlines()]
#print prob
print "prob shape", prob.shape
pred = np.argsort(prob)[::-1]
# Get top1 label
top1 = synset[pred[0]]
print "Top1: ", top1
# Get top5 label
top5 = [synset[pred[i]] for i in range(5)]
print "Top5: ", top5
return top1
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment