Skip to content

Instantly share code, notes, and snippets.

@batrlatom
Created September 2, 2019 14:59
Show Gist options
  • Save batrlatom/f67bbe8c26a916ad3134630fddc2018b to your computer and use it in GitHub Desktop.
Save batrlatom/f67bbe8c26a916ad3134630fddc2018b to your computer and use it in GitHub Desktop.
import onnx
from onnx_tf.backend import prepare
import tensorflow as tf
from tensorflow.python.client import timeline
import time
# Prepare the inputs, here we use numpy to generate some random inputs for demo purpose
import numpy as np
img = np.random.randn(1, 3, 224, 224).astype(np.float32)
# Load the ONNX model
print("Loading")
model = onnx.load('onnx_model_name.onnx')
tf_rep = prepare(model, strict=False)
print(tf_rep.inputs) # Input nodes to the model
print('-----')
print(tf_rep.outputs) # Output nodes from the model
print('-----')
print(tf_rep.tensor_dict) # All nodes in the model
print("Running") # this should run fast
run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
run_metadata = tf.RunMetadata()
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config, graph=None)
tf.import_graph_def(tf_rep.graph.as_graph_def(),name="")
start = time.time()
for i in range(10):
output = sess.run("add_9:0", feed_dict = {"input:0": img}, options=run_options, run_metadata=run_metadata) #tf_rep.run(img) ##tf_rep.run(img)
print(output)
end = time.time()
print("time elapsed:")
print(end - start)
tl = timeline.Timeline(run_metadata.step_stats)
ctf = tl.generate_chrome_trace_format()
with open('trace_file.json', 'w') as f:
f.write(ctf)
print(tf_rep.tensor_dict)
tf_rep.export_graph("onnx_tf.pb")
#"""
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment