From 5788e79adc4138c5792e1c2c740bf1f55a62e3aa Mon Sep 17 00:00:00 2001 From: Hendrik Date: Wed, 29 Nov 2023 16:34:50 +0100 Subject: [PATCH 1/5] enhanced the maximizeMergeMatcher, added corresponding unittests for it --- panoptica/__init__.py | 2 +- panoptica/_functionals.py | 2 +- panoptica/instance_evaluator.py | 5 +- panoptica/instance_matcher.py | 69 ++++++++++++++++++++++--- panoptica/metrics/metrics.py | 6 ++- panoptica/utils/processing_pair.py | 3 ++ unit_tests/test_panoptic_evaluator.py | 74 ++++++++++++++++++++++++++- 7 files changed, 146 insertions(+), 15 deletions(-) diff --git a/panoptica/__init__.py b/panoptica/__init__.py index a3d7805..4e4c832 100644 --- a/panoptica/__init__.py +++ b/panoptica/__init__.py @@ -1,5 +1,5 @@ from panoptica.instance_approximator import ConnectedComponentsInstanceApproximator, CCABackend from panoptica.instance_matcher import NaiveThresholdMatching from panoptica.evaluator import Panoptic_Evaluator -from panoptica.result import PanopticaResult +from panoptica.panoptic_result import PanopticaResult from panoptica.utils.processing_pair import SemanticPair, UnmatchedInstancePair, MatchedInstancePair diff --git a/panoptica/_functionals.py b/panoptica/_functionals.py index ee137bd..fa990fe 100644 --- a/panoptica/_functionals.py +++ b/panoptica/_functionals.py @@ -60,7 +60,7 @@ def _calc_matching_metric_of_overlapping_labels( mm_values = pool.starmap(matching_metric.metric_function, instance_pairs) mm_pairs = [(i, (instance_pairs[idx][2], instance_pairs[idx][3])) for idx, i in enumerate(mm_values)] - mm_pairs = sorted(mm_pairs, key=lambda x: x[0], reverse=matching_metric.decreasing) + mm_pairs = sorted(mm_pairs, key=lambda x: x[0], reverse=not matching_metric.decreasing) return mm_pairs diff --git a/panoptica/instance_evaluator.py b/panoptica/instance_evaluator.py index ebcdb74..2ff0d93 100644 --- a/panoptica/instance_evaluator.py +++ b/panoptica/instance_evaluator.py @@ -49,8 +49,9 @@ def evaluate_matched_instance( metric_dicts = pool.starmap(_evaluate_instance, instance_pairs) for metric_dict in metric_dicts: - assert decision_threshold is not None - if decision_metric is None or decision_metric.score_beats_threshold(metric_dict[decision_metric.name], decision_threshold): + if decision_metric is None or ( + decision_threshold is not None and decision_metric.score_beats_threshold(metric_dict[decision_metric.name], decision_threshold) + ): for k, v in metric_dict.items(): score_dict[k].append(v) diff --git a/panoptica/instance_matcher.py b/panoptica/instance_matcher.py index eaada18..87e4e8d 100644 --- a/panoptica/instance_matcher.py +++ b/panoptica/instance_matcher.py @@ -110,8 +110,6 @@ def map_instance_labels(processing_pair: UnmatchedInstancePair, labelmap: Instan pred_labelmap = labelmap.get_one_to_one_dictionary() ref_matched_labels = list([r for r in ref_labels if r in pred_labelmap.values()]) - n_matched_instances = len(ref_matched_labels) - # assign missed instances to next unused labels sequentially missed_ref_labels = list([r for r in ref_labels if r not in ref_matched_labels]) missed_pred_labels = list([p for p in pred_labels if p not in pred_labelmap]) @@ -128,11 +126,6 @@ def map_instance_labels(processing_pair: UnmatchedInstancePair, labelmap: Instan matched_instance_pair = MatchedInstancePair( prediction_arr=prediction_arr_relabeled, reference_arr=processing_pair._reference_arr, - missed_reference_labels=missed_ref_labels, - missed_prediction_labels=missed_pred_labels, - n_prediction_instance=processing_pair.n_prediction_instance, - n_reference_instance=processing_pair.n_reference_instance, - matched_instances=ref_matched_labels, ) return matched_instance_pair @@ -221,7 +214,67 @@ class MaximizeMergeMatching(InstanceMatchingAlgorithm): AssertionError: If the specified IoU threshold is not within the valid range. """ - pass + def _match_instances( + self, + unmatched_instance_pair: UnmatchedInstancePair, + matching_metric: MatchingMetric, + matching_threshold: float, + **kwargs, + ) -> InstanceLabelMap: + """ + Perform one-to-one instance matching based on IoU values. + + Args: + unmatched_instance_pair (UnmatchedInstancePair): The unmatched instance pair to be matched. + **kwargs: Additional keyword arguments. + + Returns: + Instance_Label_Map: The result of the instance matching. + """ + ref_labels = unmatched_instance_pair._ref_labels + # pred_labels = unmatched_instance_pair._pred_labels + + # Initialize variables for True Positives (tp) and False Positives (fp) + labelmap = InstanceLabelMap() + score_ref: dict[int, float] = {} + + pred_arr, ref_arr = unmatched_instance_pair._prediction_arr, unmatched_instance_pair._reference_arr + mm_pairs = _calc_matching_metric_of_overlapping_labels(pred_arr, ref_arr, ref_labels, matching_metric=matching_metric) + + # Loop through matched instances to compute PQ components + for matching_score, (ref_label, pred_label) in mm_pairs: + if labelmap.contains_and(pred_label=pred_label, ref_label=None): + # skip if prediction label is already matched + continue + if labelmap.contains_and(None, ref_label): + pred_labels_ = labelmap.get_pred_labels_matched_to_ref(ref_label) + new_score = self.new_combination_score(pred_labels_, pred_label, ref_label, unmatched_instance_pair, matching_metric) + if new_score > score_ref[ref_label]: + labelmap.add_labelmap_entry(pred_label, ref_label) + score_ref[ref_label] = new_score + elif matching_metric.score_beats_threshold(matching_score, matching_threshold): + # Match found, increment true positive count and collect IoU and Dice values + labelmap.add_labelmap_entry(pred_label, ref_label) + score_ref[ref_label] = matching_score + # map label ref_idx to pred_idx + return labelmap + + def new_combination_score( + self, + pred_labels: list[int], + new_pred_label: int, + ref_label: int, + unmatched_instance_pair: UnmatchedInstancePair, + matching_metric: MatchingMetric, + ): + pred_labels.append(new_pred_label) + score = matching_metric( + unmatched_instance_pair.reference_arr, + prediction_arr=unmatched_instance_pair.prediction_arr, + ref_instance_idx=ref_label, + pred_instance_idx=pred_labels, + ) + return score class MatchUntilConvergenceMatching(InstanceMatchingAlgorithm): diff --git a/panoptica/metrics/metrics.py b/panoptica/metrics/metrics.py index 2f89d28..f44e5de 100644 --- a/panoptica/metrics/metrics.py +++ b/panoptica/metrics/metrics.py @@ -21,13 +21,15 @@ def __call__( reference_arr: np.ndarray, prediction_arr: np.ndarray, ref_instance_idx: int | None = None, - pred_instance_idx: int | None = None, + pred_instance_idx: int | list[int] | None = None, *args, **kwargs, ): if ref_instance_idx is not None and pred_instance_idx is not None: reference_arr = reference_arr.copy() == ref_instance_idx - prediction_arr = prediction_arr.copy() == pred_instance_idx + if isinstance(pred_instance_idx, int): + pred_instance_idx = [pred_instance_idx] + prediction_arr = np.isin(prediction_arr.copy(), pred_instance_idx) return self.metric_function(reference_arr, prediction_arr, *args, **kwargs) def __eq__(self, __value: object) -> bool: diff --git a/panoptica/utils/processing_pair.py b/panoptica/utils/processing_pair.py index 9551974..4b88591 100644 --- a/panoptica/utils/processing_pair.py +++ b/panoptica/utils/processing_pair.py @@ -303,6 +303,9 @@ def add_labelmap_entry(self, pred_labels: list[int] | int, ref_label: int): ) self.labelmap[p] = ref_label + def get_pred_labels_matched_to_ref(self, ref_label: int): + return [k for k, v in self.labelmap.items() if v == ref_label] + def contains_and(self, pred_label: int | None = None, ref_label: int | None = None) -> bool: pred_in = True if pred_label is None else pred_label in self.labelmap ref_in = True if ref_label is None else ref_label in self.labelmap.values() diff --git a/unit_tests/test_panoptic_evaluator.py b/unit_tests/test_panoptic_evaluator.py index d58dbca..1415b42 100644 --- a/unit_tests/test_panoptic_evaluator.py +++ b/unit_tests/test_panoptic_evaluator.py @@ -8,7 +8,7 @@ from panoptica.evaluator import Panoptic_Evaluator from panoptica.instance_approximator import ConnectedComponentsInstanceApproximator -from panoptica.instance_matcher import NaiveThresholdMatching +from panoptica.instance_matcher import NaiveThresholdMatching, MaximizeMergeMatching from panoptica.utils.processing_pair import SemanticPair from panoptica.metrics import MatchingMetric, MatchingMetrics @@ -223,3 +223,75 @@ def test_dtype_evaluation(self): self.assertEqual(result.fp, 0) self.assertEqual(result.sq, 0.75) self.assertEqual(result.pq, 0.75) + + def test_simple_evaluation_maximize_matcher(self): + a = np.zeros([50, 50], dtype=np.uint16) + b = a.copy().astype(a.dtype) + a[20:40, 10:20] = 1 + b[20:35, 10:20] = 2 + + sample = SemanticPair(b, a) + + evaluator = Panoptic_Evaluator( + expected_input=SemanticPair, + instance_approximator=ConnectedComponentsInstanceApproximator(), + instance_matcher=MaximizeMergeMatching(), + matching_metric=MatchingMetrics.IOU, + ) + + result, debug_data = evaluator.evaluate(sample) + print(result) + self.assertEqual(result.tp, 1) + self.assertEqual(result.fp, 0) + self.assertEqual(result.sq, 0.75) + self.assertEqual(result.pq, 0.75) + + def test_simple_evaluation_maximize_matcher_overlaptwo(self): + a = np.zeros([50, 50], dtype=np.uint16) + b = a.copy().astype(a.dtype) + a[20:40, 10:20] = 1 + b[20:35, 10:20] = 2 + b[36:38, 10:20] = 3 + + sample = SemanticPair(b, a) + + evaluator = Panoptic_Evaluator( + expected_input=SemanticPair, + instance_approximator=ConnectedComponentsInstanceApproximator(), + instance_matcher=MaximizeMergeMatching(), + matching_metric=MatchingMetrics.IOU, + ) + + result, debug_data = evaluator.evaluate(sample) + print(result) + self.assertEqual(result.tp, 1) + self.assertEqual(result.fp, 0) + self.assertEqual(result.sq, 0.85) + self.assertEqual(result.pq, 0.85) + + def test_simple_evaluation_maximize_matcher_overlap(self): + a = np.zeros([50, 50], dtype=np.uint16) + b = a.copy().astype(a.dtype) + a[20:40, 10:20] = 1 + b[20:35, 10:20] = 2 + b[36:38, 10:20] = 3 + # match the two above to 1 and the 4 to nothing (FP) + b[39:47, 10:20] = 4 + + sample = SemanticPair(b, a) + + evaluator = Panoptic_Evaluator( + expected_input=SemanticPair, + instance_approximator=ConnectedComponentsInstanceApproximator(), + instance_matcher=MaximizeMergeMatching(), + matching_metric=MatchingMetrics.IOU, + ) + + result, debug_data = evaluator.evaluate(sample) + print(result) + self.assertEqual(result.tp, 1) + self.assertEqual(result.fp, 1) + self.assertEqual(result.sq, 0.85) + self.assertAlmostEqual(result.pq, 0.56666666) + self.assertAlmostEqual(result.rq, 0.66666666) + self.assertAlmostEqual(result.sq_dsc, 0.9189189189189) From 3dfc0e26289d13a2ed21ecd2cc5dd1d21f409330 Mon Sep 17 00:00:00 2001 From: iback Date: Mon, 15 Jan 2024 14:43:54 +0000 Subject: [PATCH 2/5] some comment updates --- panoptica/instance_matcher.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/panoptica/instance_matcher.py b/panoptica/instance_matcher.py index 5c18490..a1d4a09 100644 --- a/panoptica/instance_matcher.py +++ b/panoptica/instance_matcher.py @@ -216,12 +216,11 @@ def _match_instances( class MaximizeMergeMatching(InstanceMatchingAlgorithm): """ - Instance matching algorithm that performs many-to-one matching based on metric. Will merge if combined instance metric is greater than individual one + Instance matching algorithm that performs many-to-one matching based on metric. Will merge if combined instance metric is greater than individual one. Only matches if at least a single instance exceeds the threshold Methods: _match_instances(self, unmatched_instance_pair: UnmatchedInstancePair, **kwargs) -> Instance_Label_Map: - Perform one-to-one instance matching based on IoU values. Raises: AssertionError: If the specified IoU threshold is not within the valid range. @@ -232,10 +231,11 @@ def __init__( matching_threshold: float = 0.5, ) -> None: """ - Initialize the NaiveOneToOneMatching instance. + Initialize the MaximizeMergeMatching instance. Args: - iou_threshold (float, optional): The IoU threshold for matching instances. Defaults to 0.5. + matching_metric (_MatchingMetric): The metric to be used for matching. + matching_threshold (float, optional): The metric threshold for matching instances. Defaults to 0.5. Raises: AssertionError: If the specified IoU threshold is not within the valid range. @@ -275,7 +275,7 @@ def _match_instances( continue if labelmap.contains_ref(ref_label): pred_labels_ = labelmap.get_pred_labels_matched_to_ref(ref_label) - new_score = self.new_combination_score(pred_labels_, pred_label, ref_label, unmatched_instance_pair, self.matching_metric) + new_score = self.new_combination_score(pred_labels_, pred_label, ref_label, unmatched_instance_pair) if new_score > score_ref[ref_label]: labelmap.add_labelmap_entry(pred_label, ref_label) score_ref[ref_label] = new_score @@ -292,10 +292,9 @@ def new_combination_score( new_pred_label: int, ref_label: int, unmatched_instance_pair: UnmatchedInstancePair, - matching_metric: _MatchingMetric, ): pred_labels.append(new_pred_label) - score = matching_metric( + score = self.matching_metric( unmatched_instance_pair.reference_arr, prediction_arr=unmatched_instance_pair.prediction_arr, ref_instance_idx=ref_label, From f6d3bfd01f04a54b9b6a344ca26d08e85f872805 Mon Sep 17 00:00:00 2001 From: "brainless-bot[bot]" <153751247+brainless-bot[bot]@users.noreply.github.com> Date: Mon, 15 Jan 2024 15:29:21 +0000 Subject: [PATCH 3/5] Autoformat with black --- panoptica/_functionals.py | 4 +++- panoptica/instance_matcher.py | 18 ++++++++++++++---- panoptica/utils/processing_pair.py | 4 +++- unit_tests/test_panoptic_evaluator.py | 2 +- 4 files changed, 21 insertions(+), 7 deletions(-) diff --git a/panoptica/_functionals.py b/panoptica/_functionals.py index 44ff817..820f08e 100644 --- a/panoptica/_functionals.py +++ b/panoptica/_functionals.py @@ -68,7 +68,9 @@ def _calc_matching_metric_of_overlapping_labels( (i, (instance_pairs[idx][2], instance_pairs[idx][3])) for idx, i in enumerate(mm_values) ] - mm_pairs = sorted(mm_pairs, key=lambda x: x[0], reverse=not matching_metric.decreasing) + mm_pairs = sorted( + mm_pairs, key=lambda x: x[0], reverse=not matching_metric.decreasing + ) return mm_pairs diff --git a/panoptica/instance_matcher.py b/panoptica/instance_matcher.py index a1d4a09..ba0902e 100644 --- a/panoptica/instance_matcher.py +++ b/panoptica/instance_matcher.py @@ -225,6 +225,7 @@ class MaximizeMergeMatching(InstanceMatchingAlgorithm): Raises: AssertionError: If the specified IoU threshold is not within the valid range. """ + def __init__( self, matching_metric: _MatchingMetric = Metrics.IOU, @@ -265,8 +266,13 @@ def _match_instances( labelmap = InstanceLabelMap() score_ref: dict[int, float] = {} - pred_arr, ref_arr = unmatched_instance_pair._prediction_arr, unmatched_instance_pair._reference_arr - mm_pairs = _calc_matching_metric_of_overlapping_labels(pred_arr, ref_arr, ref_labels, matching_metric=self.matching_metric) + pred_arr, ref_arr = ( + unmatched_instance_pair._prediction_arr, + unmatched_instance_pair._reference_arr, + ) + mm_pairs = _calc_matching_metric_of_overlapping_labels( + pred_arr, ref_arr, ref_labels, matching_metric=self.matching_metric + ) # Loop through matched instances to compute PQ components for matching_score, (ref_label, pred_label) in mm_pairs: @@ -275,11 +281,15 @@ def _match_instances( continue if labelmap.contains_ref(ref_label): pred_labels_ = labelmap.get_pred_labels_matched_to_ref(ref_label) - new_score = self.new_combination_score(pred_labels_, pred_label, ref_label, unmatched_instance_pair) + new_score = self.new_combination_score( + pred_labels_, pred_label, ref_label, unmatched_instance_pair + ) if new_score > score_ref[ref_label]: labelmap.add_labelmap_entry(pred_label, ref_label) score_ref[ref_label] = new_score - elif self.matching_metric.score_beats_threshold(matching_score, self.matching_threshold): + elif self.matching_metric.score_beats_threshold( + matching_score, self.matching_threshold + ): # Match found, increment true positive count and collect IoU and Dice values labelmap.add_labelmap_entry(pred_label, ref_label) score_ref[ref_label] = matching_score diff --git a/panoptica/utils/processing_pair.py b/panoptica/utils/processing_pair.py index ca4d7da..bde4a58 100644 --- a/panoptica/utils/processing_pair.py +++ b/panoptica/utils/processing_pair.py @@ -338,7 +338,9 @@ def contains_pred(self, pred_label: int): def contains_ref(self, ref_label: int): return ref_label in self.labelmap.values() - def contains_and(self, pred_label: int | None = None, ref_label: int | None = None) -> bool: + def contains_and( + self, pred_label: int | None = None, ref_label: int | None = None + ) -> bool: pred_in = True if pred_label is None else pred_label in self.labelmap ref_in = True if ref_label is None else ref_label in self.labelmap.values() return pred_in and ref_in diff --git a/unit_tests/test_panoptic_evaluator.py b/unit_tests/test_panoptic_evaluator.py index 7fe8351..fa5c2ec 100644 --- a/unit_tests/test_panoptic_evaluator.py +++ b/unit_tests/test_panoptic_evaluator.py @@ -9,7 +9,7 @@ from panoptica.panoptic_evaluator import Panoptic_Evaluator from panoptica.instance_approximator import ConnectedComponentsInstanceApproximator from panoptica.instance_matcher import NaiveThresholdMatching, MaximizeMergeMatching -from panoptica.metrics import _MatchingMetric, Metrics +from panoptica.metrics import _MatchingMetric, Metrics from panoptica.utils.processing_pair import SemanticPair From 28f02ecb84f2a59e4ccbafda3360c757c7dc4e9b Mon Sep 17 00:00:00 2001 From: iback Date: Mon, 15 Jan 2024 15:30:40 +0000 Subject: [PATCH 4/5] fixed import metrics statement --- panoptica/instance_evaluator.py | 1 + panoptica/panoptic_evaluator.py | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/panoptica/instance_evaluator.py b/panoptica/instance_evaluator.py index f1974ab..f27fb11 100644 --- a/panoptica/instance_evaluator.py +++ b/panoptica/instance_evaluator.py @@ -11,6 +11,7 @@ from panoptica.timing import measure_time from panoptica.utils import EdgeCaseHandler from panoptica.utils.processing_pair import MatchedInstancePair +from panoptica.metrics import Metrics def evaluate_matched_instance( diff --git a/panoptica/panoptic_evaluator.py b/panoptica/panoptic_evaluator.py index ae904d4..07535ae 100644 --- a/panoptica/panoptic_evaluator.py +++ b/panoptica/panoptic_evaluator.py @@ -17,7 +17,6 @@ ) from panoptica.utils.citation_reminder import citation_reminder - class Panoptic_Evaluator: def __init__( self, From b792285e15560c0736abd276fd368051a59c2adb Mon Sep 17 00:00:00 2001 From: "brainless-bot[bot]" <153751247+brainless-bot[bot]@users.noreply.github.com> Date: Mon, 15 Jan 2024 15:32:13 +0000 Subject: [PATCH 5/5] Autoformat with black --- panoptica/panoptic_evaluator.py | 1 + 1 file changed, 1 insertion(+) diff --git a/panoptica/panoptic_evaluator.py b/panoptica/panoptic_evaluator.py index 07535ae..ae904d4 100644 --- a/panoptica/panoptic_evaluator.py +++ b/panoptica/panoptic_evaluator.py @@ -17,6 +17,7 @@ ) from panoptica.utils.citation_reminder import citation_reminder + class Panoptic_Evaluator: def __init__( self,