Skip to content

Instantly share code, notes, and snippets.

@Erol444
Created June 29, 2021 10:10
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Erol444/98b3fb8418f441eabc59c41da3803e84 to your computer and use it in GitHub Desktop.
Save Erol444/98b3fb8418f441eabc59c41da3803e84 to your computer and use it in GitHub Desktop.
DepthAI Mask/No mask YoloV3 AI training example
#!/usr/bin/env python3
"""
Tiny-yolo-v3 device side decoding demo
YOLO v3 Tiny is a real-time object detection model implemented with Keras* from
this repository <https://github.com/david8862/keras-YOLOv3-model-set> and converted
to TensorFlow* framework. This model was pretrained on COCO* dataset with 80 classes.
"""
from pathlib import Path
import sys
import cv2
import depthai as dai
import numpy as np
import time
# Get argument first
nnPath = str((Path(__file__).parent / Path('models/tiny-yolo-v3_openvino_2021.2_6shave.blob')).resolve().absolute())
if len(sys.argv) > 1:
nnPath = sys.argv[1]
if not Path(nnPath).exists():
import sys
raise FileNotFoundError(f'Required file/s not found, please run "{sys.executable} install_requirements.py"')
# Tiny yolo v3 label texts
labelMap = [
"mask", "no mask"
]
syncNN = True
# Create pipeline
pipeline = dai.Pipeline()
pipeline.setOpenVINOVersion(dai.OpenVINO.Version.VERSION_2021_3)
# Define sources and outputs
camRgb = pipeline.createColorCamera()
detectionNetwork = pipeline.createYoloDetectionNetwork()
xoutRgb = pipeline.createXLinkOut()
nnOut = pipeline.createXLinkOut()
xoutRgb.setStreamName("rgb")
nnOut.setStreamName("nn")
# Properties
camRgb.setPreviewSize(416, 416)
camRgb.setResolution(dai.ColorCameraProperties.SensorResolution.THE_1080_P)
camRgb.setInterleaved(False)
camRgb.setColorOrder(dai.ColorCameraProperties.ColorOrder.BGR)
camRgb.setFps(40)
# Network specific settings
detectionNetwork.setConfidenceThreshold(0.5)
detectionNetwork.setNumClasses(2)
detectionNetwork.setCoordinateSize(4)
detectionNetwork.setAnchors(np.array([10, 14, 23, 27, 37, 58, 81, 82, 135, 169, 344, 319]))
detectionNetwork.setAnchorMasks({"side26": np.array([1, 2, 3]), "side13": np.array([3, 4, 5])})
detectionNetwork.setIouThreshold(0.5)
detectionNetwork.setBlobPath(nnPath)
detectionNetwork.setNumInferenceThreads(2)
detectionNetwork.input.setBlocking(False)
# Linking
camRgb.preview.link(detectionNetwork.input)
if syncNN:
detectionNetwork.passthrough.link(xoutRgb.input)
else:
camRgb.preview.link(xoutRgb.input)
detectionNetwork.out.link(nnOut.input)
# Connect to device and start pipeline
with dai.Device(pipeline) as device:
# Output queues will be used to get the rgb frames and nn data from the outputs defined above
qRgb = device.getOutputQueue(name="rgb", maxSize=4, blocking=False)
qDet = device.getOutputQueue(name="nn", maxSize=4, blocking=False)
frame = None
detections = []
startTime = time.monotonic()
counter = 0
color2 = (255, 255, 255)
# nn data, being the bounding box locations, are in <0..1> range - they need to be normalized with frame width/height
def frameNorm(frame, bbox):
normVals = np.full(len(bbox), frame.shape[0])
normVals[::2] = frame.shape[1]
return (np.clip(np.array(bbox), 0, 1) * normVals).astype(int)
def displayFrame(name, frame):
color = (255, 0, 0)
for detection in detections:
bbox = frameNorm(frame, (detection.xmin, detection.ymin, detection.xmax, detection.ymax))
cv2.putText(frame, labelMap[detection.label], (bbox[0] + 10, bbox[1] + 20), cv2.FONT_HERSHEY_TRIPLEX, 0.5, 255)
cv2.putText(frame, f"{int(detection.confidence * 100)}%", (bbox[0] + 10, bbox[1] + 40), cv2.FONT_HERSHEY_TRIPLEX, 0.5, 255)
cv2.rectangle(frame, (bbox[0], bbox[1]), (bbox[2], bbox[3]), color, 2)
# Show the frame
cv2.imshow(name, frame)
while True:
if syncNN:
inRgb = qRgb.get()
inDet = qDet.get()
else:
inRgb = qRgb.tryGet()
inDet = qDet.tryGet()
if inRgb is not None:
frame = inRgb.getCvFrame()
cv2.putText(frame, "NN fps: {:.2f}".format(counter / (time.monotonic() - startTime)),
(2, frame.shape[0] - 4), cv2.FONT_HERSHEY_TRIPLEX, 0.4, color2)
if inDet is not None:
detections = inDet.detections
counter += 1
if frame is not None:
displayFrame("rgb", frame)
if cv2.waitKey(1) == ord('q'):
break
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment