Skip to content

Instantly share code, notes, and snippets.

@Namburger
Last active May 26, 2020 14:12
Show Gist options
  • Save Namburger/f44a938886ad4a0325ca2f30263fcee0 to your computer and use it in GitHub Desktop.
Save Namburger/f44a938886ad4a0325ca2f30263fcee0 to your computer and use it in GitHub Desktop.
Example code for post training quantization with tensorflow from_frozen_graph API (deprecated in tensorflow2.0).
# More infro here on Post Training Quantization here:
# https://www.tensorflow.org/lite/performance/post_training_quantization
# from_frozen_graph api is not n tf2.0 but can still be use with tf.compat.v1.lite, more on this api:
# https://www.tensorflow.org/api_docs/python/tf/compat/v1/lite/TFLiteConverter#from_frozen_graph
# This is an example for converting a frozen graph model to a fully quantized tflite model
# The model used here is http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.25_192.tgz
# Note that with post training quantization, sometimes it is not guarantee that the model will be fully quantized.
import sys, os, glob
import tensorflow as tf
import pathlib
import numpy as np
if len(sys.argv) != 2:
print('Usage: <' + sys.argv[0] + '> <frozen_graph_file>')
exit()
tf.compat.v1.enable_eager_execution()
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.DEBUG)
def fake_representative_data_gen():
for _ in range(100):
fake_image = np.random.random((1,224,224,3)).astype(np.float32)
yield [fake_image]
frozen_graph = sys.argv[1]
input_array = ['input']
output_array = ['MobilenetV1/Predictions/Reshape_1']
converter = tf.compat.v1.lite.TFLiteConverter.from_frozen_graph(frozen_graph, input_array, output_array)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = fake_representative_data_gen
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.uint8
converter.inference_output_type = tf.uint8
tflite_model = converter.convert()
quant_dir = pathlib.Path(os.getcwd(), 'output')
quant_dir.mkdir(exist_ok=True, parents=True)
tflite_model_file = quant_dir/'mobilenet_v1_0.25_224_quant.tflite'
tflite_model_file.write_bytes(tflite_model)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment