Skip to content

Instantly share code, notes, and snippets.

@vishal-keshav
Created April 3, 2019 10:30
Show Gist options
  • Save vishal-keshav/7fa502ffc9f8fd592a1fc400c031c113 to your computer and use it in GitHub Desktop.
Save vishal-keshav/7fa502ffc9f8fd592a1fc400c031c113 to your computer and use it in GitHub Desktop.
Optimizing tensorflow model for inference
import sys
import tensorflow as tf
from tensorflow.python.tools import freeze_graph
from tensorflow.python.tools import optimize_for_inference_lib
def convert_frozen_to_inference(model_path = "generated_model",
frozen_file = "generated_model.pb", inputs = ["input"],
outputs = ["output"], out_file = "generated_model_opt.pb"):
frozen_graph = tf.GraphDef()
with tf.gfile.Open(model_path + "/" + frozen_file) as f:
file_data = f.read()
frozen_graph.ParseFromString(file_data)
optimized_graph_def = optimize_for_inference_lib.optimize_for_inference(
frozen_graph, inputs, outputs, tf.float32.as_datatype_enum)
f = tf.gfile.FastGFile(model_path + "/" + out_file, "w")
f.write(optimized_graph_def.SerializeToString())
def main():
inputs = ["Placeholder"] # Set this list appropriately
#outputs = ["DepthToSpace"] # Set this list appropriately
outputs = ["g_conv10/BiasAdd"]
convert_frozen_to_inference(inputs = inputs, outputs = outputs)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment