diff --git a/panoptica/_functionals.py b/panoptica/_functionals.py index 5b03723..820f08e 100644 --- a/panoptica/_functionals.py +++ b/panoptica/_functionals.py @@ -68,7 +68,9 @@ def _calc_matching_metric_of_overlapping_labels( (i, (instance_pairs[idx][2], instance_pairs[idx][3])) for idx, i in enumerate(mm_values) ] - mm_pairs = sorted(mm_pairs, key=lambda x: x[0], reverse=matching_metric.decreasing) + mm_pairs = sorted( + mm_pairs, key=lambda x: x[0], reverse=not matching_metric.decreasing + ) return mm_pairs diff --git a/panoptica/instance_evaluator.py b/panoptica/instance_evaluator.py index fa742f8..f27fb11 100644 --- a/panoptica/instance_evaluator.py +++ b/panoptica/instance_evaluator.py @@ -5,13 +5,13 @@ import numpy as np from panoptica.metrics import ( - Metrics, _MatchingMetric, ) from panoptica.panoptic_result import PanopticaResult from panoptica.timing import measure_time from panoptica.utils import EdgeCaseHandler from panoptica.utils.processing_pair import MatchedInstancePair +from panoptica.metrics import Metrics def evaluate_matched_instance( diff --git a/panoptica/instance_matcher.py b/panoptica/instance_matcher.py index 339609f..ba0902e 100644 --- a/panoptica/instance_matcher.py +++ b/panoptica/instance_matcher.py @@ -109,8 +109,6 @@ def map_instance_labels( pred_labelmap = labelmap.get_one_to_one_dictionary() ref_matched_labels = list([r for r in ref_labels if r in pred_labelmap.values()]) - n_matched_instances = len(ref_matched_labels) - # assign missed instances to next unused labels sequentially missed_ref_labels = list([r for r in ref_labels if r not in ref_matched_labels]) missed_pred_labels = list([p for p in pred_labels if p not in pred_labelmap]) @@ -127,11 +125,6 @@ def map_instance_labels( matched_instance_pair = MatchedInstancePair( prediction_arr=prediction_arr_relabeled, reference_arr=processing_pair._reference_arr, - missed_reference_labels=missed_ref_labels, - missed_prediction_labels=missed_pred_labels, - n_prediction_instance=processing_pair.n_prediction_instance, - n_reference_instance=processing_pair.n_reference_instance, - matched_instances=ref_matched_labels, ) return matched_instance_pair @@ -223,18 +216,101 @@ def _match_instances( class MaximizeMergeMatching(InstanceMatchingAlgorithm): """ - Instance matching algorithm that performs many-to-one matching based on metric. Will merge if combined instance metric is greater than individual one + Instance matching algorithm that performs many-to-one matching based on metric. Will merge if combined instance metric is greater than individual one. Only matches if at least a single instance exceeds the threshold Methods: _match_instances(self, unmatched_instance_pair: UnmatchedInstancePair, **kwargs) -> Instance_Label_Map: - Perform one-to-one instance matching based on IoU values. Raises: AssertionError: If the specified IoU threshold is not within the valid range. """ - pass + def __init__( + self, + matching_metric: _MatchingMetric = Metrics.IOU, + matching_threshold: float = 0.5, + ) -> None: + """ + Initialize the MaximizeMergeMatching instance. + + Args: + matching_metric (_MatchingMetric): The metric to be used for matching. + matching_threshold (float, optional): The metric threshold for matching instances. Defaults to 0.5. + + Raises: + AssertionError: If the specified IoU threshold is not within the valid range. + """ + self.matching_metric = matching_metric + self.matching_threshold = matching_threshold + + def _match_instances( + self, + unmatched_instance_pair: UnmatchedInstancePair, + **kwargs, + ) -> InstanceLabelMap: + """ + Perform one-to-one instance matching based on IoU values. + + Args: + unmatched_instance_pair (UnmatchedInstancePair): The unmatched instance pair to be matched. + **kwargs: Additional keyword arguments. + + Returns: + Instance_Label_Map: The result of the instance matching. + """ + ref_labels = unmatched_instance_pair._ref_labels + # pred_labels = unmatched_instance_pair._pred_labels + + # Initialize variables for True Positives (tp) and False Positives (fp) + labelmap = InstanceLabelMap() + score_ref: dict[int, float] = {} + + pred_arr, ref_arr = ( + unmatched_instance_pair._prediction_arr, + unmatched_instance_pair._reference_arr, + ) + mm_pairs = _calc_matching_metric_of_overlapping_labels( + pred_arr, ref_arr, ref_labels, matching_metric=self.matching_metric + ) + + # Loop through matched instances to compute PQ components + for matching_score, (ref_label, pred_label) in mm_pairs: + if labelmap.contains_pred(pred_label=pred_label): + # skip if prediction label is already matched + continue + if labelmap.contains_ref(ref_label): + pred_labels_ = labelmap.get_pred_labels_matched_to_ref(ref_label) + new_score = self.new_combination_score( + pred_labels_, pred_label, ref_label, unmatched_instance_pair + ) + if new_score > score_ref[ref_label]: + labelmap.add_labelmap_entry(pred_label, ref_label) + score_ref[ref_label] = new_score + elif self.matching_metric.score_beats_threshold( + matching_score, self.matching_threshold + ): + # Match found, increment true positive count and collect IoU and Dice values + labelmap.add_labelmap_entry(pred_label, ref_label) + score_ref[ref_label] = matching_score + # map label ref_idx to pred_idx + return labelmap + + def new_combination_score( + self, + pred_labels: list[int], + new_pred_label: int, + ref_label: int, + unmatched_instance_pair: UnmatchedInstancePair, + ): + pred_labels.append(new_pred_label) + score = self.matching_metric( + unmatched_instance_pair.reference_arr, + prediction_arr=unmatched_instance_pair.prediction_arr, + ref_instance_idx=ref_label, + pred_instance_idx=pred_labels, + ) + return score class MatchUntilConvergenceMatching(InstanceMatchingAlgorithm): diff --git a/panoptica/metrics/__init__.py b/panoptica/metrics/__init__.py index ef93513..fd32f08 100644 --- a/panoptica/metrics/__init__.py +++ b/panoptica/metrics/__init__.py @@ -8,9 +8,9 @@ ) from panoptica.metrics.iou import _compute_instance_iou, _compute_iou from panoptica.metrics.metrics import ( - EvalMetric, + Metrics, ListMetric, + EvalMetric, MetricDict, - Metrics, _MatchingMetric, ) diff --git a/panoptica/metrics/metrics.py b/panoptica/metrics/metrics.py index f282907..b6e7117 100644 --- a/panoptica/metrics/metrics.py +++ b/panoptica/metrics/metrics.py @@ -23,13 +23,15 @@ def __call__( reference_arr: np.ndarray, prediction_arr: np.ndarray, ref_instance_idx: int | None = None, - pred_instance_idx: int | None = None, + pred_instance_idx: int | list[int] | None = None, *args, **kwargs, ): if ref_instance_idx is not None and pred_instance_idx is not None: reference_arr = reference_arr.copy() == ref_instance_idx - prediction_arr = prediction_arr.copy() == pred_instance_idx + if isinstance(pred_instance_idx, int): + pred_instance_idx = [pred_instance_idx] + prediction_arr = np.isin(prediction_arr.copy(), pred_instance_idx) return self._metric_function(reference_arr, prediction_arr, *args, **kwargs) def __eq__(self, __value: object) -> bool: diff --git a/panoptica/utils/processing_pair.py b/panoptica/utils/processing_pair.py index 93d4886..bde4a58 100644 --- a/panoptica/utils/processing_pair.py +++ b/panoptica/utils/processing_pair.py @@ -329,6 +329,15 @@ def add_labelmap_entry(self, pred_labels: list[int] | int, ref_label: int): ) self.labelmap[p] = ref_label + def get_pred_labels_matched_to_ref(self, ref_label: int): + return [k for k, v in self.labelmap.items() if v == ref_label] + + def contains_pred(self, pred_label: int): + return pred_label in self.labelmap + + def contains_ref(self, ref_label: int): + return ref_label in self.labelmap.values() + def contains_and( self, pred_label: int | None = None, ref_label: int | None = None ) -> bool: diff --git a/unit_tests/test_panoptic_evaluator.py b/unit_tests/test_panoptic_evaluator.py index 27a7eb9..fa5c2ec 100644 --- a/unit_tests/test_panoptic_evaluator.py +++ b/unit_tests/test_panoptic_evaluator.py @@ -8,8 +8,8 @@ from panoptica.panoptic_evaluator import Panoptic_Evaluator from panoptica.instance_approximator import ConnectedComponentsInstanceApproximator -from panoptica.instance_matcher import NaiveThresholdMatching -from panoptica.metrics import Metrics +from panoptica.instance_matcher import NaiveThresholdMatching, MaximizeMergeMatching +from panoptica.metrics import _MatchingMetric, Metrics from panoptica.utils.processing_pair import SemanticPair @@ -238,3 +238,72 @@ def test_dtype_evaluation(self): self.assertEqual(result.fp, 0) self.assertEqual(result.sq, 0.75) self.assertEqual(result.pq, 0.75) + + def test_simple_evaluation_maximize_matcher(self): + a = np.zeros([50, 50], dtype=np.uint16) + b = a.copy().astype(a.dtype) + a[20:40, 10:20] = 1 + b[20:35, 10:20] = 2 + + sample = SemanticPair(b, a) + + evaluator = Panoptic_Evaluator( + expected_input=SemanticPair, + instance_approximator=ConnectedComponentsInstanceApproximator(), + instance_matcher=MaximizeMergeMatching(), + ) + + result, debug_data = evaluator.evaluate(sample) + print(result) + self.assertEqual(result.tp, 1) + self.assertEqual(result.fp, 0) + self.assertEqual(result.sq, 0.75) + self.assertEqual(result.pq, 0.75) + + def test_simple_evaluation_maximize_matcher_overlaptwo(self): + a = np.zeros([50, 50], dtype=np.uint16) + b = a.copy().astype(a.dtype) + a[20:40, 10:20] = 1 + b[20:35, 10:20] = 2 + b[36:38, 10:20] = 3 + + sample = SemanticPair(b, a) + + evaluator = Panoptic_Evaluator( + expected_input=SemanticPair, + instance_approximator=ConnectedComponentsInstanceApproximator(), + instance_matcher=MaximizeMergeMatching(), + ) + + result, debug_data = evaluator.evaluate(sample) + print(result) + self.assertEqual(result.tp, 1) + self.assertEqual(result.fp, 0) + self.assertEqual(result.sq, 0.85) + self.assertEqual(result.pq, 0.85) + + def test_simple_evaluation_maximize_matcher_overlap(self): + a = np.zeros([50, 50], dtype=np.uint16) + b = a.copy().astype(a.dtype) + a[20:40, 10:20] = 1 + b[20:35, 10:20] = 2 + b[36:38, 10:20] = 3 + # match the two above to 1 and the 4 to nothing (FP) + b[39:47, 10:20] = 4 + + sample = SemanticPair(b, a) + + evaluator = Panoptic_Evaluator( + expected_input=SemanticPair, + instance_approximator=ConnectedComponentsInstanceApproximator(), + instance_matcher=MaximizeMergeMatching(), + ) + + result, debug_data = evaluator.evaluate(sample) + print(result) + self.assertEqual(result.tp, 1) + self.assertEqual(result.fp, 1) + self.assertEqual(result.sq, 0.85) + self.assertAlmostEqual(result.pq, 0.56666666) + self.assertAlmostEqual(result.rq, 0.66666666) + self.assertAlmostEqual(result.sq_dsc, 0.9189189189189)