Skip to content

Instantly share code, notes, and snippets.

@dpattison3
Created January 13, 2018 00:28
Show Gist options
  • Save dpattison3/26bf10fabc0dc08c4b19920c2330e39b to your computer and use it in GitHub Desktop.
Save dpattison3/26bf10fabc0dc08c4b19920c2330e39b to your computer and use it in GitHub Desktop.
import tensorflow as tf
import scipy.misc
from timeit import default_timer as timer
import cv2
import numpy as np
def load_graph(frozen_graph_filename):
# We load the protobuf file from the disk and parse it to retrieve the
# unserialized graph_def
with tf.gfile.GFile(frozen_graph_filename, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
# Then, we import the graph_def into a new Graph and returns it
with tf.Graph().as_default() as graph:
# The name var will prefix every op/nodes in your graph
# Since we load everything in a new graph, this is not needed
tf.import_graph_def(graph_def, name='')
return graph
graph = load_graph('test_frozen_model.pb')
sess = tf.Session(graph=graph)
im_path = 'CHANGE_THIS'
im = cv2.imread(im_path)
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
im = np.asarray(im)
image_pl = graph.get_tensor_by_name('Placeholder_1:0')
softmax = graph.get_tensor_by_name('Validation/decoder/Softmax:0')
output = sess.run([softmax], feed_dict={image_pl: im})
shape = im.shape
output = output[0][:, 1].reshape(shape[0], shape[1])
threshold = 0.5
im_threshold = output > threshold
im_threshold = np.uint8(255*im_threshold)
cv2.imwrite('TEST.jpg', im_threshold)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment