Skip to content

Instantly share code, notes, and snippets.

@batrlatom
Created September 2, 2019 14:58
Show Gist options
  • Save batrlatom/e938b5e17b04ba84123bc1a8fa1a0b8e to your computer and use it in GitHub Desktop.
Save batrlatom/e938b5e17b04ba84123bc1a8fa1a0b8e to your computer and use it in GitHub Desktop.
import tensorflow as tf
import argparse
import numpy as np
import cv2
def load_graph(frozen_graph_filename):
with tf.gfile.GFile(frozen_graph_filename, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
with tf.Graph().as_default() as graph:
tf.import_graph_def(graph_def, name="prefix")
return graph
def load_pb(path_to_pb):
with tf.gfile.GFile(path_to_pb, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
with tf.Graph().as_default() as graph:
tf.import_graph_def(graph_def, name='')
return graph
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--frozen_model_filename", default="onnx_tf.pb", type=str, help="Frozen model file to import")
args = parser.parse_args()
graph = load_pb(args.frozen_model_filename)
with tf.Session(graph=graph) as sess:
print([n.name for n in tf.get_default_graph().as_graph_def().node])
input_img = graph.get_tensor_by_name('input:0')
y = graph.get_tensor_by_name('add_9:0')
for i in range(0, 100):
img_in = np.random.rand(1, 3, 224, 224)
y_out = sess.run(y, feed_dict={ input_img: img_in})
print(y_out[0])
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment