diff --git a/panoptica/_functionals.py b/panoptica/_functionals.py index eee9a49..5344d20 100644 --- a/panoptica/_functionals.py +++ b/panoptica/_functionals.py @@ -3,7 +3,6 @@ import numpy as np -from panoptica.metrics.iou import _compute_instance_iou from panoptica.utils.constants import CCABackend from panoptica.utils.numpy_utils import _get_bbox_nd diff --git a/panoptica/panoptica_result.py b/panoptica/panoptica_result.py index ef631f9..60b0e43 100644 --- a/panoptica/panoptica_result.py +++ b/panoptica/panoptica_result.py @@ -11,10 +11,6 @@ MetricCouldNotBeComputedException, MetricMode, MetricType, - _compute_centerline_dice_coefficient, - _compute_dice_coefficient, - _average_symmetric_surface_distance, - _compute_relative_volume_difference, ) from panoptica.utils import EdgeCaseHandler @@ -44,8 +40,6 @@ def __init__( """ self._edge_case_handler = edge_case_handler empty_list_std = self._edge_case_handler.handle_empty_list_std().value - self._prediction_arr = prediction_arr - self._reference_arr = reference_arr self._global_metrics: list[Metric] = global_metrics ###################### # Evaluation Metrics # @@ -253,12 +247,42 @@ def __init__( m, empty_list_std, list_metrics[m], is_edge_case, edge_case_result ) # even if not available, set the global vars + default_value = None + was_calculated = False + if m in self._global_metrics: + default_value = self._calc_global_bin_metric( + m, prediction_arr, reference_arr + ) + was_calculated = True + self._add_metric( f"global_bin_{m.name.lower()}", MetricType.GLOBAL, - _build_global_bin_metric_function(m), + lambda x: MetricCouldNotBeComputedException( + f"Global Metric {m} not set" + ), long_name="Global Binary " + m.value.long_name, + default_value=default_value, + was_calculated=was_calculated, + ) + + def _calc_global_bin_metric(self, metric: Metric, prediction_arr, reference_arr): + if metric not in self._global_metrics: + raise MetricCouldNotBeComputedException(f"Global Metric {metric} not set") + if self.tp == 0: + is_edgecase, result = self._edge_case_handler.handle_zero_tp( + metric, self.tp, self.num_pred_instances, self.num_ref_instances ) + if is_edgecase: + return result + pred_binary = prediction_arr + ref_binary = reference_arr + pred_binary[pred_binary != 0] = 1 + ref_binary[ref_binary != 0] = 1 + return metric( + reference_arr=ref_binary, + prediction_arr=pred_binary, + ) def _add_metric( self, @@ -292,6 +316,7 @@ def calculate_all(self, print_errors: bool = False): print_errors (bool, optional): If true, will print every metric that could not be computed and its reason. Defaults to False. """ metric_errors: dict[str, Exception] = {} + for k, v in self._evaluation_metrics.items(): try: v = getattr(self, k) @@ -302,6 +327,13 @@ def calculate_all(self, print_errors: bool = False): for k, v in metric_errors.items(): print(f"Metric {k}: {v}") + def _calc(self, k, v): + try: + v = getattr(self, k) + return False, v + except Exception as e: + return True, e + def __str__(self) -> str: text = "" for metric_type in MetricType: @@ -366,6 +398,8 @@ def __getattribute__(self, __name: str) -> Any: try: attr = object.__getattribute__(self, __name) except AttributeError as e: + if __name == "_evaluation_metrics": + raise e if __name in self._evaluation_metrics.keys(): pass else: @@ -514,12 +548,11 @@ def function_template(res: PanopticaResult): prediction_arr=res._prediction_arr, ) - return function_template + return lambda x: function_template(x) # endregion - if __name__ == "__main__": c = PanopticaResult( reference_arr=np.zeros([5, 5, 5]), diff --git a/unit_tests/test_panoptic_evaluator.py b/unit_tests/test_panoptic_evaluator.py index a3009ae..5ab23a8 100644 --- a/unit_tests/test_panoptic_evaluator.py +++ b/unit_tests/test_panoptic_evaluator.py @@ -40,6 +40,28 @@ def test_simple_evaluation(self): self.assertEqual(result.fp, 0) self.assertEqual(result.sq, 0.75) self.assertEqual(result.pq, 0.75) + self.assertAlmostEqual(result.global_bin_dsc, 0.8571428571428571) + + def test_simple_evaluation_instance_multiclass(self): + a = np.zeros([50, 50], dtype=np.uint16) + b = a.copy().astype(a.dtype) + a[20:30, 10:20] = 1 + a[30:40, 10:20] = 3 + b[20:35, 10:20] = 2 + + evaluator = Panoptica_Evaluator( + expected_input=InputType.UNMATCHED_INSTANCE, + instance_matcher=NaiveThresholdMatching(), + ) + + result, debug_data = evaluator.evaluate(b, a)["ungrouped"] + print(result) + self.assertAlmostEqual(result.global_bin_dsc, 0.8571428571428571) + self.assertEqual(result.tp, 1) + self.assertEqual(result.fp, 0) + self.assertEqual(result.fn, 1) + self.assertAlmostEqual(result.sq, 0.6666666666666666) + self.assertAlmostEqual(result.pq, 0.4444444444444444) def test_simple_evaluation_DSC(self): a = np.zeros([50, 50], dtype=np.uint16)