Skip to content

Instantly share code, notes, and snippets.

@jarutis
Created March 21, 2023 07:43
Show Gist options
  • Save jarutis/f57a3db7b4c37b59163a2ff5d8c8d54e to your computer and use it in GitHub Desktop.
Save jarutis/f57a3db7b4c37b59163a2ff5d8c8d54e to your computer and use it in GitHub Desktop.
YoloV8 Torchserve model handler
"""Custom TorchServe model handler for YOLOv8 models.
"""
from ts.torch_handler.base_handler import BaseHandler
import numpy as np
import base64
import torch
import torchvision.transforms as tf
import io
from PIL import Image
import cv2
class ModelHandler(BaseHandler):
"""
Model handler for YoloV8 bounding box model
"""
img_size = 640
"""Image size (px). Images will be resized to this resolution before inference.
"""
def __init__(self):
# call superclass initializer
super().__init__()
def preprocess(self, data):
"""Converts input images to float tensors.
Args:
data (List): Input data from the request in the form of a list of image tensors.
Returns:
Tensor: single Tensor of shape [BATCH_SIZE, 3, IMG_SIZE, IMG_SIZE]
"""
images = []
transform = tf.Compose([
tf.ToTensor(),
tf.Resize((self.img_size, self.img_size))
])
# handle if images are given in base64, etc.
for row in data:
# Compat layer: normally the envelope should just return the data
# directly, but older versions of Torchserve didn't have envelope.
image = row.get("data") or row.get("body")
if isinstance(image, str):
# if the image is a string of bytesarray.
image = base64.b64decode(image)
# If the image is sent as bytesarray
if isinstance(image, (bytearray, bytes)):
image = Image.open(io.BytesIO(image))
else:
# if the image is a list
image = torch.FloatTensor(image)
# force convert to tensor
# and resize to [img_size, img_size]
image = transform(image)
images.append(image)
# convert list of equal-size tensors to single stacked tensor
# has shape BATCH_SIZE x 3 x IMG_SIZE x IMG_SIZE
images_tensor = torch.stack(images).to(self.device)
return images_tensor
def postprocess(self, inference_output):
outputs = np.array([cv2.transpose(inference_output[0].numpy())])
rows = outputs.shape[1]
boxes = []
scores = []
class_ids = []
for i in range(rows):
classes_scores = outputs[0][i][4:]
(minScore, maxScore, minClassLoc, (x, maxClassIndex)) = cv2.minMaxLoc(classes_scores)
if maxScore >= 0.25:
box = [
outputs[0][i][0] - (0.5 * outputs[0][i][2]), outputs[0][i][1] - (0.5 * outputs[0][i][3]),
outputs[0][i][2], outputs[0][i][3]]
boxes.append(box)
scores.append(maxScore)
class_ids.append(maxClassIndex)
result_boxes = cv2.dnn.NMSBoxes(boxes, scores, 0.25, 0.45, 0.5)
detections = []
for i in range(len(result_boxes)):
index = result_boxes[i]
box = boxes[index]
detection = {
'class_id': class_ids[index],
'class_name': self.mapping[str(class_ids[index])],
'confidence': scores[index],
'box': [c.item() for c in box],
'scale': self.img_size / 640}
print(detection)
detections.append(detection)
# format each detection
return detections
@jarutis
Copy link
Author

jarutis commented Mar 21, 2023

@jarutis
Copy link
Author

jarutis commented Mar 21, 2023

Looking for a handler for segmentation in exchange :)

@belapyc
Copy link

belapyc commented Feb 4, 2024

Hi, Thank you for providing this! I am trying to use it with YoloV8 and torchserve and for some reason I get the following "number of inputs mismatched". I ve tried changing the handler script but it seems it's not doing anything.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment