Skip to content

Instantly share code, notes, and snippets.

@dudeperf3ct
Created August 2, 2022 14:09
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dudeperf3ct/b1d761e538bbec167196c94914440d68 to your computer and use it in GitHub Desktop.
Save dudeperf3ct/b1d761e538bbec167196c94914440d68 to your computer and use it in GitHub Desktop.
Export to TFLite
from functools import partial
import numpy as np
import os
import glob
import argparse
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from config.config import logger
# For testing purpose
def test_representative_dataset_gen(img_height, img_width):
for _ in range(100):
data = np.random.rand(1, img_height, img_width, 3)
yield [data.astype(np.float32)]
def load_test_data(args):
if args.test_dir != "":
files = glob.glob(f"{args.test_dir}/*")
arr = []
for i in range(len(files)):
im = np.array(
tf.keras.utils.load_img(
files[i], target_size=(args.img_height, args.img_width)
)
)
im = im.astype(np.float32, copy=False)
arr.append(im)
arr = tf.convert_to_tensor(np.array(arr, dtype="float32"))
ds = tf.data.Dataset.from_tensor_slices((arr)).batch(1)
return ds
def representative_dataset_gen(ds):
for input_value in ds.take(10):
yield [input_value]
# TODO: implement representative_dataset that takes in a directory of images and returns a dataset for int8 quantization
def export_tf_model(args):
"""Export TensorFlow model to TensorFlow Lite format.
Parameters
----------
args : ArgumentParser
Arguments parsed from command line.
"""
logger.info("================= Converting .h5 to .tflite ================")
# path to the TF Saved Model
if args.model_path.endswith(".h5") and os.path.isfile(args.model_path):
model = tf.keras.models.load_model(args.model_path)
else:
raise ValueError("TF Model must end with .h5")
# Convert the model
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model_name = ""
converter.optimizations = [tf.lite.Optimize.DEFAULT]
# use fp16 for increased accuracy
if args.fp16 and not args.int8:
logger.info("FP16 quantization enabled")
converter.target_spec.supported_types = [tf.float16]
tflite_model_name = "model_fp16.tflite"
# int8 quantization
elif args.int8 and not args.fp16:
logger.info("INT8 quantization enabled")
# ignore representative dataset if running a test run to check file sizes only
if args.test_dir == "" and not args.test_run:
logger.error(
"Please provide representative dataset for INT8 quantization and set image height and width accordingly"
)
raise ValueError(
"Please provide representative dataset for INT8 quantization"
)
# requires representative dataset for full integer quantization
if args.test_run:
logger.info("Running using dummy representative dataset")
converter.representative_dataset = partial(
test_representative_dataset_gen, args.img_height, args.img_width
)
else:
logger.info(
"Creating representative dataset using samples from {}".format(
args.test_dir
)
)
ds = load_test_data(args)
converter.representative_dataset = partial(representative_dataset_gen, ds)
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.uint8 # or tf.int8
converter.inference_output_type = tf.uint8 # or tf.int8
tflite_model_name = "model_int8.tflite"
else:
logger.error("Both fp16 and int8 are not supported")
raise NotImplementedError("Both fp16 and int8 are not supported")
tflite_model = converter.convert()
# Save the model.
root_dir = os.path.dirname(args.model_path)
with open(os.path.join(root_dir, tflite_model_name), "wb") as f:
f.write(tflite_model)
# get sizes of model
original_model_size = float(os.path.getsize(args.model_path)) / 1e6
tflite_model_size = float(
os.path.getsize(os.path.join(root_dir, tflite_model_name)) / 1e6
)
logger.info("==================== Stats =========================")
logger.info(
"TF Lite model saved at: {}".format(os.path.join(root_dir, tflite_model_name))
)
logger.info(f"Original TF Model size: {original_model_size} MB")
logger.info(f"TF Lite size: {tflite_model_size} MB")
logger.info(f"Reduction in size: {original_model_size - tflite_model_size:.2f} MB")
logger.info(
f"Reduction percentage: {(original_model_size - tflite_model_size) / original_model_size * 100:.2f} %"
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Export to TF Lite")
parser.add_argument(
"--model-path", type=str, required=True, help="Path to saved TF Model"
)
parser.add_argument(
"--img-height", type=int, default=100, help="Image height used for training"
)
parser.add_argument(
"--img-width", type=int, default=100, help="Image height used for training"
)
parser.add_argument("--fp16", action="store_true", help="Enable FP16 quantization")
parser.add_argument("--int8", action="store_true", help="Enable INT8 quantization")
parser.add_argument(
"--test-run", action="store_true", help="Use dummy representative dataset"
)
parser.add_argument("--test-dir", type=str, default="", help="Path to test data")
args = parser.parse_args()
if args.test_dir == "" and args.int8 and not args.test_run:
logger.error("INT8 quantization requires test data")
raise ValueError("INT8 quantization requires test data")
if args.int8:
logger.info(
"Set proper image height and width for INT8 quantization (default = 100)"
)
export_tf_model(args)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment