-
-
Save joek13/b895db0cd50a7c71a123611885057c69 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Custom TorchServe model handler for YOLOv5 models. | |
""" | |
from ts.torch_handler.base_handler import BaseHandler | |
import numpy as np | |
import base64 | |
import torch | |
import torchvision.transforms as tf | |
import torchvision | |
import io | |
from PIL import Image | |
class ModelHandler(BaseHandler): | |
""" | |
A custom model handler implementation. | |
""" | |
img_size = 640 | |
"""Image size (px). Images will be resized to this resolution before inference. | |
""" | |
def __init__(self): | |
# call superclass initializer | |
super().__init__() | |
def preprocess(self, data): | |
"""Converts input images to float tensors. | |
Args: | |
data (List): Input data from the request in the form of a list of image tensors. | |
Returns: | |
Tensor: single Tensor of shape [BATCH_SIZE, 3, IMG_SIZE, IMG_SIZE] | |
""" | |
images = [] | |
transform = tf.Compose([ | |
tf.ToTensor(), | |
tf.Resize((self.img_size, self.img_size)) | |
]) | |
# load images | |
# taken from https://github.com/pytorch/serve/blob/master/ts/torch_handler/vision_handler.py | |
# handle if images are given in base64, etc. | |
for row in data: | |
# Compat layer: normally the envelope should just return the data | |
# directly, but older versions of Torchserve didn't have envelope. | |
image = row.get("data") or row.get("body") | |
if isinstance(image, str): | |
# if the image is a string of bytesarray. | |
image = base64.b64decode(image) | |
# If the image is sent as bytesarray | |
if isinstance(image, (bytearray, bytes)): | |
image = Image.open(io.BytesIO(image)) | |
else: | |
# if the image is a list | |
image = torch.FloatTensor(image) | |
# force convert to tensor | |
# and resize to [img_size, img_size] | |
image = transform(image) | |
images.append(image) | |
# convert list of equal-size tensors to single stacked tensor | |
# has shape BATCH_SIZE x 3 x IMG_SIZE x IMG_SIZE | |
images_tensor = torch.stack(images).to(self.device) | |
return images_tensor | |
def postprocess(self, inference_output): | |
# perform NMS (nonmax suppression) on model outputs | |
pred = non_max_suppression(inference_output[0]) | |
# initialize empty list of detections for each image | |
detections = [[] for _ in range(len(pred))] | |
for i, image_detections in enumerate(pred): # axis 0: for each image | |
for det in image_detections: # axis 1: for each detection | |
# x1,y1,x2,y2 in normalized image coordinates (i.e. 0.0-1.0) | |
xyxy = det[:4] / self.img_size | |
# confidence value | |
conf = det[4].item() | |
# index of predicted class | |
class_idx = int(det[5].item()) | |
# get label of predicted class | |
# if missing, then just return class idx | |
label = self.mapping.get(str(class_idx), class_idx) | |
detections[i].append({ | |
"x1": xyxy[0].item(), | |
"y1": xyxy[1].item(), | |
"x2": xyxy[2].item(), | |
"y2": xyxy[3].item(), | |
"confidence": conf, | |
"class": label | |
}) | |
# format each detection | |
return detections | |
def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False, | |
labels=(), max_det=300): | |
"""Runs Non-Maximum Suppression (NMS) on inference results | |
Returns: | |
list of detections, on (n,6) tensor per image [xyxy, conf, cls] | |
""" | |
nc = prediction.shape[2] - 5 # number of classes | |
xc = prediction[..., 4] > conf_thres # candidates | |
# Checks | |
assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0' | |
assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0' | |
# Settings | |
# (pixels) minimum and maximum box width and height | |
min_wh, max_wh = 2, 4096 | |
max_nms = 30000 # maximum number of boxes into torchvision.ops.nms() | |
time_limit = 10.0 # seconds to quit after | |
redundant = True # require redundant detections | |
multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img) | |
merge = False # use merge-NMS | |
output = [torch.zeros((0, 6), device=prediction.device) | |
] * prediction.shape[0] | |
for xi, x in enumerate(prediction): # image index, image inference | |
# Apply constraints | |
# x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height | |
x = x[xc[xi]] # confidence | |
# Cat apriori labels if autolabelling | |
if labels and len(labels[xi]): | |
l = labels[xi] | |
v = torch.zeros((len(l), nc + 5), device=x.device) | |
v[:, :4] = l[:, 1:5] # box | |
v[:, 4] = 1.0 # conf | |
v[range(len(l)), l[:, 0].long() + 5] = 1.0 # cls | |
x = torch.cat((x, v), 0) | |
# If none remain process next image | |
if not x.shape[0]: | |
continue | |
# Compute conf | |
x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf | |
# Box (center x, center y, width, height) to (x1, y1, x2, y2) | |
box = xywh2xyxy(x[:, :4]) | |
# Detections matrix nx6 (xyxy, conf, cls) | |
if multi_label: | |
i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T | |
x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1) | |
else: # best class only | |
conf, j = x[:, 5:].max(1, keepdim=True) | |
x = torch.cat((box, conf, j.float()), 1)[ | |
conf.view(-1) > conf_thres] | |
# Filter by class | |
if classes is not None: | |
x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)] | |
# Apply finite constraint | |
# if not torch.isfinite(x).all(): | |
# x = x[torch.isfinite(x).all(1)] | |
# Check shape | |
n = x.shape[0] # number of boxes | |
if not n: # no boxes | |
continue | |
elif n > max_nms: # excess boxes | |
# sort by confidence | |
x = x[x[:, 4].argsort(descending=True)[:max_nms]] | |
# Batched NMS | |
c = x[:, 5:6] * (0 if agnostic else max_wh) # classes | |
# boxes (offset by class), scores | |
boxes, scores = x[:, :4] + c, x[:, 4] | |
i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS | |
if i.shape[0] > max_det: # limit detections | |
i = i[:max_det] | |
if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean) | |
# update boxes as boxes(i,4) = weights(i,n) * boxes(n,4) | |
iou = torchvision.box_iou( | |
boxes[i], boxes) > iou_thres # iou matrix | |
weights = iou * scores[None] # box weights | |
x[i, :4] = torch.mm(weights, x[:, :4]).float( | |
) / weights.sum(1, keepdim=True) # merged boxes | |
if redundant: | |
i = i[iou.sum(1) > 1] # require redundancy | |
output[xi] = x[i] | |
return output | |
def xywh2xyxy(x): | |
# Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right | |
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) | |
y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x | |
y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y | |
y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x | |
y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y | |
return y |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment