Skip to content

Instantly share code, notes, and snippets.

@spradeep28
Last active November 29, 2023 17:44
Show Gist options
  • Save spradeep28/996ec57b83d30a7c9bbd80e3f0d08404 to your computer and use it in GitHub Desktop.
Save spradeep28/996ec57b83d30a7c9bbd80e3f0d08404 to your computer and use it in GitHub Desktop.
"""
Copyright (c) 2022 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
"""
This script is tested with TensorFlow v2.12.1 and OpenVINO v2023.1.0
Usage Example below (with required parameters):
python bit_ov_model_quantization.py
--gt_labels ./<path_to>/ground_truth_ilsvrc2012_val.txt
--dataset_dir ./<path-to-dataset>/ilsvrc2012_val_ds/
--bit_m_tf ./<path-to-tf>/model
--bit_ov_fp32 ./<path-to-ov>/fp32_ir_model
"""
import os, sys
from openvino.runtime import Core
import numpy as np
import argparse, os
import nncf
import openvino.runtime as ov
import pandas as pd
import re
import logging
logging.basicConfig(level=logging.ERROR)
import tensorflow as tf
import tensorflow_hub as hub
import tensorflow.compat.v2 as tf
from PIL import Image
from sklearn.metrics import accuracy_score
ie = Core()
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
# For top 1 labels.
MAX_PREDS = 1
BATCH_SIZE = 1
IMG_SIZE = (224, 224) # Default Imagenet image size
NUM_CLASSES = 1000 # For Imagenette dataset
# Data transform function for NNCF calibration
def nncf_transform(image, label):
image = tf.io.decode_jpeg(tf.io.read_file(image), channels=3)
image = tf.image.resize(image, IMG_SIZE)
return image
# Data transform function for imagenet ds validation
def val_transform(image_path, label):
image = tf.io.decode_jpeg(tf.io.read_file(image_path), channels=3)
image = tf.image.resize(image, IMG_SIZE)
img_reshaped = tf.reshape(image, [IMG_SIZE[0], IMG_SIZE[1], 3])
image = tf.image.convert_image_dtype(img_reshaped, tf.float32)
return image, label
# Validation dataset split
def get_val_data_split(tf_dataset_, train_split=0.7, val_split=0.3, \
shuffle=True, shuffle_size=50000):
if shuffle:
ds = tf_dataset_.shuffle(shuffle_size, seed=12)
train_size = int(train_split * shuffle_size)
val_size = int(val_split * shuffle_size)
val_ds = ds.skip(train_size).take(val_size)
return val_ds
# OpenVINO IR model inference validation
def ov_infer_validate(model: ov.Model,
val_loader: tf.data.Dataset) -> tf.Tensor:
model.reshape([1,IMG_SIZE[0],IMG_SIZE[1],3]) # If MO ran with Dynamic batching
compiled_model = ov.compile_model(model)
output = compiled_model.outputs[0]
ov_predictions = []
for img, label in val_loader:#.take(25000):#.take(5000):#.take(5):
pred = compiled_model(img)[output]
ov_result = tf.reshape(pred, [-1])
top_label_idx = np.argsort(ov_result)[-MAX_PREDS::][::-1]
ov_predictions.append(top_label_idx)
return ov_predictions
# OpenVINO IR model NNCF Quantization
def quantize(ov_model, calibration_dataset): #, val_loader: tf.data.Dataset):
print("Started NNCF qunatization process")
ov_quantized_model = nncf.quantize(ov_model, calibration_dataset, fast_bias_correction=False)
return ov_quantized_model
# OpenVINO FP32 IR model inference
def ov_fp32_predictions(ov_fp32_model, validation_dataset):
# Load and compile the OV model
ov_model = ie.read_model(ov_fp32_model)
print("Starting OV FP32 Model Inference...!!!")
ov_fp32_pred = ov_infer_validate(ov_model, validation_dataset)
return ov_fp32_pred
def nncf_quantize_int8_pred_results(ov_fp32_model, calibration_dataset, \
validation_dataset, ov_int8_model):
# Load and compile the OV model
ov_model = ie.read_model(ov_fp32_model)
# NNCF Quantization of OpenVINO IR model
int8_ov_model = quantize(ov_model, calibration_dataset)
ov.serialize(int8_ov_model, ov_int8_model)
print("NNCF Quantization Process completed..!!!")
ov_int8_model = ie.read_model(ov_int8_model)
print("Starting OV INT8 Model Inference...!!!")
ov_int8_pred = ov_infer_validate(ov_int8_model, validation_dataset)
return ov_int8_pred
def tf_inference(tf_saved_model_path, val_loader: tf.data.Dataset):
tf_model = tf.keras.models.load_model(tf_saved_model_path)
print("Starting TF FP32 Model Inference...!!!")
tf_predictions = []
for img, label in val_loader:
tf_result = tf_model.predict(img, verbose=0)
tf_result = tf.reshape(tf_result, [-1])
top5_label_idx = np.argsort(tf_result)[-MAX_PREDS::][::-1]
tf_predictions.append(top5_label_idx)
return tf_predictions
"""
Module: bit_classificaiton
Description: API to run BiT classificaiton OpenVINO IR model INT8 Quantization on using NNCF and
perfom accuracy metrics for TF FP32, OV FP32 and OV INT8 on ImageNet2012 Validation dataset
"""
def bit_classification(args):
ip_shape = args.inp_shape
if isinstance(ip_shape, str):
ip_shape = [int(i) for i in ip_shape.split(",")]
if len(ip_shape) != 4:
sys.exit( "Input shape error. Set shape 'N,W,H,C'. For example: '1,224,224,3' " )
# Imagenet2012 validataion dataset used for TF and OV FP32 accuracy testing.
#dataset_dir = ../dataset/ilsvrc2012_val/1.0/ + "*.JPEG"
dataset_dir = args.dataset_dir + "*.JPEG"
tf_dataset = tf.data.Dataset.list_files(dataset_dir)
gt_lables = open(args.gt_labels)
val_labels = []
for l in gt_lables:
val_labels.append(str(l))
# Generating ImageNet 2012 validation dataset dictionary (img, label)
val_images = []
val_labels_in_img_order = []
for i, v in enumerate(tf_dataset):
img_path = str(v.numpy())
id = int(img_path.split('/')[-1].split('_')[-1].split('.')[0])
val_images.append(img_path[2:-1])
val_labels_in_img_order.append(int(re.sub(r'\n','',val_labels[id-1])))
val_df = pd.DataFrame(data={'images': val_images, 'label': val_labels_in_img_order})
# Converting imagenet2012 val dictionary into tf.data.Dataset
tf_dataset_ = tf.data.Dataset.from_tensor_slices((list(val_df['images'].values), val_df['label'].values))
imgnet2012_val_dataset = tf_dataset_.map(val_transform).batch(BATCH_SIZE)
# TF Dataset split for nncf calibration
img2012_val_split_for_calib = get_val_data_split(tf_dataset_, train_split=0.7, \
val_split=0.3, shuffle=True, \
shuffle_size=50000)
img2012_val_split_for_calib = img2012_val_split_for_calib.map(nncf_transform).batch(BATCH_SIZE)
# TF Model Inference
tf_model_path = args.bit_m_tf
print(f"Tensorflow FP32 Model {args.bit_m_tf}")
tf_p = tf_inference(tf_model_path, imgnet2012_val_dataset)
#acc_score = accuracy_score(tf_pred, val_labels_in_img_order[0:25000])
acc_score = accuracy_score(tf_p, val_labels_in_img_order)
print(f"Accuracy of FP32 TF model = {acc_score}\n")
# OpenVINO Model Inference
print(f"OpenVINO FP32 IR Model {args.bit_ov_fp32}")
ov_fp32_p = ov_fp32_predictions(args.bit_ov_fp32, imgnet2012_val_dataset)
acc_score = accuracy_score(ov_fp32_p, val_labels_in_img_order)
print(f"Accuracy of FP32 IR model = {acc_score}\n")
print("Starting NNCF dataset Calibration....!!!")
calibration_dataset = nncf.Dataset(img2012_val_split_for_calib)
# OpenVINO IR FP32 to INT8 Model Quantization with NNCF and
# INT8 predictions results on validation dataset
ov_int8_p = nncf_quantize_int8_pred_results(args.bit_ov_fp32, calibration_dataset, \
imgnet2012_val_dataset, args.bit_ov_int8)
print(f"OpenVINO NNCF Quantized INT8 IR Model {args.bit_ov_int8}")
acc_score = accuracy_score(ov_int8_p, val_labels_in_img_order)
print(f"Accuracy of INT8 IR model = {acc_score}\n")
#acc_score = accuracy_score(tf_p, ov_fp32_p)
#print(f"TF Vs OV FP32 Accuracy Score = {acc_score}")
#acc_score = accuracy_score(ov_fp32_p, ov_int8_p)
#print(f"OV FP32 Vs OV INT8 Accuracy Score = {acc_score}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="BiT Classification model quantization and accuracy measurement")
optional = parser._action_groups.pop()
required=parser.add_argument_group("required arguments")
optional.add_argument("--inp_shape", type=str, help="N,W,H,C", default="1,224,224,3", required=False)
required.add_argument("--dataset_dir", type=str, help="Directory path to ImageNet2012 validation dataset", required=True)
required.add_argument("--gt_labels", type=str, help="Path to ImageNet2012 validation ds gt labels file", required=True)
required.add_argument("--bit_m_tf", type=str, help="Path to BiT TF fp32 model file", required=True)
required.add_argument("--bit_ov_fp32", type=str, help="Path to BiT OpenVINO fp32 model file", required=True)
optional.add_argument("--bit_ov_int8", type=str, help="Path to save BiT OpenVINO INT8 model file",
default="./bit_m_r50x1_1/ov/int8/saved_model.xml", required=False)
parser._action_groups.append(optional)
args = parser.parse_args()
bit_classification(args)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment