-
-
Save Erol444/98b3fb8418f441eabc59c41da3803e84 to your computer and use it in GitHub Desktop.
DepthAI Mask/No mask YoloV3 AI training example
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
""" | |
Tiny-yolo-v3 device side decoding demo | |
YOLO v3 Tiny is a real-time object detection model implemented with Keras* from | |
this repository <https://github.com/david8862/keras-YOLOv3-model-set> and converted | |
to TensorFlow* framework. This model was pretrained on COCO* dataset with 80 classes. | |
""" | |
from pathlib import Path | |
import sys | |
import cv2 | |
import depthai as dai | |
import numpy as np | |
import time | |
# Get argument first | |
nnPath = str((Path(__file__).parent / Path('models/tiny-yolo-v3_openvino_2021.2_6shave.blob')).resolve().absolute()) | |
if len(sys.argv) > 1: | |
nnPath = sys.argv[1] | |
if not Path(nnPath).exists(): | |
import sys | |
raise FileNotFoundError(f'Required file/s not found, please run "{sys.executable} install_requirements.py"') | |
# Tiny yolo v3 label texts | |
labelMap = [ | |
"mask", "no mask" | |
] | |
syncNN = True | |
# Create pipeline | |
pipeline = dai.Pipeline() | |
pipeline.setOpenVINOVersion(dai.OpenVINO.Version.VERSION_2021_3) | |
# Define sources and outputs | |
camRgb = pipeline.createColorCamera() | |
detectionNetwork = pipeline.createYoloDetectionNetwork() | |
xoutRgb = pipeline.createXLinkOut() | |
nnOut = pipeline.createXLinkOut() | |
xoutRgb.setStreamName("rgb") | |
nnOut.setStreamName("nn") | |
# Properties | |
camRgb.setPreviewSize(416, 416) | |
camRgb.setResolution(dai.ColorCameraProperties.SensorResolution.THE_1080_P) | |
camRgb.setInterleaved(False) | |
camRgb.setColorOrder(dai.ColorCameraProperties.ColorOrder.BGR) | |
camRgb.setFps(40) | |
# Network specific settings | |
detectionNetwork.setConfidenceThreshold(0.5) | |
detectionNetwork.setNumClasses(2) | |
detectionNetwork.setCoordinateSize(4) | |
detectionNetwork.setAnchors(np.array([10, 14, 23, 27, 37, 58, 81, 82, 135, 169, 344, 319])) | |
detectionNetwork.setAnchorMasks({"side26": np.array([1, 2, 3]), "side13": np.array([3, 4, 5])}) | |
detectionNetwork.setIouThreshold(0.5) | |
detectionNetwork.setBlobPath(nnPath) | |
detectionNetwork.setNumInferenceThreads(2) | |
detectionNetwork.input.setBlocking(False) | |
# Linking | |
camRgb.preview.link(detectionNetwork.input) | |
if syncNN: | |
detectionNetwork.passthrough.link(xoutRgb.input) | |
else: | |
camRgb.preview.link(xoutRgb.input) | |
detectionNetwork.out.link(nnOut.input) | |
# Connect to device and start pipeline | |
with dai.Device(pipeline) as device: | |
# Output queues will be used to get the rgb frames and nn data from the outputs defined above | |
qRgb = device.getOutputQueue(name="rgb", maxSize=4, blocking=False) | |
qDet = device.getOutputQueue(name="nn", maxSize=4, blocking=False) | |
frame = None | |
detections = [] | |
startTime = time.monotonic() | |
counter = 0 | |
color2 = (255, 255, 255) | |
# nn data, being the bounding box locations, are in <0..1> range - they need to be normalized with frame width/height | |
def frameNorm(frame, bbox): | |
normVals = np.full(len(bbox), frame.shape[0]) | |
normVals[::2] = frame.shape[1] | |
return (np.clip(np.array(bbox), 0, 1) * normVals).astype(int) | |
def displayFrame(name, frame): | |
color = (255, 0, 0) | |
for detection in detections: | |
bbox = frameNorm(frame, (detection.xmin, detection.ymin, detection.xmax, detection.ymax)) | |
cv2.putText(frame, labelMap[detection.label], (bbox[0] + 10, bbox[1] + 20), cv2.FONT_HERSHEY_TRIPLEX, 0.5, 255) | |
cv2.putText(frame, f"{int(detection.confidence * 100)}%", (bbox[0] + 10, bbox[1] + 40), cv2.FONT_HERSHEY_TRIPLEX, 0.5, 255) | |
cv2.rectangle(frame, (bbox[0], bbox[1]), (bbox[2], bbox[3]), color, 2) | |
# Show the frame | |
cv2.imshow(name, frame) | |
while True: | |
if syncNN: | |
inRgb = qRgb.get() | |
inDet = qDet.get() | |
else: | |
inRgb = qRgb.tryGet() | |
inDet = qDet.tryGet() | |
if inRgb is not None: | |
frame = inRgb.getCvFrame() | |
cv2.putText(frame, "NN fps: {:.2f}".format(counter / (time.monotonic() - startTime)), | |
(2, frame.shape[0] - 4), cv2.FONT_HERSHEY_TRIPLEX, 0.4, color2) | |
if inDet is not None: | |
detections = inDet.detections | |
counter += 1 | |
if frame is not None: | |
displayFrame("rgb", frame) | |
if cv2.waitKey(1) == ord('q'): | |
break |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment