Created
August 2, 2022 14:09
-
-
Save dudeperf3ct/b1d761e538bbec167196c94914440d68 to your computer and use it in GitHub Desktop.
Export to TFLite
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from functools import partial | |
import numpy as np | |
import os | |
import glob | |
import argparse | |
import tensorflow as tf | |
from tensorflow.keras.preprocessing.image import ImageDataGenerator | |
from config.config import logger | |
# For testing purpose | |
def test_representative_dataset_gen(img_height, img_width): | |
for _ in range(100): | |
data = np.random.rand(1, img_height, img_width, 3) | |
yield [data.astype(np.float32)] | |
def load_test_data(args): | |
if args.test_dir != "": | |
files = glob.glob(f"{args.test_dir}/*") | |
arr = [] | |
for i in range(len(files)): | |
im = np.array( | |
tf.keras.utils.load_img( | |
files[i], target_size=(args.img_height, args.img_width) | |
) | |
) | |
im = im.astype(np.float32, copy=False) | |
arr.append(im) | |
arr = tf.convert_to_tensor(np.array(arr, dtype="float32")) | |
ds = tf.data.Dataset.from_tensor_slices((arr)).batch(1) | |
return ds | |
def representative_dataset_gen(ds): | |
for input_value in ds.take(10): | |
yield [input_value] | |
# TODO: implement representative_dataset that takes in a directory of images and returns a dataset for int8 quantization | |
def export_tf_model(args): | |
"""Export TensorFlow model to TensorFlow Lite format. | |
Parameters | |
---------- | |
args : ArgumentParser | |
Arguments parsed from command line. | |
""" | |
logger.info("================= Converting .h5 to .tflite ================") | |
# path to the TF Saved Model | |
if args.model_path.endswith(".h5") and os.path.isfile(args.model_path): | |
model = tf.keras.models.load_model(args.model_path) | |
else: | |
raise ValueError("TF Model must end with .h5") | |
# Convert the model | |
converter = tf.lite.TFLiteConverter.from_keras_model(model) | |
tflite_model_name = "" | |
converter.optimizations = [tf.lite.Optimize.DEFAULT] | |
# use fp16 for increased accuracy | |
if args.fp16 and not args.int8: | |
logger.info("FP16 quantization enabled") | |
converter.target_spec.supported_types = [tf.float16] | |
tflite_model_name = "model_fp16.tflite" | |
# int8 quantization | |
elif args.int8 and not args.fp16: | |
logger.info("INT8 quantization enabled") | |
# ignore representative dataset if running a test run to check file sizes only | |
if args.test_dir == "" and not args.test_run: | |
logger.error( | |
"Please provide representative dataset for INT8 quantization and set image height and width accordingly" | |
) | |
raise ValueError( | |
"Please provide representative dataset for INT8 quantization" | |
) | |
# requires representative dataset for full integer quantization | |
if args.test_run: | |
logger.info("Running using dummy representative dataset") | |
converter.representative_dataset = partial( | |
test_representative_dataset_gen, args.img_height, args.img_width | |
) | |
else: | |
logger.info( | |
"Creating representative dataset using samples from {}".format( | |
args.test_dir | |
) | |
) | |
ds = load_test_data(args) | |
converter.representative_dataset = partial(representative_dataset_gen, ds) | |
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] | |
converter.inference_input_type = tf.uint8 # or tf.int8 | |
converter.inference_output_type = tf.uint8 # or tf.int8 | |
tflite_model_name = "model_int8.tflite" | |
else: | |
logger.error("Both fp16 and int8 are not supported") | |
raise NotImplementedError("Both fp16 and int8 are not supported") | |
tflite_model = converter.convert() | |
# Save the model. | |
root_dir = os.path.dirname(args.model_path) | |
with open(os.path.join(root_dir, tflite_model_name), "wb") as f: | |
f.write(tflite_model) | |
# get sizes of model | |
original_model_size = float(os.path.getsize(args.model_path)) / 1e6 | |
tflite_model_size = float( | |
os.path.getsize(os.path.join(root_dir, tflite_model_name)) / 1e6 | |
) | |
logger.info("==================== Stats =========================") | |
logger.info( | |
"TF Lite model saved at: {}".format(os.path.join(root_dir, tflite_model_name)) | |
) | |
logger.info(f"Original TF Model size: {original_model_size} MB") | |
logger.info(f"TF Lite size: {tflite_model_size} MB") | |
logger.info(f"Reduction in size: {original_model_size - tflite_model_size:.2f} MB") | |
logger.info( | |
f"Reduction percentage: {(original_model_size - tflite_model_size) / original_model_size * 100:.2f} %" | |
) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description="Export to TF Lite") | |
parser.add_argument( | |
"--model-path", type=str, required=True, help="Path to saved TF Model" | |
) | |
parser.add_argument( | |
"--img-height", type=int, default=100, help="Image height used for training" | |
) | |
parser.add_argument( | |
"--img-width", type=int, default=100, help="Image height used for training" | |
) | |
parser.add_argument("--fp16", action="store_true", help="Enable FP16 quantization") | |
parser.add_argument("--int8", action="store_true", help="Enable INT8 quantization") | |
parser.add_argument( | |
"--test-run", action="store_true", help="Use dummy representative dataset" | |
) | |
parser.add_argument("--test-dir", type=str, default="", help="Path to test data") | |
args = parser.parse_args() | |
if args.test_dir == "" and args.int8 and not args.test_run: | |
logger.error("INT8 quantization requires test data") | |
raise ValueError("INT8 quantization requires test data") | |
if args.int8: | |
logger.info( | |
"Set proper image height and width for INT8 quantization (default = 100)" | |
) | |
export_tf_model(args) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment