Skip to content

Instantly share code, notes, and snippets.

@aurotripathy
Created April 18, 2023 20:36
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save aurotripathy/914701add7837a8d1fb114b3da1e1318 to your computer and use it in GitHub Desktop.
Save aurotripathy/914701add7837a8d1fb114b3da1e1318 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
"""
Calibration data/images is from the data folder
from https://github.com/derronqi/yolov7-face
"""
import sys
import onnx
import torch
import torchvision
from torchvision import transforms
import tqdm
from furiosa.optimizer import optimize_model
from furiosa.quantizer import quantize, Calibrator, CalibrationMethod
from utils.datasets import ImageDataset
from torch.utils.data import DataLoader
import numpy as np
from pudb import set_trace
def create_quantized_dfg():
model = onnx.load_model("model.onnx")
device = 'cpu'
source = 'data/images'
imgsz = 640
stride = 1
calib_dataset = ImageDataset(source, img_size=imgsz, stride=stride)
calib_dataloader = DataLoader(calib_dataset,
batch_size=1,
num_workers=1,
sampler=None,
pin_memory=True)
model = optimize_model(model)
calibrator = Calibrator(model, CalibrationMethod.MIN_MAX_ASYM)
for calibration_data, img in tqdm.tqdm(calib_dataloader,
desc="Calibration",
unit="images",
mininterval=0.5):
print(f'calibration data: {calibration_data}')
print(f'calibration data shape: {img.shape}')
calibration_data = torch.permute(img, (0, 3, 2, 1))
print(f'calibration data shape: {img.shape}')
# set_trace()
calibrator.collect_data([[img.numpy().astype(np.float32)]])
ranges = calibrator.compute_range()
model_quantized = quantize(model, ranges)
with open("model_quantized.dfg", "wb") as f:
f.write(bytes(model_quantized))
if __name__ == "__main__":
# set_trace()
sys.exit(create_quantized_dfg())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment