Skip to content

Instantly share code, notes, and snippets.

@masouduut94
Last active May 25, 2020 20:17
Show Gist options
  • Save masouduut94/74cd1fbef4b96dddd0aba99aea7cf4ab to your computer and use it in GitHub Desktop.
Save masouduut94/74cd1fbef4b96dddd0aba99aea7cf4ab to your computer and use it in GitHub Desktop.
pedestrian detection service on stream video provided by rtsp link
import warnings
from os import getenv
import sys
from os.path import dirname, abspath
sys.path.append(dirname(dirname(abspath(__file__))))
import torch
from deep_sort import build_tracker
from detector import build_detector
import cv2
from utils.draw import compute_color_for_labels
from concurrent.futures import ThreadPoolExecutor
from redis import Redis
redis_cache = Redis('127.0.0.1')
class RealTimeTracking(object):
"""
This class is built to get frame from rtsp link and continuously
assign each frame to an attribute namely as frame in order to
compensate the network packet loss. then we use flask to give it
as service to client.
Args:
args: parse_args inputs
cfg: deepsort dict and yolo-model cfg from server_cfg file
"""
def __init__(self, cfg, args):
# Create a VideoCapture object
self.cfg = cfg
self.args = args
use_cuda = self.args.use_cuda and torch.cuda.is_available()
if not use_cuda:
warnings.warn(UserWarning("Running in cpu mode!"))
self.detector = build_detector(cfg, use_cuda=use_cuda)
self.deepsort = build_tracker(cfg, use_cuda=use_cuda)
self.class_names = self.detector.class_names
self.vdo = cv2.VideoCapture(self.args.input)
self.status, self.frame = None, None
self.im_width = int(self.vdo.get(cv2.CAP_PROP_FRAME_WIDTH))
self.im_height = int(self.vdo.get(cv2.CAP_PROP_FRAME_HEIGHT))
self.output_frame = None
self.thread = ThreadPoolExecutor(max_workers=1)
self.thread.submit(self.update)
def update(self):
"""
Repeatedly reading frames from camera using threads
"""
while True:
if self.vdo.isOpened():
(self.status, self.frame) = self.vdo.read()
def run(self):
"""
Until the in_progress flag is not set as off, keep reading and processing
pedestrian detection.
"""
print('streaming started ...')
while getenv('in_progress') != 'off':
try:
frame = self.frame.copy()
self.detection(frame=frame)
frame_to_bytes = cv2.imencode('.jpg', frame)[1].tobytes()
redis_cache.set('frame', frame_to_bytes)
except AttributeError:
pass
print('streaming stopped ...')
def detection(self, frame):
"""
Detection, tracking and drawing the bboxes on pedestrians
"""
im = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# do detection
bbox_xywh, cls_conf, cls_ids = self.detector(im)
if bbox_xywh is not None:
# select person class
mask = cls_ids == 0
bbox_xywh = bbox_xywh[mask]
bbox_xywh[:, 3:] *= 1.2 # bbox dilation just in case bbox too small
cls_conf = cls_conf[mask]
# do tracking
outputs = self.deepsort.update(bbox_xywh, cls_conf, im)
# draw boxes for visualization
if len(outputs) > 0:
self.draw_boxes(img=frame, output=outputs)
@staticmethod
def draw_boxes(img, output, offset=(0, 0)):
# Draws bboxes on the detcted pedestrians
for i, box in enumerate(output):
x1, y1, x2, y2, identity = [int(ii) for ii in box]
x1 += offset[0]
x2 += offset[0]
y1 += offset[1]
y2 += offset[1]
# box text and bar
color = compute_color_for_labels(identity)
label = '{}{:d}'.format("", identity)
t_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_PLAIN, 2, 2)[0]
cv2.rectangle(img, (x1, y1), (x2, y2), color, 3)
cv2.rectangle(img, (x1, y1), (x1 + t_size[0] + 3, y1 + t_size[1] + 4), color, -1)
cv2.putText(img, label, (x1, y1 + t_size[1] + 4), cv2.FONT_HERSHEY_PLAIN, 2, [255, 255, 255], 2)
return img
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment