Skip to content

Instantly share code, notes, and snippets.

@apivovarov
Created June 13, 2019 01:31
Show Gist options
  • Save apivovarov/df0c502a45755702974b42dce3c9e858 to your computer and use it in GitHub Desktop.
Save apivovarov/df0c502a45755702974b42dce3c9e858 to your computer and use it in GitHub Desktop.
Convert to tflite format
#!/usr/bin/env python3
import tensorflow.lite as lite
from tensorflow.lite.python import lite_constants
import sys
# Converting a GraphDef from file.
def from_frozen_graph(graph_def_file):
input_arrays = ["normalized_input_image_tensor"]
output_arrays = ["TFLite_Detection_PostProcess","TFLite_Detection_PostProcess:1","TFLite_Detection_PostProcess:2","TFLite_Detection_PostProcess:3"]
input_shapes = {"normalized_input_image_tensor" : [1, 300, 300, 3]}
converter = lite.TFLiteConverter.from_frozen_graph(
graph_def_file, input_arrays, output_arrays, input_shapes)
return converter
# Converting a SavedModel.
def from_saved_model(saved_model_dir):
converter = lite.TFLiteConverter.from_saved_model(saved_model_dir)
return converter
def convert(converter, out_name, is_quant):
converter.inference_type = lite_constants.QUANTIZED_UINT8 if is_quant else lite_constants.FLOAT
converter.output_format = lite_constants.TFLITE
converter.allow_custom_ops = True
converter.quantized_input_stats = {"normalized_input_image_tensor": (128., 127.)} if is_quant else None
print("Converting...")
tflite_model = converter.convert()
open(out_name, "wb").write(tflite_model)
print("tflite file: {}".format(out_name))
path = sys.argv[1]
is_quant = "quant" in path.lower()
print("is_quant: {}".format(is_quant))
if path.endswith(".pb"):
out_name = path[:-3] + ".tflite"
converter = from_frozen_graph(path)
else:
out_name = path + ".tflite"
converter = from_saved_model(path)
convert(converter, out_name, is_quant)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment