Skip to content

Instantly share code, notes, and snippets.

@joek13
Created July 28, 2021 14:16
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save joek13/b895db0cd50a7c71a123611885057c69 to your computer and use it in GitHub Desktop.
Save joek13/b895db0cd50a7c71a123611885057c69 to your computer and use it in GitHub Desktop.
"""Custom TorchServe model handler for YOLOv5 models.
"""
from ts.torch_handler.base_handler import BaseHandler
import numpy as np
import base64
import torch
import torchvision.transforms as tf
import torchvision
import io
from PIL import Image
class ModelHandler(BaseHandler):
"""
A custom model handler implementation.
"""
img_size = 640
"""Image size (px). Images will be resized to this resolution before inference.
"""
def __init__(self):
# call superclass initializer
super().__init__()
def preprocess(self, data):
"""Converts input images to float tensors.
Args:
data (List): Input data from the request in the form of a list of image tensors.
Returns:
Tensor: single Tensor of shape [BATCH_SIZE, 3, IMG_SIZE, IMG_SIZE]
"""
images = []
transform = tf.Compose([
tf.ToTensor(),
tf.Resize((self.img_size, self.img_size))
])
# load images
# taken from https://github.com/pytorch/serve/blob/master/ts/torch_handler/vision_handler.py
# handle if images are given in base64, etc.
for row in data:
# Compat layer: normally the envelope should just return the data
# directly, but older versions of Torchserve didn't have envelope.
image = row.get("data") or row.get("body")
if isinstance(image, str):
# if the image is a string of bytesarray.
image = base64.b64decode(image)
# If the image is sent as bytesarray
if isinstance(image, (bytearray, bytes)):
image = Image.open(io.BytesIO(image))
else:
# if the image is a list
image = torch.FloatTensor(image)
# force convert to tensor
# and resize to [img_size, img_size]
image = transform(image)
images.append(image)
# convert list of equal-size tensors to single stacked tensor
# has shape BATCH_SIZE x 3 x IMG_SIZE x IMG_SIZE
images_tensor = torch.stack(images).to(self.device)
return images_tensor
def postprocess(self, inference_output):
# perform NMS (nonmax suppression) on model outputs
pred = non_max_suppression(inference_output[0])
# initialize empty list of detections for each image
detections = [[] for _ in range(len(pred))]
for i, image_detections in enumerate(pred): # axis 0: for each image
for det in image_detections: # axis 1: for each detection
# x1,y1,x2,y2 in normalized image coordinates (i.e. 0.0-1.0)
xyxy = det[:4] / self.img_size
# confidence value
conf = det[4].item()
# index of predicted class
class_idx = int(det[5].item())
# get label of predicted class
# if missing, then just return class idx
label = self.mapping.get(str(class_idx), class_idx)
detections[i].append({
"x1": xyxy[0].item(),
"y1": xyxy[1].item(),
"x2": xyxy[2].item(),
"y2": xyxy[3].item(),
"confidence": conf,
"class": label
})
# format each detection
return detections
def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False,
labels=(), max_det=300):
"""Runs Non-Maximum Suppression (NMS) on inference results
Returns:
list of detections, on (n,6) tensor per image [xyxy, conf, cls]
"""
nc = prediction.shape[2] - 5 # number of classes
xc = prediction[..., 4] > conf_thres # candidates
# Checks
assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'
# Settings
# (pixels) minimum and maximum box width and height
min_wh, max_wh = 2, 4096
max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
time_limit = 10.0 # seconds to quit after
redundant = True # require redundant detections
multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
merge = False # use merge-NMS
output = [torch.zeros((0, 6), device=prediction.device)
] * prediction.shape[0]
for xi, x in enumerate(prediction): # image index, image inference
# Apply constraints
# x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
x = x[xc[xi]] # confidence
# Cat apriori labels if autolabelling
if labels and len(labels[xi]):
l = labels[xi]
v = torch.zeros((len(l), nc + 5), device=x.device)
v[:, :4] = l[:, 1:5] # box
v[:, 4] = 1.0 # conf
v[range(len(l)), l[:, 0].long() + 5] = 1.0 # cls
x = torch.cat((x, v), 0)
# If none remain process next image
if not x.shape[0]:
continue
# Compute conf
x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
# Box (center x, center y, width, height) to (x1, y1, x2, y2)
box = xywh2xyxy(x[:, :4])
# Detections matrix nx6 (xyxy, conf, cls)
if multi_label:
i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
else: # best class only
conf, j = x[:, 5:].max(1, keepdim=True)
x = torch.cat((box, conf, j.float()), 1)[
conf.view(-1) > conf_thres]
# Filter by class
if classes is not None:
x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
# Apply finite constraint
# if not torch.isfinite(x).all():
# x = x[torch.isfinite(x).all(1)]
# Check shape
n = x.shape[0] # number of boxes
if not n: # no boxes
continue
elif n > max_nms: # excess boxes
# sort by confidence
x = x[x[:, 4].argsort(descending=True)[:max_nms]]
# Batched NMS
c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
# boxes (offset by class), scores
boxes, scores = x[:, :4] + c, x[:, 4]
i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
if i.shape[0] > max_det: # limit detections
i = i[:max_det]
if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
# update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
iou = torchvision.box_iou(
boxes[i], boxes) > iou_thres # iou matrix
weights = iou * scores[None] # box weights
x[i, :4] = torch.mm(weights, x[:, :4]).float(
) / weights.sum(1, keepdim=True) # merged boxes
if redundant:
i = i[iou.sum(1) > 1] # require redundancy
output[xi] = x[i]
return output
def xywh2xyxy(x):
# Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
return y
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment