Skip to content

Instantly share code, notes, and snippets.

@mikaelhg
Created August 21, 2022 19:37
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mikaelhg/7e0f991bd685a289ae4ef30a5e084e3d to your computer and use it in GitHub Desktop.
Save mikaelhg/7e0f991bd685a289ae4ef30a5e084e3d to your computer and use it in GitHub Desktop.
Five minute pedestrian tracking with norfair and yolov5 from Torch Hub
#!/bin/env python
import argparse
from typing import List
import numpy as np
import torch
import norfair
from norfair import Detection, Tracker, Video
max_distance_between_points = 30
_INTERESTING_CLASSES = [0]
def yolo_to_norfair_detections(dets: torch.tensor) -> List[Detection]:
results: List[Detection] = []
assert len(dets.xywh) == 1
for xywh in dets.xywh[0]:
if int(xywh[5].item()) not in _INTERESTING_CLASSES:
continue
centroid = np.array([xywh[0].item(), xywh[1].item()])
scores = np.array([xywh[4].item()])
results.append(Detection(points=centroid, scores=scores))
return results
def euclidean_distance(detection, tracked_object):
return np.linalg.norm(detection.points - tracked_object.estimate)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("files", type=str, nargs="+", help="Video files to process")
args = parser.parse_args()
model = torch.hub.load('ultralytics/yolov5', 'yolov5m6')
for input_path in args.files:
video = Video(input_path=input_path)
tracker = Tracker(
distance_function=euclidean_distance,
distance_threshold=max_distance_between_points,
)
for frame in video:
results = model(frame)
detections = yolo_to_norfair_detections(results)
tracked_objects = tracker.update(detections)
norfair.draw_points(frame, detections)
norfair.draw_tracked_objects(frame, tracked_objects)
video.write(frame)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment