Created
January 14, 2024 15:05
-
-
Save hu8813/cfc2249768bda21b19b5ab808e17b801 to your computer and use it in GitHub Desktop.
Static_quant.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import cv2 | |
import os | |
from onnxruntime.quantization import quantize_static, CalibrationMethod, CalibrationDataReader, QuantType, QuantFormat | |
class ImageDataReader(CalibrationDataReader): | |
def __init__(self, image_folder, input_shape, num_samples=None): | |
self.image_folder = image_folder | |
self.input_shape = input_shape | |
self.images = os.listdir(image_folder) | |
if num_samples is not None: | |
self.images = self.images[:num_samples] | |
self.current_index = 0 | |
def get_next(self): | |
if self.current_index < len(self.images): | |
image_path = os.path.join(self.image_folder, self.images[self.current_index]) | |
input_data = self.preprocess_image(image_path) | |
self.current_index += 1 | |
return {'images': input_data} | |
else: | |
return None | |
def preprocess_image(self, image_path): | |
# Load and preprocess the image | |
image = cv2.imread(image_path) | |
image = cv2.resize(image, (self.input_shape[2], self.input_shape[3])) | |
image = image.transpose(2, 0, 1) # Change data layout from HWC to CHW | |
image = image.astype(np.float32) / 255.0 # Normalize | |
image = np.expand_dims(image, axis=0) # Add batch dimension | |
return image | |
# Update these paths and parameters | |
onnx_model_input_path = "model_20240112_1939-infer.onnx" | |
onnx_model_output_path = "model_20240112_1939_int8.onnx" | |
image_folder = "test_images" | |
input_shape = (1, 3, 640, 640) # Update if necessary | |
num_calibration_samples = 100 # Update if necessary | |
calibration_data_reader = ImageDataReader(image_folder, input_shape, num_samples=num_calibration_samples) | |
# Quantize the model to int8 | |
quantized_model = quantize_static( | |
model_input=onnx_model_input_path, | |
model_output=onnx_model_output_path, | |
calibration_data_reader=calibration_data_reader, | |
activation_type=QuantType.QInt8, | |
weight_type=QuantType.QInt8, | |
quant_format=QuantFormat.QDQ, | |
per_channel=False, | |
calibrate_method=CalibrationMethod.MinMax | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment