Skip to content

Instantly share code, notes, and snippets.

@vkuzo
Created April 6, 2021 15:42
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save vkuzo/d863f53c8809198b3e0a4fd2af1563a7 to your computer and use it in GitHub Desktop.
Save vkuzo/d863f53c8809198b3e0a4fd2af1563a7 to your computer and use it in GitHub Desktop.
diff --git a/detect.py b/detect.py
index 2a4d6f4..2c6a832 100644
--- a/detect.py
+++ b/detect.py
@@ -14,6 +14,20 @@ from utils.general import check_img_size, check_requirements, check_imshow, non_
from utils.plots import plot_one_box
from utils.torch_utils import select_device, load_classifier, time_synchronized
+class QuantizationModule(torch.nn.Module):
+ def __init__(self, model):
+ super(QuantizationModule, self).__init__()
+ self.model = model
+ self.quant = torch.quantization.QuantStub()
+ self.dequant = torch.quantization.DeQuantStub()
+
+ def forward(self, x):
+ print(x)
+ x = self.quant(x)
+ x = self.model(x)
+ x = self.dequant(x)
+ return x
+
def detect(save_img=False):
source, weights, view_img, save_txt, imgsz = opt.source, opt.weights, opt.view_img, opt.save_txt, opt.img_size
@@ -33,6 +47,21 @@ def detect(save_img=False):
# Load model
model = attempt_load(weights, map_location=device) # load FP32 model
stride = int(model.stride.max()) # model stride
+ # Get names and colors
+ names = model.module.names if hasattr(model, 'module') else model.names
+ colors = [[random.randint(0, 255) for _ in range(3)] for _ in names]
+
+ print(model)
+
+ # quantize (repro https://discuss.pytorch.org/t/static-quantization-for-yolov5-model/116446)
+ model = QuantizationModule(model) # model here is loaded by the code in repo in export.py
+ model.qconfig = torch.quantization.get_default_qconfig('qnnpack')
+ torch.backends.quantized.engine = "qnnpack"
+ model_static_quantized = torch.quantization.prepare(model, inplace=False)
+ # TODO: add calibration for PTQ
+ model_static_quantized = torch.quantization.convert(model_static_quantized.cpu(), inplace=False)
+ print(model_static_quantized)
+
imgsz = check_img_size(imgsz, s=stride) # check img_size
if half:
model.half() # to FP16
@@ -52,10 +81,6 @@ def detect(save_img=False):
else:
dataset = LoadImages(source, img_size=imgsz, stride=stride)
- # Get names and colors
- names = model.module.names if hasattr(model, 'module') else model.names
- colors = [[random.randint(0, 255) for _ in range(3)] for _ in names]
-
# Run inference
if device.type != 'cpu':
model(torch.zeros(1, 3, imgsz, imgsz).to(device).type_as(next(model.parameters()))) # run once
@@ -69,7 +94,11 @@ def detect(save_img=False):
# Inference
t1 = time_synchronized()
- pred = model(img, augment=opt.augment)[0]
+ # pred = model(img, augment=opt.augment)[0]
+ pred = model(img)[0]
+
+ # try quantized forward
+ pred_q = model_static_quantized(img.float().cpu())[0]
# Apply NMS
pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres, classes=opt.classes, agnostic=opt.agnostic_nms)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment