diff --git a/examples/example_spine_instance.py b/examples/example_spine_instance.py index bcea628..8a38960 100644 --- a/examples/example_spine_instance.py +++ b/examples/example_spine_instance.py @@ -14,7 +14,7 @@ evaluator = Panoptica_Evaluator( expected_input=InputType.MATCHED_INSTANCE, - eval_metrics=[Metric.DSC, Metric.IOU], + instance_metrics=[Metric.DSC, Metric.IOU], segmentation_class_groups=SegmentationClassGroups( { "vertebra": LabelGroup([i for i in range(1, 10)]), diff --git a/panoptica/configs/panoptica_evaluator_unmatched_instance.yaml b/panoptica/configs/panoptica_evaluator_unmatched_instance.yaml index ad8d9b0..55fbbc3 100644 --- a/panoptica/configs/panoptica_evaluator_unmatched_instance.yaml +++ b/panoptica/configs/panoptica_evaluator_unmatched_instance.yaml @@ -19,7 +19,8 @@ edge_case_handler: !EdgeCaseHandler !Metric RVD: !MetricZeroTPEdgeCaseHandling {empty_prediction_result: !EdgeCaseResult NAN, empty_reference_result: !EdgeCaseResult NAN, no_instances_result: !EdgeCaseResult NAN, normal: !EdgeCaseResult NAN} -eval_metrics: [!Metric DSC, !Metric IOU] +instance_metrics: [!Metric DSC, !Metric IOU] +global_metrics: [!Metric DSC, !Metric RVD] expected_input: !InputType UNMATCHED_INSTANCE instance_approximator: null instance_matcher: !NaiveThresholdMatching {allow_many_to_one: false, matching_metric: !Metric IOU, diff --git a/panoptica/instance_evaluator.py b/panoptica/instance_evaluator.py index 383df0e..303abf0 100644 --- a/panoptica/instance_evaluator.py +++ b/panoptica/instance_evaluator.py @@ -1,11 +1,8 @@ from multiprocessing import Pool - import numpy as np from panoptica.metrics import Metric -from panoptica.panoptica_result import PanopticaResult -from panoptica.utils import EdgeCaseHandler -from panoptica.utils.processing_pair import MatchedInstancePair +from panoptica.utils.processing_pair import MatchedInstancePair, EvaluateInstancePair def evaluate_matched_instance( @@ -13,9 +10,8 @@ def evaluate_matched_instance( eval_metrics: list[Metric] = [Metric.DSC, Metric.IOU, Metric.ASSD], decision_metric: Metric | None = Metric.IOU, decision_threshold: float | None = None, - edge_case_handler: EdgeCaseHandler | None = None, **kwargs, -) -> PanopticaResult: +) -> EvaluateInstancePair: """ Map instance labels based on the provided labelmap and create a MatchedInstancePair. @@ -31,8 +27,6 @@ def evaluate_matched_instance( >>> labelmap = [([1, 2], [3, 4]), ([5], [6])] >>> result = map_instance_labels(unmatched_instance_pair, labelmap) """ - if edge_case_handler is None: - edge_case_handler = EdgeCaseHandler() if decision_metric is not None: assert decision_metric.name in [ v.name for v in eval_metrics @@ -68,14 +62,13 @@ def evaluate_matched_instance( score_dict[k].append(v) # Create and return the PanopticaResult object with computed metrics - return PanopticaResult( + return EvaluateInstancePair( reference_arr=matched_instance_pair.reference_arr, prediction_arr=matched_instance_pair.prediction_arr, num_pred_instances=matched_instance_pair.n_prediction_instance, num_ref_instances=matched_instance_pair.n_reference_instance, tp=tp, list_metrics=score_dict, - edge_case_handler=edge_case_handler, ) diff --git a/panoptica/metrics/metrics.py b/panoptica/metrics/metrics.py index 2d6a4c0..5cbce45 100644 --- a/panoptica/metrics/metrics.py +++ b/panoptica/metrics/metrics.py @@ -23,6 +23,7 @@ class _Metric: """A Metric class containing a name, whether higher or lower values is better, and a function to calculate that metric between two instances in an array""" name: str + long_name: str decreasing: bool _metric_function: Callable @@ -91,11 +92,21 @@ class Metric(_Enum_Compare): _type_: _description_ """ - DSC = _Metric("DSC", False, _compute_instance_volumetric_dice) - IOU = _Metric("IOU", False, _compute_instance_iou) - ASSD = _Metric("ASSD", True, _compute_instance_average_symmetric_surface_distance) - clDSC = _Metric("clDSC", False, _compute_centerline_dice) - RVD = _Metric("RVD", True, _compute_instance_relative_volume_difference) + DSC = _Metric("DSC", "Dice", False, _compute_instance_volumetric_dice) + IOU = _Metric("IOU", "Intersection over Union", False, _compute_instance_iou) + ASSD = _Metric( + "ASSD", + "Average Symmetric Surface Distance", + True, + _compute_instance_average_symmetric_surface_distance, + ) + clDSC = _Metric("clDSC", "Centerline Dice", False, _compute_centerline_dice) + RVD = _Metric( + "RVD", + "Relative Volume Difference", + True, + _compute_instance_relative_volume_difference, + ) # ST = _Metric("ST", False, _compute_instance_segmentation_tendency) def __call__( diff --git a/panoptica/panoptica_evaluator.py b/panoptica/panoptica_evaluator.py index fd5617c..f0522f8 100644 --- a/panoptica/panoptica_evaluator.py +++ b/panoptica/panoptica_evaluator.py @@ -15,6 +15,7 @@ UnmatchedInstancePair, _ProcessingPair, InputType, + EvaluateInstancePair, ) import numpy as np from panoptica.utils.config import SupportsConfig @@ -30,7 +31,13 @@ def __init__( instance_matcher: InstanceMatchingAlgorithm | None = None, edge_case_handler: EdgeCaseHandler | None = None, segmentation_class_groups: SegmentationClassGroups | None = None, - eval_metrics: list[Metric] = [Metric.DSC, Metric.IOU, Metric.ASSD, Metric.RVD], + instance_metrics: list[Metric] = [ + Metric.DSC, + Metric.IOU, + Metric.ASSD, + Metric.RVD, + ], + global_metrics: list[Metric] = [Metric.DSC], decision_metric: Metric | None = None, decision_threshold: float | None = None, log_times: bool = False, @@ -43,12 +50,21 @@ def __init__( instance_approximator (InstanceApproximator | None, optional): Determines which instance approximator is used if necessary. Defaults to None. instance_matcher (InstanceMatchingAlgorithm | None, optional): Determines which instance matching algorithm is used if necessary. Defaults to None. iou_threshold (float, optional): Iou Threshold for evaluation. Defaults to 0.5. + edge_case_handler (edge_case_handler, optional): EdgeCaseHandler to be used. If none, will create the default one + segmentation_class_groups (SegmentationClassGroups, optional): If not none, will evaluate per class group defined, instead of over all at the same time. + instance_metrics (list[Metric]): List of all metrics that should be calculated between all instances + global_metrics (list[Metric]): List of all metrics that should be calculated on the global binary masks + decision_metric: (Metric | None, optional): This metric is the final decision point between True Positive and False Positive. Can be left away if the matching algorithm is used (it will match by a metric and threshold already) + decision_threshold: (float | None, optional): Threshold for the decision_metric + log_times (bool): If true, will printout the times for the different phases of the pipeline. + verbose (bool): If true, will spit out more details than you want. """ self.__expected_input = expected_input # self.__instance_approximator = instance_approximator self.__instance_matcher = instance_matcher - self.__eval_metrics = eval_metrics + self.__eval_metrics = instance_metrics + self.__global_metrics = global_metrics self.__decision_metric = decision_metric self.__decision_threshold = decision_threshold @@ -74,6 +90,7 @@ def _yaml_repr(cls, node) -> dict: "edge_case_handler": node.__edge_case_handler, "segmentation_class_groups": node.__segmentation_class_groups, "eval_metrics": node.__eval_metrics, + "global_metrics": node.__global_metrics, "decision_metric": node.__decision_metric, "decision_threshold": node.__decision_threshold, "log_times": node.__log_times, @@ -101,7 +118,8 @@ def evaluate( edge_case_handler=self.__edge_case_handler, instance_approximator=self.__instance_approximator, instance_matcher=self.__instance_matcher, - eval_metrics=self.__eval_metrics, + instance_metrics=self.__eval_metrics, + global_metrics=self.__global_metrics, decision_metric=self.__decision_metric, decision_threshold=self.__decision_threshold, result_all=result_all, @@ -142,7 +160,8 @@ def evaluate( edge_case_handler=self.__edge_case_handler, instance_approximator=self.__instance_approximator, instance_matcher=self.__instance_matcher, - eval_metrics=self.__eval_metrics, + instance_metrics=self.__eval_metrics, + global_metrics=self.__global_metrics, decision_metric=self.__decision_metric, decision_threshold=decision_threshold, result_all=result_all, @@ -159,7 +178,8 @@ def panoptic_evaluate( ), instance_approximator: InstanceApproximator | None = None, instance_matcher: InstanceMatchingAlgorithm | None = None, - eval_metrics: list[Metric] = [Metric.DSC, Metric.IOU, Metric.ASSD], + instance_metrics: list[Metric] = [Metric.DSC, Metric.IOU, Metric.ASSD], + global_metrics: list[Metric] = [Metric.DSC], decision_metric: Metric | None = None, decision_threshold: float | None = None, edge_case_handler: EdgeCaseHandler | None = None, @@ -227,7 +247,8 @@ def panoptic_evaluate( if isinstance(processing_pair, UnmatchedInstancePair): processing_pair = _handle_zero_instances_cases( processing_pair, - eval_metrics=eval_metrics, + eval_metrics=instance_metrics, + global_metrics=global_metrics, edge_case_handler=edge_case_handler, ) @@ -250,7 +271,8 @@ def panoptic_evaluate( if isinstance(processing_pair, MatchedInstancePair): processing_pair = _handle_zero_instances_cases( processing_pair, - eval_metrics=eval_metrics, + eval_metrics=instance_metrics, + global_metrics=global_metrics, edge_case_handler=edge_case_handler, ) @@ -260,15 +282,27 @@ def panoptic_evaluate( start = perf_counter() processing_pair = evaluate_matched_instance( processing_pair, - eval_metrics=eval_metrics, + eval_metrics=instance_metrics, decision_metric=decision_metric, decision_threshold=decision_threshold, - edge_case_handler=edge_case_handler, ) if log_times: print(f"-- Instance Evaluation took {perf_counter() - start} seconds") + if isinstance(processing_pair, EvaluateInstancePair): + processing_pair = PanopticaResult( + reference_arr=processing_pair.reference_arr, + prediction_arr=processing_pair.prediction_arr, + num_pred_instances=processing_pair.num_pred_instances, + num_ref_instances=processing_pair.num_ref_instances, + tp=processing_pair.tp, + list_metrics=processing_pair.list_metrics, + global_metrics=global_metrics, + edge_case_handler=edge_case_handler, + ) + if isinstance(processing_pair, PanopticaResult): + processing_pair._global_metrics = global_metrics if result_all: processing_pair.calculate_all(print_errors=verbose_calc) return processing_pair, debug_data @@ -279,6 +313,7 @@ def panoptic_evaluate( def _handle_zero_instances_cases( processing_pair: UnmatchedInstancePair | MatchedInstancePair, edge_case_handler: EdgeCaseHandler, + global_metrics: list[Metric], eval_metrics: list[_Metric] = [Metric.DSC, Metric.IOU, Metric.ASSD], ) -> UnmatchedInstancePair | MatchedInstancePair | PanopticaResult: """ @@ -322,6 +357,7 @@ def _handle_zero_instances_cases( is_edge_case = True if is_edge_case: + panoptica_result_args["global_metrics"] = global_metrics panoptica_result_args["num_ref_instances"] = n_reference_instance panoptica_result_args["num_pred_instances"] = n_prediction_instance return PanopticaResult(**panoptica_result_args) diff --git a/panoptica/panoptica_result.py b/panoptica/panoptica_result.py index b831e04..ef631f9 100644 --- a/panoptica/panoptica_result.py +++ b/panoptica/panoptica_result.py @@ -24,12 +24,12 @@ def __init__( self, reference_arr: np.ndarray, prediction_arr: np.ndarray, - # TODO some metadata object containing dtype, voxel spacing, ... num_pred_instances: int, num_ref_instances: int, tp: int, list_metrics: dict[Metric, list[float]], edge_case_handler: EdgeCaseHandler, + global_metrics: list[Metric] = [], ): """Result object for Panoptica, contains all calculatable metrics @@ -46,6 +46,7 @@ def __init__( empty_list_std = self._edge_case_handler.handle_empty_list_std().value self._prediction_arr = prediction_arr self._reference_arr = reference_arr + self._global_metrics: list[Metric] = global_metrics ###################### # Evaluation Metrics # ###################### @@ -119,39 +120,6 @@ def __init__( ) # endregion # - # region Global - self.global_bin_dsc: int - self._add_metric( - "global_bin_dsc", - MetricType.GLOBAL, - global_bin_dsc, - long_name="Global Binary Dice", - ) - # - self.global_bin_cldsc: int - self._add_metric( - "global_bin_cldsc", - MetricType.GLOBAL, - global_bin_cldsc, - long_name="Global Binary Centerline Dice", - ) - # - self.global_bin_assd: int - self._add_metric( - "global_bin_assd", - MetricType.GLOBAL, - global_bin_assd, - long_name="Global Binary Average Symmetric Surface Distance", - ) - # - self.global_bin_rvd: int - self._add_metric( - "global_bin_rvd", - MetricType.GLOBAL, - global_bin_rvd, - long_name="Global Binary Relative Volume Difference", - ) - # endregion # # region IOU self.sq: float @@ -259,19 +227,37 @@ def __init__( ) # endregion + # region Global + # Just for autocomplete + self.global_bin_dsc: int + self.global_bin_iou: int + self.global_bin_cldsc: int + self.global_bin_assd: int + self.global_bin_rvd: int + # endregion + ################## # List Metrics # ################## self._list_metrics: dict[Metric, Evaluation_List_Metric] = {} - for k, v in list_metrics.items(): - is_edge_case, edge_case_result = self._edge_case_handler.handle_zero_tp( - metric=k, - tp=self.tp, - num_pred_instances=self.num_pred_instances, - num_ref_instances=self.num_ref_instances, - ) - self._list_metrics[k] = Evaluation_List_Metric( - k, empty_list_std, v, is_edge_case, edge_case_result + # Loop over all available metric, add it to evaluation_list_metric if available, but also add the global references + for m in Metric: + if m in list_metrics: + is_edge_case, edge_case_result = self._edge_case_handler.handle_zero_tp( + metric=m, + tp=self.tp, + num_pred_instances=self.num_pred_instances, + num_ref_instances=self.num_ref_instances, + ) + self._list_metrics[m] = Evaluation_List_Metric( + m, empty_list_std, list_metrics[m], is_edge_case, edge_case_result + ) + # even if not available, set the global vars + self._add_metric( + f"global_bin_{m.name.lower()}", + MetricType.GLOBAL, + _build_global_bin_metric_function(m), + long_name="Global Binary " + m.value.long_name, ) def _add_metric( @@ -321,6 +307,8 @@ def __str__(self) -> str: for metric_type in MetricType: if metric_type == MetricType.NO_PRINT: continue + if metric_type == MetricType.GLOBAL and len(self._global_metrics) == 0: + continue text += f"\n+++ {metric_type.name} +++\n" for k, v in self._evaluation_metrics.items(): if v.metric_type != metric_type: @@ -506,45 +494,27 @@ def sq_rvd_std(res: PanopticaResult): # endregion -# region Global -def global_bin_dsc(res: PanopticaResult): - if res.tp == 0: - return 0.0 - pred_binary = res._prediction_arr.copy() - ref_binary = res._reference_arr.copy() - pred_binary[pred_binary != 0] = 1 - ref_binary[ref_binary != 0] = 1 - return _compute_dice_coefficient(ref_binary, pred_binary) - - -def global_bin_cldsc(res: PanopticaResult): - if res.tp == 0: - return 0.0 - pred_binary = res._prediction_arr.copy() - ref_binary = res._reference_arr.copy() - pred_binary[pred_binary != 0] = 1 - ref_binary[ref_binary != 0] = 1 - return _compute_centerline_dice_coefficient(ref_binary, pred_binary) - - -def global_bin_assd(res: PanopticaResult): - if res.tp == 0: - return 0.0 - pred_binary = res._prediction_arr.copy() - ref_binary = res._reference_arr.copy() - pred_binary[pred_binary != 0] = 1 - ref_binary[ref_binary != 0] = 1 - return _average_symmetric_surface_distance(ref_binary, pred_binary) +def _build_global_bin_metric_function(metric: Metric): + def function_template(res: PanopticaResult): + if metric not in res._global_metrics: + raise MetricCouldNotBeComputedException(f"Global Metric {metric} not set") + if res.tp == 0: + is_edgecase, result = res._edge_case_handler.handle_zero_tp( + metric, res.tp, res.num_pred_instances, res.num_ref_instances + ) + if is_edgecase: + return result + pred_binary = res._prediction_arr.copy() + ref_binary = res._reference_arr.copy() + pred_binary[pred_binary != 0] = 1 + ref_binary[ref_binary != 0] = 1 + return metric( + reference_arr=res._reference_arr, + prediction_arr=res._prediction_arr, + ) -def global_bin_rvd(res: PanopticaResult): - if res.tp == 0: - return 0.0 - pred_binary = res._prediction_arr.copy() - ref_binary = res._reference_arr.copy() - pred_binary[pred_binary != 0] = 1 - ref_binary[ref_binary != 0] = 1 - return _compute_relative_volume_difference(ref_binary, pred_binary) + return function_template # endregion diff --git a/panoptica/utils/edge_case_handling.py b/panoptica/utils/edge_case_handling.py index 1a865f0..b88f49e 100644 --- a/panoptica/utils/edge_case_handling.py +++ b/panoptica/utils/edge_case_handling.py @@ -143,6 +143,20 @@ def handle_zero_tp( num_pred_instances: int, num_ref_instances: int, ) -> tuple[bool, float | None]: + """_summary_ + + Args: + metric (Metric): _description_ + tp (int): _description_ + num_pred_instances (int): _description_ + num_ref_instances (int): _description_ + + Raises: + NotImplementedError: _description_ + + Returns: + tuple[bool, float | None]: if edge case, and its edge case value + """ if tp != 0: return False, EdgeCaseResult.NONE.value if metric not in self.__listmetric_zeroTP_handling: diff --git a/panoptica/utils/processing_pair.py b/panoptica/utils/processing_pair.py index c64b1e9..e0afa86 100644 --- a/panoptica/utils/processing_pair.py +++ b/panoptica/utils/processing_pair.py @@ -5,6 +5,8 @@ from panoptica._functionals import _get_paired_crop from panoptica.utils import _count_unique_without_zeros, _unique_without_zeros from panoptica.utils.constants import _Enum_Compare +from dataclasses import dataclass +from panoptica.metrics import Metric uint_type: type = np.unsignedinteger int_type: type = np.integer @@ -315,6 +317,16 @@ def copy(self): ) +@dataclass +class EvaluateInstancePair: + reference_arr: np.ndarray + prediction_arr: np.ndarray + num_pred_instances: int + num_ref_instances: int + tp: int + list_metrics: dict[Metric, list[float]] + + class InputType(_Enum_Compare): SEMANTIC = SemanticPair UNMATCHED_INSTANCE = UnmatchedInstancePair diff --git a/unit_tests/test_panoptic_evaluator.py b/unit_tests/test_panoptic_evaluator.py index 5fe9c0e..a3009ae 100644 --- a/unit_tests/test_panoptic_evaluator.py +++ b/unit_tests/test_panoptic_evaluator.py @@ -70,7 +70,7 @@ def test_simple_evaluation_DSC_partial(self): expected_input=InputType.SEMANTIC, instance_approximator=ConnectedComponentsInstanceApproximator(), instance_matcher=NaiveThresholdMatching(matching_metric=Metric.DSC), - eval_metrics=[Metric.DSC], + instance_metrics=[Metric.DSC], ) result, debug_data = evaluator.evaluate(b, a)["ungrouped"] diff --git a/unit_tests/test_panoptic_result.py b/unit_tests/test_panoptic_result.py index f5087ea..200d084 100644 --- a/unit_tests/test_panoptic_result.py +++ b/unit_tests/test_panoptic_result.py @@ -28,6 +28,7 @@ def test_simple_evaluation(self): num_pred_instances=5, tp=0, list_metrics={Metric.IOU: []}, + global_metrics=[Metric.DSC], edge_case_handler=EdgeCaseHandler(), ) c.calculate_all(print_errors=True)