Last active
February 20, 2021 20:08
-
-
Save smellslikeml/635da79198e230621f29f9ac36a17e53 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import cv2 | |
import numpy as np | |
from collections import deque | |
# https://github.com/elliottzheng/face-detection | |
from face_detection import RetinaFace | |
# https://pyscenedetect.readthedocs.io/projects/Manual/en/latest/api/scene_manager.html#scenemanager-example | |
from scenedetect import VideoManager | |
from scenedetect import SceneManager | |
from scenedetect.detectors import ContentDetector | |
class SceneTracker(object): | |
def __init__(self, min_size, expiration): | |
self.minNeighbors = 1 | |
self.scaleFactor = 1.05 | |
self.deviationFactor = 0.4 | |
self.min_size = int(min_size) | |
self.expiration = expiration | |
def box_verify(self, box): | |
box = list(map(lambda x: int(max((0, x))), box)) | |
if np.sum(box): | |
x1, y1, x2, y2 = box | |
h = x2 - x1 | |
w = y2 - y1 | |
if w > self.min_size and np.abs(h/w - 1) < self.deviationFactor: | |
return box | |
else: | |
return [] | |
def obj_verify(self, patch, obj_classifier=cv2.CascadeClassifier("haarcascade_frontalface_default.xml")): | |
g = cv2.cvtColor(patch, cv2.COLOR_BGR2GRAY) | |
det = obj_classifier.detectMultiScale(g, scaleFactor=self.scaleFactor, minNeighbors=self.minNeighbors, minSize=(self.min_size, self.min_size)) | |
return isinstance(det, np.ndarray) | |
def init(self, idx, frame, boxes, tracker=cv2.TrackerCSRT_create): | |
self.tracker_dict = {} | |
for box in boxes: | |
b = self.box_verify(box) | |
if b: | |
x1, y1, x2, y2 = b | |
t = tracker() | |
tracker_id = t.__str__().split()[1][:-1] | |
try: | |
t.init(frame, box) | |
q = deque(maxlen=self.expiration) | |
q.append(self.obj_verify(frame[y1:y2, x1:x2, :])) | |
if len(q) == self.expiration: | |
tracker_dict.pop(tracker_id) | |
else: | |
self.tracker_dict[tracker_id] = (t, [(idx, box)], q) | |
except cv2.error: | |
pass | |
return | |
def update(self, idx, frame): | |
for tracker_id in list(self.tracker_dict.keys()): | |
(success, box) = self.tracker_dict[tracker_id][0].update(frame) | |
b = self.box_verify(box) | |
if b: | |
x1, y1, x2, y2 = b | |
self.tracker_dict[tracker_id][1].append((idx, box)) | |
self.tracker_dict[tracker_id][2].append(self.obj_verify(frame[y1:y2, x1:x2, :])) | |
return | |
class Video(object): | |
def __init__(self, source, model=RetinaFace(), expiration=5, down_scale=100, skip_rate=100): | |
self.show = True | |
self.source = source | |
self.scene_index = {} | |
self.skip_rate = skip_rate | |
self.down_scale = down_scale | |
self.expiration = expiration | |
self.model = model | |
def scene_idx(self, threshold=30.0): | |
video_manager = VideoManager([self.source]) | |
scene_manager = SceneManager() | |
scene_manager.add_detector( | |
ContentDetector(threshold=threshold)) | |
video_manager.set_downscale_factor(self.down_scale) | |
video_manager.start() | |
scene_manager.detect_scenes(frame_source=video_manager) | |
self.scene_index = {idx:(scene[0].get_frames(), scene[1].get_frames()) for idx, scene in enumerate(scene_manager.get_scene_list())} | |
return self.scene_index | |
def load_video(self, skip_rate): | |
res = [] | |
cap = cv2.VideoCapture(self.source) | |
ret = True | |
idx = 0 | |
scene_start = (scene[0] for scene in list(self.scene_index.values())) | |
self.frame_idx = [] | |
while ret: | |
idx += 1 | |
f = cap.grab() | |
if not idx % self.skip_rate or idx in scene_start: | |
ret, frame = cap.retrieve(f) | |
if ret: | |
res.append(frame) | |
self.frame_idx.append(idx) | |
cap.release() | |
return np.array(res) | |
def find_objs(self, scene_aware=True): | |
self.obj_detections = {} | |
if scene_aware: | |
self.scene_idx() | |
frames = self.load_video(self.skip_rate) | |
objs = self.model(frames) | |
for ii, idx in enumerate(self.frame_idx): | |
try: | |
self.obj_detections[idx] = list(map(lambda l: [int(max((x, 0))) for x in l], [objs[ii][j][0] for j in range(len(objs[ii]))])) | |
except IndexError: | |
pass | |
return self.obj_detections | |
def index_objs(self, skip_rate=10): | |
scene_no = 0 | |
trackers = {} | |
objs = self.find_objs() | |
frames = self.load_video(skip_rate) | |
num_frames, height, width, channels = frames.shape | |
frame_idx = self.frame_idx | |
scenes = [scene for scene in list(self.scene_index.values())] | |
scene = scenes.pop(0) | |
for ii, idx in enumerate(frame_idx): | |
frame = frames[ii] | |
if idx in list(objs.keys()): | |
bboxes = objs[idx] | |
trackers = SceneTracker(min((height, width)) / 100, self.expiration) | |
trackers.init(idx, frame, bboxes) | |
else: | |
if trackers: | |
trackers.update(idx, frame) | |
while idx >= scene[1] and idx <= frame_idx[-1]: | |
for tracker_id in list(trackers.tracker_dict.keys()): | |
if np.max(list(trackers.tracker_dict[tracker_id][2])): | |
print(self.source, scene_no, tracker_id, trackers.tracker_dict[tracker_id][1]) | |
if self.show: | |
cv2.imshow('face', frame) | |
if cv2.waitKey(1) & 0xFF == ord('q'): | |
break | |
scene = scenes.pop(0) | |
scene_no += 1 | |
if __name__ == '__main__': | |
import os | |
import sys | |
from time import time | |
video_path = sys.argv[1] | |
vid = Video(video_path, skip_rate=100) | |
start_time = time() | |
vid.index_objs(skip_rate=30) | |
print(time() - start_time) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment