From 85ebf83805d22c3d1c57280f47b0e823e360cbce Mon Sep 17 00:00:00 2001 From: Alex Robbins Date: Fri, 6 Dec 2024 14:17:50 -0700 Subject: [PATCH] Fix reid track pruning bug Issue 325 --- norfair/tracker.py | 7 ++++++ tests/test_tracker.py | 53 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 60 insertions(+) diff --git a/norfair/tracker.py b/norfair/tracker.py index 275eceee..99ab5ad4 100644 --- a/norfair/tracker.py +++ b/norfair/tracker.py @@ -631,6 +631,13 @@ def hit(self, detection: "Detection", period: int = 1): self.is_initializing = False self._acquire_ids() + # Reset reid_hit_counter if we are are successfully tracking this object. + # If hit_counter was 0 when Tracker.update was called and ReID is being used, + # we preemptively set reid_hit_counter to reid_hit_counter_max. But if the object + # is hit, we need to reset it. + if self.hit_counter_is_positive: + self.reid_hit_counter = None + # We use a kalman filter in which we consider each coordinate on each point as a sensor. # This is a hacky way to update only certain sensors (only x, y coordinates for # points which were detected). diff --git a/tests/test_tracker.py b/tests/test_tracker.py index 6e3b9e8a..770585a2 100644 --- a/tests/test_tracker.py +++ b/tests/test_tracker.py @@ -353,6 +353,59 @@ def dist(new_obj, tracked_obj): assert tracked_objects[0].id != obj_id +def test_reid_hit_counter_reset(): + # + # test that reid hit counter resets to None if it had started counting down but + # then the track was hit with an incoming detection + # + + # simple reid distance + def dist(new_obj, tracked_obj): + return np.linalg.norm(new_obj.estimate - tracked_obj.estimate) + + hit_counter_max = 2 + reid_hit_counter_max = 2 + + tracker = Tracker( + distance_function="euclidean", + distance_threshold=1, + hit_counter_max=hit_counter_max, + initialization_delay=1, + reid_distance_function=dist, + reid_distance_threshold=5, + reid_hit_counter_max=reid_hit_counter_max, + ) + + # check that hit counters initialize correctly + tracked_objects = tracker.update([Detection(points=np.array([[1, 1]]))]) + tracked_objects = tracker.update([Detection(points=np.array([[1, 1]]))]) + assert len(tracked_objects) == 1 + assert tracked_objects[0].hit_counter == 2 + assert tracked_objects[0].reid_hit_counter == None + + # check that object is still alive when hit_counter goes to 0 + obj_id = tracked_objects[0].id + for _ in range(hit_counter_max): + tracked_objects = tracker.update() + assert len(tracked_objects) == 1 + assert tracked_objects[0].hit_counter == 0 + assert tracked_objects[0].reid_hit_counter is None + + # check that object is alive and reid_hit_counter is None after being matched again + tracked_objects = tracker.update([Detection(points=np.array([[1, 1]]))]) + assert len(tracked_objects) == 1 + assert tracked_objects[0].hit_counter == 1 + assert tracked_objects[0].reid_hit_counter is None + + # check that after reid_hit_counter_max more updates, object still exists + for _ in range(reid_hit_counter_max + 2): + tracked_objects = tracker.update([Detection(points=np.array([[1, 1]]))]) + assert len(tracked_objects) == 1 + assert tracked_objects[0].hit_counter == 2 + assert tracked_objects[0].reid_hit_counter is None + assert tracked_objects[0].id == obj_id + + # TODO tests list: # - detections with different labels # - partial matches where some points are missing