Created
April 6, 2021 15:42
-
-
Save vkuzo/d863f53c8809198b3e0a4fd2af1563a7 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/detect.py b/detect.py | |
index 2a4d6f4..2c6a832 100644 | |
--- a/detect.py | |
+++ b/detect.py | |
@@ -14,6 +14,20 @@ from utils.general import check_img_size, check_requirements, check_imshow, non_ | |
from utils.plots import plot_one_box | |
from utils.torch_utils import select_device, load_classifier, time_synchronized | |
+class QuantizationModule(torch.nn.Module): | |
+ def __init__(self, model): | |
+ super(QuantizationModule, self).__init__() | |
+ self.model = model | |
+ self.quant = torch.quantization.QuantStub() | |
+ self.dequant = torch.quantization.DeQuantStub() | |
+ | |
+ def forward(self, x): | |
+ print(x) | |
+ x = self.quant(x) | |
+ x = self.model(x) | |
+ x = self.dequant(x) | |
+ return x | |
+ | |
def detect(save_img=False): | |
source, weights, view_img, save_txt, imgsz = opt.source, opt.weights, opt.view_img, opt.save_txt, opt.img_size | |
@@ -33,6 +47,21 @@ def detect(save_img=False): | |
# Load model | |
model = attempt_load(weights, map_location=device) # load FP32 model | |
stride = int(model.stride.max()) # model stride | |
+ # Get names and colors | |
+ names = model.module.names if hasattr(model, 'module') else model.names | |
+ colors = [[random.randint(0, 255) for _ in range(3)] for _ in names] | |
+ | |
+ print(model) | |
+ | |
+ # quantize (repro https://discuss.pytorch.org/t/static-quantization-for-yolov5-model/116446) | |
+ model = QuantizationModule(model) # model here is loaded by the code in repo in export.py | |
+ model.qconfig = torch.quantization.get_default_qconfig('qnnpack') | |
+ torch.backends.quantized.engine = "qnnpack" | |
+ model_static_quantized = torch.quantization.prepare(model, inplace=False) | |
+ # TODO: add calibration for PTQ | |
+ model_static_quantized = torch.quantization.convert(model_static_quantized.cpu(), inplace=False) | |
+ print(model_static_quantized) | |
+ | |
imgsz = check_img_size(imgsz, s=stride) # check img_size | |
if half: | |
model.half() # to FP16 | |
@@ -52,10 +81,6 @@ def detect(save_img=False): | |
else: | |
dataset = LoadImages(source, img_size=imgsz, stride=stride) | |
- # Get names and colors | |
- names = model.module.names if hasattr(model, 'module') else model.names | |
- colors = [[random.randint(0, 255) for _ in range(3)] for _ in names] | |
- | |
# Run inference | |
if device.type != 'cpu': | |
model(torch.zeros(1, 3, imgsz, imgsz).to(device).type_as(next(model.parameters()))) # run once | |
@@ -69,7 +94,11 @@ def detect(save_img=False): | |
# Inference | |
t1 = time_synchronized() | |
- pred = model(img, augment=opt.augment)[0] | |
+ # pred = model(img, augment=opt.augment)[0] | |
+ pred = model(img)[0] | |
+ | |
+ # try quantized forward | |
+ pred_q = model_static_quantized(img.float().cpu())[0] | |
# Apply NMS | |
pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres, classes=opt.classes, agnostic=opt.agnostic_nms) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment