Skip to content

Instantly share code, notes, and snippets.

@bresilla
Created March 18, 2024 15:16
Show Gist options
  • Save bresilla/8a67c6899a37ffece0986036ef4c549c to your computer and use it in GitHub Desktop.
Save bresilla/8a67c6899a37ffece0986036ef4c549c to your computer and use it in GitHub Desktop.
simple_tracker
#!/usr/bin/env python3
import numpy as np
import rclpy
from rclpy.node import Node
from sensor_msgs.msg import Image
from cv_bridge import CvBridge
from detection_msgs import BoundingBoxes
import cv2
from message_filters import Subscriber, ApproximateTimeSynchronizer
class ObjectTracker(Node):
def __init__(self):
super().__init__('object_tracker')
self.bridge = CvBridge()
self.object_count = 0
self.image_sub = Subscriber(self, Image, '/camera/image_raw')
self.bbox_sub = Subscriber(self, BoundingBoxes, '/bboxes')
self.sync = ApproximateTimeSynchronizer([self.image_sub, self.bbox_sub], 10, 0.1)
self.sync.registerCallback(self.callback)
self.bbox_pub = self.create_publisher(BoundingBoxes, '/tracked_bboxes', 10)
self.trackers = []
def callback(self, img_msg, bbox_msg):
frame = self.bridge.imgmsg_to_cv2(img_msg, "bgr8")
current_boxes = []
for bbox in bbox_msg.bounding_boxes:
box = (bbox.x_offset, bbox.y_offset, bbox.width, bbox.height)
current_boxes.append(box)
new_trackers = []
for box in current_boxes:
is_new_object = True
for tracker in self.trackers:
_, tracked_box = tracker.update(frame)
dist = np.linalg.norm(np.array((box[0]+50, box[1]+50)) - np.array((tracked_box[0]+50, tracked_box[1]+50)))
if dist < 20: # if the current box is within 20 pixels of a tracked box, it's likely the same object
is_new_object = False
new_trackers.append(tracker)
break
if is_new_object:
self.object_count += 1
tracker = cv2.TrackerCSRT_create()
tracker.init(frame, box)
new_trackers.append(tracker)
tracked_bboxes = BoundingBoxes()
tracked_bboxes.header = bbox_msg.header
for tracker in new_trackers:
_, tracked_box = tracker.update(frame)
x, y, w, h = tracked_box
bbox = BoundingBox()
bbox.x_offset = x
bbox.y_offset = y
bbox.width = w
bbox.height = h
bbox.id = self.object_count
tracked_bboxes.bounding_boxes.append(bbox)
self.bbox_pub.publish(tracked_bboxes)
self.trackers = new_trackers
def main(args=None):
rclpy.init(args=args)
node = ObjectTracker()
rclpy.spin(node)
node.destroy_node()
rclpy.shutdown()
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment