Skip to content

Instantly share code, notes, and snippets.

@hu8813
Created January 14, 2024 15:05
Show Gist options
  • Save hu8813/cfc2249768bda21b19b5ab808e17b801 to your computer and use it in GitHub Desktop.
Save hu8813/cfc2249768bda21b19b5ab808e17b801 to your computer and use it in GitHub Desktop.
Static_quant.py
import numpy as np
import cv2
import os
from onnxruntime.quantization import quantize_static, CalibrationMethod, CalibrationDataReader, QuantType, QuantFormat
class ImageDataReader(CalibrationDataReader):
def __init__(self, image_folder, input_shape, num_samples=None):
self.image_folder = image_folder
self.input_shape = input_shape
self.images = os.listdir(image_folder)
if num_samples is not None:
self.images = self.images[:num_samples]
self.current_index = 0
def get_next(self):
if self.current_index < len(self.images):
image_path = os.path.join(self.image_folder, self.images[self.current_index])
input_data = self.preprocess_image(image_path)
self.current_index += 1
return {'images': input_data}
else:
return None
def preprocess_image(self, image_path):
# Load and preprocess the image
image = cv2.imread(image_path)
image = cv2.resize(image, (self.input_shape[2], self.input_shape[3]))
image = image.transpose(2, 0, 1) # Change data layout from HWC to CHW
image = image.astype(np.float32) / 255.0 # Normalize
image = np.expand_dims(image, axis=0) # Add batch dimension
return image
# Update these paths and parameters
onnx_model_input_path = "model_20240112_1939-infer.onnx"
onnx_model_output_path = "model_20240112_1939_int8.onnx"
image_folder = "test_images"
input_shape = (1, 3, 640, 640) # Update if necessary
num_calibration_samples = 100 # Update if necessary
calibration_data_reader = ImageDataReader(image_folder, input_shape, num_samples=num_calibration_samples)
# Quantize the model to int8
quantized_model = quantize_static(
model_input=onnx_model_input_path,
model_output=onnx_model_output_path,
calibration_data_reader=calibration_data_reader,
activation_type=QuantType.QInt8,
weight_type=QuantType.QInt8,
quant_format=QuantFormat.QDQ,
per_channel=False,
calibrate_method=CalibrationMethod.MinMax
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment