diff --git a/examples/example_spine_semantic.py b/examples/example_spine_semantic.py index 8bf993d..57e6a04 100644 --- a/examples/example_spine_semantic.py +++ b/examples/example_spine_semantic.py @@ -22,6 +22,7 @@ expected_input=SemanticPair, instance_approximator=ConnectedComponentsInstanceApproximator(), instance_matcher=NaiveThresholdMatching(), + verbose=True, ) with cProfile.Profile() as pr: diff --git a/panoptica/metrics/metrics.py b/panoptica/metrics/metrics.py index 57432d7..960e6af 100644 --- a/panoptica/metrics/metrics.py +++ b/panoptica/metrics/metrics.py @@ -175,6 +175,7 @@ class MetricType(_Enum_Compare): _Enum_Compare (_type_): _description_ """ + NO_PRINT = auto() MATCHING = auto() GLOBAL = auto() INSTANCE = auto() diff --git a/panoptica/panoptic_evaluator.py b/panoptica/panoptic_evaluator.py index 0119fbc..1daeda9 100644 --- a/panoptica/panoptic_evaluator.py +++ b/panoptica/panoptic_evaluator.py @@ -146,7 +146,6 @@ def panoptic_evaluate( instance_approximator is not None ), "Got SemanticPair but not InstanceApproximator" print("-- Got SemanticPair, will approximate instances") - processing_pair = instance_approximator.approximate_instances(processing_pair) start = perf_counter() processing_pair = instance_approximator.approximate_instances(processing_pair) if log_times: @@ -185,6 +184,7 @@ def panoptic_evaluate( if isinstance(processing_pair, MatchedInstancePair): print("-- Got MatchedInstancePair, will evaluate instances") + start = perf_counter() processing_pair = evaluate_matched_instance( processing_pair, eval_metrics=eval_metrics, diff --git a/panoptica/panoptic_result.py b/panoptica/panoptic_result.py index 2b4b79b..0abbba9 100644 --- a/panoptica/panoptic_result.py +++ b/panoptica/panoptic_result.py @@ -13,6 +13,7 @@ MetricType, _compute_centerline_dice_coefficient, _compute_dice_coefficient, + _average_symmetric_surface_distance, ) from panoptica.utils import EdgeCaseHandler @@ -94,6 +95,20 @@ def __init__( fn, long_name="False Negatives", ) + self.prec: int + self._add_metric( + "prec", + MetricType.NO_PRINT, + prec, + long_name="Precision (positive predictive value)", + ) + self.rec: int + self._add_metric( + "rec", + MetricType.NO_PRINT, + rec, + long_name="Recall (sensitivity)", + ) self.rq: float self._add_metric( "rq", @@ -119,6 +134,14 @@ def __init__( global_bin_cldsc, long_name="Global Binary Centerline Dice", ) + # + self.global_bin_assd: int + self._add_metric( + "global_bin_assd", + MetricType.GLOBAL, + global_bin_assd, + long_name="Global Binary Average Symmetric Surface Distance", + ) # endregion # # region IOU @@ -270,6 +293,8 @@ def calculate_all(self, print_errors: bool = False): def __str__(self) -> str: text = "" for metric_type in MetricType: + if metric_type == MetricType.NO_PRINT: + continue text += f"\n+++ {metric_type.name} +++\n" for k, v in self._evaluation_metrics.items(): if v.metric_type != metric_type: @@ -360,6 +385,14 @@ def fn(res: PanopticaResult): return res.num_ref_instances - res.tp +def prec(res: PanopticaResult): + return res.tp / (res.tp + res.fp) + + +def rec(res: PanopticaResult): + return res.tp / (res.tp + res.fn) + + def rq(res: PanopticaResult): """ Calculate the Recognition Quality (RQ) based on TP, FP, and FN. @@ -456,6 +489,16 @@ def global_bin_cldsc(res: PanopticaResult): return _compute_centerline_dice_coefficient(ref_binary, pred_binary) +def global_bin_assd(res: PanopticaResult): + if res.tp == 0: + return 0.0 + pred_binary = res._prediction_arr.copy() + ref_binary = res._reference_arr.copy() + pred_binary[pred_binary != 0] = 1 + ref_binary[ref_binary != 0] = 1 + return _average_symmetric_surface_distance(ref_binary, pred_binary) + + # endregion