Skip to content

Instantly share code, notes, and snippets.

@guschmue
Last active September 25, 2020 23:13
Show Gist options
  • Save guschmue/fdd66ec852c3d179a97bc08b3331d640 to your computer and use it in GitHub Desktop.
Save guschmue/fdd66ec852c3d179a97bc08b3331d640 to your computer and use it in GitHub Desktop.
tf.keras to onnx
import tf2onnx
from tensorflow.python.keras.saving import saving_utils as _saving_utils
def to_onnx(model, output=None):
function = _saving_utils.trace_model_call(model)
concrete_func = function.get_concrete_function()
input_names = [input_tensor.name for input_tensor in concrete_func.inputs
if input_tensor.dtype != tf.dtypes.resource]
output_names = [output_tensor.name for output_tensor in concrete_func.outputs
if output_tensor.dtype != tf.dtypes.resource]
frozen_graph = tf2onnx.tf_loader.from_function(concrete_func, input_names, output_names)
with tf.Graph().as_default() as tf_graph:
tf.import_graph_def(frozen_graph, name='')
g = tf2onnx.tfonnx.process_tf_graph(tf_graph, opset=11, input_names=input_names, output_names=output_names)
onnx_graph = tf2onnx.optimizer.optimize_graph(g)
model_proto = onnx_graph.make_model("test model")
if output:
tf2onnx.utils.save_protobuf(output, model_proto)
inputs = [n.name for n in model_proto.graph.input]
outputs = [n.name for n in model_proto.graph.output]
return model_proto, inputs, outputs
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment