Skip to content

Instantly share code, notes, and snippets.

@vishal-keshav
Created April 9, 2019 06:07
Show Gist options
  • Save vishal-keshav/94217d4cc5fbd8d1a0790434161b4dc4 to your computer and use it in GitHub Desktop.
Save vishal-keshav/94217d4cc5fbd8d1a0790434161b4dc4 to your computer and use it in GitHub Desktop.
Infer the output from a tensorflow pb model
import tensorflow as tf
import numpy as np
def preprocess(img):
# Apply any preprocessing on the input
return img
def inference_from_pb(pb_file = "model.pb", img, inputs = ['input'], outputs = ['output']):
img = preporcess(img)
with tf.gfile.GFile(pb_file, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
with tf.Graph().as_default() as graph:
tf.import_graph_def(graph_def)
#for op in graph.get_operations():
# print(op.name)
with tf.Session(graph = graph) as sess:
graph_input = graph.get_tensor_by_name(inputs[0])
graph_output = graph.get_tensor_by_name(outputs[0])
#tf.global_variables_initializer()
output = sess.run(graph_output, feed_dict = {graph_input: img})
return output
def main():
img = np.array([[1,2],[3,4]])
out = inference_from_pb(img)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment