Skip to content

Instantly share code, notes, and snippets.

@takotab
Created April 26, 2019 09:24
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save takotab/ccd131fb06d4cdd2b786b3ef83c62ade to your computer and use it in GitHub Desktop.
Save takotab/ccd131fb06d4cdd2b786b3ef83c62ade to your computer and use it in GitHub Desktop.
Please note it also includes stuff to use models from onnx
import tensorflow as tf
import onnx
from onnx_tf.backend import prepare
import numpy as np
import os
# from PIL import Image
def export(export_path):
onnx_model = onnx.load("resnet_fine.pt.onnx")
tf_rep = prepare(onnx_model, strict=False)
# image = Image.open("data.png")
with tf.Graph().as_default():
graph_def = tf.GraphDef()
with open("saved_model.pb", "rb") as f:
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name="")
with tf.Session() as sess:
# init = tf.initialize_all_variables()
init = tf.global_variables_initializer()
input_x = sess.graph.get_tensor_by_name("0:0") # input
outputs1 = sess.graph.get_tensor_by_name("add_10:0")
output_tf_pb = sess.run(
[outputs1], feed_dict={input_x: np.random.randn(1, 3, 64, 64)}
)
# output_tf_pb = sess.run([outputs1, outputs2], feed_dict={input_x:np.random.randn(1, 3, 224, 224)})
print("output_tf_pb = {}".format(output_tf_pb))
# os.removedirs("output2")
builder = tf.saved_model.builder.SavedModelBuilder("output2")
prediction_signature = tf.saved_model.signature_def_utils.build_signature_def(
inputs={"images": tf.saved_model.utils.build_tensor_info(input_x)},
outputs={"scores": tf.saved_model.utils.build_tensor_info(outputs1)},
method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME,
)
tensor_info_x = tf.saved_model.utils.build_tensor_info(input_x)
tensor_info_y = tf.saved_model.utils.build_tensor_info(outputs1)
classification_signature = tf.saved_model.signature_def_utils.build_signature_def(
inputs={tf.saved_model.signature_constants.CLASSIFY_INPUTS: tensor_info_x},
outputs={
tf.saved_model.signature_constants.CLASSIFY_OUTPUT_CLASSES: tensor_info_y
},
method_name=tf.saved_model.signature_constants.CLASSIFY_METHOD_NAME,
)
builder.add_meta_graph_and_variables(
sess,
[tf.saved_model.tag_constants.SERVING],
signature_def_map={
"predict_images": prediction_signature,
tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: classification_signature,
},
main_op=tf.tables_initializer(),
)
builder.save()
if __name__ == "__main__":
export("output/export")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment