From 735397e1ce54e3b22974dffebf34284c5728d538 Mon Sep 17 00:00:00 2001 From: iback Date: Wed, 7 Aug 2024 15:40:36 +0000 Subject: [PATCH 1/2] global binary metrics are now using a function creator, as they have an underlying pattern. They use the edge case handler now. to allow for maximum flexibility, added a EvaluateInstancePair class which allows some functions to be easier. This way, one could jump from an instance approximation algorithm directly to results. Also added global_metrics as argument so users can decide which global metrics should be calculated, default set to Dice (DSC). Additionally, renamed eval_metrics to instance_metrics to distinguish it better to the global_metrics argument --- examples/example_spine_instance.py | 2 +- examples/example_spine_instance_config.py | 4 +- examples/example_spine_semantic.py | 4 +- ...anoptica_evaluator_unmatched_instance.yaml | 3 +- panoptica/instance_evaluator.py | 31 +--- panoptica/metrics/metrics.py | 67 +++----- panoptica/panoptica_evaluator.py | 85 +++++----- panoptica/panoptica_result.py | 146 ++++++------------ panoptica/utils/edge_case_handling.py | 50 +++--- panoptica/utils/processing_pair.py | 72 +++------ unit_tests/test_panoptic_evaluator.py | 2 +- unit_tests/test_panoptic_result.py | 5 +- 12 files changed, 182 insertions(+), 289 deletions(-) diff --git a/examples/example_spine_instance.py b/examples/example_spine_instance.py index bcea628..8a38960 100644 --- a/examples/example_spine_instance.py +++ b/examples/example_spine_instance.py @@ -14,7 +14,7 @@ evaluator = Panoptica_Evaluator( expected_input=InputType.MATCHED_INSTANCE, - eval_metrics=[Metric.DSC, Metric.IOU], + instance_metrics=[Metric.DSC, Metric.IOU], segmentation_class_groups=SegmentationClassGroups( { "vertebra": LabelGroup([i for i in range(1, 10)]), diff --git a/examples/example_spine_instance_config.py b/examples/example_spine_instance_config.py index 56c61e4..2855add 100644 --- a/examples/example_spine_instance_config.py +++ b/examples/example_spine_instance_config.py @@ -10,9 +10,7 @@ reference_mask = read_nifti(directory + "/spine_seg/matched_instance/ref.nii.gz") prediction_mask = read_nifti(directory + "/spine_seg/matched_instance/pred.nii.gz") -evaluator = Panoptica_Evaluator.load_from_config_name( - "panoptica_evaluator_unmatched_instance" -) +evaluator = Panoptica_Evaluator.load_from_config_name("panoptica_evaluator_unmatched_instance") with cProfile.Profile() as pr: diff --git a/examples/example_spine_semantic.py b/examples/example_spine_semantic.py index a2e5a32..2793592 100644 --- a/examples/example_spine_semantic.py +++ b/examples/example_spine_semantic.py @@ -25,9 +25,7 @@ with cProfile.Profile() as pr: if __name__ == "__main__": - result, debug_data = evaluator.evaluate(prediction_mask, reference_mask)[ - "ungrouped" - ] + result, debug_data = evaluator.evaluate(prediction_mask, reference_mask)["ungrouped"] print(result) pr.dump_stats(directory + "/semantic_example.log") diff --git a/panoptica/configs/panoptica_evaluator_unmatched_instance.yaml b/panoptica/configs/panoptica_evaluator_unmatched_instance.yaml index ad8d9b0..55fbbc3 100644 --- a/panoptica/configs/panoptica_evaluator_unmatched_instance.yaml +++ b/panoptica/configs/panoptica_evaluator_unmatched_instance.yaml @@ -19,7 +19,8 @@ edge_case_handler: !EdgeCaseHandler !Metric RVD: !MetricZeroTPEdgeCaseHandling {empty_prediction_result: !EdgeCaseResult NAN, empty_reference_result: !EdgeCaseResult NAN, no_instances_result: !EdgeCaseResult NAN, normal: !EdgeCaseResult NAN} -eval_metrics: [!Metric DSC, !Metric IOU] +instance_metrics: [!Metric DSC, !Metric IOU] +global_metrics: [!Metric DSC, !Metric RVD] expected_input: !InputType UNMATCHED_INSTANCE instance_approximator: null instance_matcher: !NaiveThresholdMatching {allow_many_to_one: false, matching_metric: !Metric IOU, diff --git a/panoptica/instance_evaluator.py b/panoptica/instance_evaluator.py index 383df0e..9b784e7 100644 --- a/panoptica/instance_evaluator.py +++ b/panoptica/instance_evaluator.py @@ -1,11 +1,8 @@ from multiprocessing import Pool - import numpy as np from panoptica.metrics import Metric -from panoptica.panoptica_result import PanopticaResult -from panoptica.utils import EdgeCaseHandler -from panoptica.utils.processing_pair import MatchedInstancePair +from panoptica.utils.processing_pair import MatchedInstancePair, EvaluateInstancePair def evaluate_matched_instance( @@ -13,9 +10,8 @@ def evaluate_matched_instance( eval_metrics: list[Metric] = [Metric.DSC, Metric.IOU, Metric.ASSD], decision_metric: Metric | None = Metric.IOU, decision_threshold: float | None = None, - edge_case_handler: EdgeCaseHandler | None = None, **kwargs, -) -> PanopticaResult: +) -> EvaluateInstancePair: """ Map instance labels based on the provided labelmap and create a MatchedInstancePair. @@ -31,12 +27,8 @@ def evaluate_matched_instance( >>> labelmap = [([1, 2], [3, 4]), ([5], [6])] >>> result = map_instance_labels(unmatched_instance_pair, labelmap) """ - if edge_case_handler is None: - edge_case_handler = EdgeCaseHandler() if decision_metric is not None: - assert decision_metric.name in [ - v.name for v in eval_metrics - ], "decision metric not contained in eval_metrics" + assert decision_metric.name in [v.name for v in eval_metrics], "decision metric not contained in eval_metrics" assert decision_threshold is not None, "decision metric set but no threshold" # Initialize variables for True Positives (tp) tp = len(matched_instance_pair.matched_instances) @@ -48,34 +40,25 @@ def evaluate_matched_instance( ) ref_matched_labels = matched_instance_pair.matched_instances - instance_pairs = [ - (reference_arr, prediction_arr, ref_idx, eval_metrics) - for ref_idx in ref_matched_labels - ] + instance_pairs = [(reference_arr, prediction_arr, ref_idx, eval_metrics) for ref_idx in ref_matched_labels] with Pool() as pool: - metric_dicts: list[dict[Metric, float]] = pool.starmap( - _evaluate_instance, instance_pairs - ) + metric_dicts: list[dict[Metric, float]] = pool.starmap(_evaluate_instance, instance_pairs) for metric_dict in metric_dicts: if decision_metric is None or ( - decision_threshold is not None - and decision_metric.score_beats_threshold( - metric_dict[decision_metric], decision_threshold - ) + decision_threshold is not None and decision_metric.score_beats_threshold(metric_dict[decision_metric], decision_threshold) ): for k, v in metric_dict.items(): score_dict[k].append(v) # Create and return the PanopticaResult object with computed metrics - return PanopticaResult( + return EvaluateInstancePair( reference_arr=matched_instance_pair.reference_arr, prediction_arr=matched_instance_pair.prediction_arr, num_pred_instances=matched_instance_pair.n_prediction_instance, num_ref_instances=matched_instance_pair.n_reference_instance, tp=tp, list_metrics=score_dict, - edge_case_handler=edge_case_handler, ) diff --git a/panoptica/metrics/metrics.py b/panoptica/metrics/metrics.py index 2d6a4c0..4f02e48 100644 --- a/panoptica/metrics/metrics.py +++ b/panoptica/metrics/metrics.py @@ -23,6 +23,7 @@ class _Metric: """A Metric class containing a name, whether higher or lower values is better, and a function to calculate that metric between two instances in an array""" name: str + long_name: str decreasing: bool _metric_function: Callable @@ -39,9 +40,7 @@ def __call__( reference_arr = reference_arr.copy() == ref_instance_idx if isinstance(pred_instance_idx, int): pred_instance_idx = [pred_instance_idx] - prediction_arr = np.isin( - prediction_arr.copy(), pred_instance_idx - ) # type:ignore + prediction_arr = np.isin(prediction_arr.copy(), pred_instance_idx) # type:ignore return self._metric_function(reference_arr, prediction_arr, *args, **kwargs) def __eq__(self, __value: object) -> bool: @@ -65,12 +64,8 @@ def __hash__(self) -> int: def increasing(self): return not self.decreasing - def score_beats_threshold( - self, matching_score: float, matching_threshold: float - ) -> bool: - return (self.increasing and matching_score >= matching_threshold) or ( - self.decreasing and matching_score <= matching_threshold - ) + def score_beats_threshold(self, matching_score: float, matching_threshold: float) -> bool: + return (self.increasing and matching_score >= matching_threshold) or (self.decreasing and matching_score <= matching_threshold) class DirectValueMeta(EnumMeta): @@ -91,11 +86,11 @@ class Metric(_Enum_Compare): _type_: _description_ """ - DSC = _Metric("DSC", False, _compute_instance_volumetric_dice) - IOU = _Metric("IOU", False, _compute_instance_iou) - ASSD = _Metric("ASSD", True, _compute_instance_average_symmetric_surface_distance) - clDSC = _Metric("clDSC", False, _compute_centerline_dice) - RVD = _Metric("RVD", True, _compute_instance_relative_volume_difference) + DSC = _Metric("DSC", "Dice", False, _compute_instance_volumetric_dice) + IOU = _Metric("IOU", "Intersection over Union", False, _compute_instance_iou) + ASSD = _Metric("ASSD", "Average Symmetric Surface Distance", True, _compute_instance_average_symmetric_surface_distance) + clDSC = _Metric("clDSC", "Centerline Dice", False, _compute_centerline_dice) + RVD = _Metric("RVD", "Relative Volume Difference", True, _compute_instance_relative_volume_difference) # ST = _Metric("ST", False, _compute_instance_segmentation_tendency) def __call__( @@ -127,9 +122,7 @@ def __call__( **kwargs, ) - def score_beats_threshold( - self, matching_score: float, matching_threshold: float - ) -> bool: + def score_beats_threshold(self, matching_score: float, matching_threshold: float) -> bool: """Calculates whether a score beats a specified threshold Args: @@ -139,9 +132,7 @@ def score_beats_threshold( Returns: bool: True if the matching_score beats the threshold, False otherwise. """ - return (self.increasing and matching_score >= matching_threshold) or ( - self.decreasing and matching_score <= matching_threshold - ) + return (self.increasing and matching_score >= matching_threshold) or (self.decreasing and matching_score <= matching_threshold) @property def name(self): @@ -238,9 +229,7 @@ def __call__(self, result_obj: "PanopticaResult") -> Any: # ERROR if self._error: if self._error_obj is None: - self._error_obj = MetricCouldNotBeComputedException( - f"Metric {self.id} requested, but could not be computed" - ) + self._error_obj = MetricCouldNotBeComputedException(f"Metric {self.id} requested, but could not be computed") raise self._error_obj # Already calculated? if self._was_calculated: @@ -248,12 +237,8 @@ def __call__(self, result_obj: "PanopticaResult") -> Any: # Calculate it try: - assert ( - not self._was_calculated - ), f"Metric {self.id} was called to compute, but is set to have been already calculated" - assert ( - self._calc_func is not None - ), f"Metric {self.id} was called to compute, but has no calculation function set" + assert not self._was_calculated, f"Metric {self.id} was called to compute, but is set to have been already calculated" + assert self._calc_func is not None, f"Metric {self.id} was called to compute, but has no calculation function set" value = self._calc_func(result_obj) except MetricCouldNotBeComputedException as e: value = e @@ -298,32 +283,20 @@ def __init__( else: self.AVG = None if self.ALL is None else np.average(self.ALL) self.SUM = None if self.ALL is None else np.sum(self.ALL) - self.MIN = ( - None if self.ALL is None or len(self.ALL) == 0 else np.min(self.ALL) - ) - self.MAX = ( - None if self.ALL is None or len(self.ALL) == 0 else np.max(self.ALL) - ) - - self.STD = ( - None - if self.ALL is None - else empty_list_std if len(self.ALL) == 0 else np.std(self.ALL) - ) + self.MIN = None if self.ALL is None or len(self.ALL) == 0 else np.min(self.ALL) + self.MAX = None if self.ALL is None or len(self.ALL) == 0 else np.max(self.ALL) + + self.STD = None if self.ALL is None else empty_list_std if len(self.ALL) == 0 else np.std(self.ALL) def __getitem__(self, mode: MetricMode | str): if self.error: - raise MetricCouldNotBeComputedException( - f"Metric {self.id} has not been calculated, add it to your eval_metrics" - ) + raise MetricCouldNotBeComputedException(f"Metric {self.id} has not been calculated, add it to your eval_metrics") if isinstance(mode, MetricMode): mode = mode.name if hasattr(self, mode): return getattr(self, mode) else: - raise MetricCouldNotBeComputedException( - f"List_Metric {self.id} does not contain {mode} member" - ) + raise MetricCouldNotBeComputedException(f"List_Metric {self.id} does not contain {mode} member") if __name__ == "__main__": diff --git a/panoptica/panoptica_evaluator.py b/panoptica/panoptica_evaluator.py index fd5617c..2ffdebf 100644 --- a/panoptica/panoptica_evaluator.py +++ b/panoptica/panoptica_evaluator.py @@ -15,6 +15,7 @@ UnmatchedInstancePair, _ProcessingPair, InputType, + EvaluateInstancePair, ) import numpy as np from panoptica.utils.config import SupportsConfig @@ -30,7 +31,8 @@ def __init__( instance_matcher: InstanceMatchingAlgorithm | None = None, edge_case_handler: EdgeCaseHandler | None = None, segmentation_class_groups: SegmentationClassGroups | None = None, - eval_metrics: list[Metric] = [Metric.DSC, Metric.IOU, Metric.ASSD, Metric.RVD], + instance_metrics: list[Metric] = [Metric.DSC, Metric.IOU, Metric.ASSD, Metric.RVD], + global_metrics: list[Metric] = [Metric.DSC], decision_metric: Metric | None = None, decision_threshold: float | None = None, log_times: bool = False, @@ -43,24 +45,29 @@ def __init__( instance_approximator (InstanceApproximator | None, optional): Determines which instance approximator is used if necessary. Defaults to None. instance_matcher (InstanceMatchingAlgorithm | None, optional): Determines which instance matching algorithm is used if necessary. Defaults to None. iou_threshold (float, optional): Iou Threshold for evaluation. Defaults to 0.5. + edge_case_handler (edge_case_handler, optional): EdgeCaseHandler to be used. If none, will create the default one + segmentation_class_groups (SegmentationClassGroups, optional): If not none, will evaluate per class group defined, instead of over all at the same time. + instance_metrics (list[Metric]): List of all metrics that should be calculated between all instances + global_metrics (list[Metric]): List of all metrics that should be calculated on the global binary masks + decision_metric: (Metric | None, optional): This metric is the final decision point between True Positive and False Positive. Can be left away if the matching algorithm is used (it will match by a metric and threshold already) + decision_threshold: (float | None, optional): Threshold for the decision_metric + log_times (bool): If true, will printout the times for the different phases of the pipeline. + verbose (bool): If true, will spit out more details than you want. """ self.__expected_input = expected_input # self.__instance_approximator = instance_approximator self.__instance_matcher = instance_matcher - self.__eval_metrics = eval_metrics + self.__eval_metrics = instance_metrics + self.__global_metrics = global_metrics self.__decision_metric = decision_metric self.__decision_threshold = decision_threshold self.__segmentation_class_groups = segmentation_class_groups - self.__edge_case_handler = ( - edge_case_handler if edge_case_handler is not None else EdgeCaseHandler() - ) + self.__edge_case_handler = edge_case_handler if edge_case_handler is not None else EdgeCaseHandler() if self.__decision_metric is not None: - assert ( - self.__decision_threshold is not None - ), "decision metric set but no decision threshold for it" + assert self.__decision_threshold is not None, "decision metric set but no decision threshold for it" # self.__log_times = log_times self.__verbose = verbose @@ -74,6 +81,7 @@ def _yaml_repr(cls, node) -> dict: "edge_case_handler": node.__edge_case_handler, "segmentation_class_groups": node.__segmentation_class_groups, "eval_metrics": node.__eval_metrics, + "global_metrics": node.__global_metrics, "decision_metric": node.__decision_metric, "decision_threshold": node.__decision_threshold, "log_times": node.__log_times, @@ -90,9 +98,7 @@ def evaluate( verbose: bool | None = None, ) -> dict[str, tuple[PanopticaResult, dict[str, _ProcessingPair]]]: processing_pair = self.__expected_input(prediction_arr, reference_arr) - assert isinstance( - processing_pair, self.__expected_input.value - ), f"input not of expected type {self.__expected_input}" + assert isinstance(processing_pair, self.__expected_input.value), f"input not of expected type {self.__expected_input}" if self.__segmentation_class_groups is None: return { @@ -101,7 +107,8 @@ def evaluate( edge_case_handler=self.__edge_case_handler, instance_approximator=self.__instance_approximator, instance_matcher=self.__instance_matcher, - eval_metrics=self.__eval_metrics, + instance_metrics=self.__eval_metrics, + global_metrics=self.__global_metrics, decision_metric=self.__decision_metric, decision_threshold=self.__decision_threshold, result_all=result_all, @@ -111,12 +118,8 @@ def evaluate( ) } - self.__segmentation_class_groups.has_defined_labels_for( - processing_pair.prediction_arr, raise_error=True - ) - self.__segmentation_class_groups.has_defined_labels_for( - processing_pair.reference_arr, raise_error=True - ) + self.__segmentation_class_groups.has_defined_labels_for(processing_pair.prediction_arr, raise_error=True) + self.__segmentation_class_groups.has_defined_labels_for(processing_pair.reference_arr, raise_error=True) result_grouped = {} for group_name, label_group in self.__segmentation_class_groups.items(): @@ -128,9 +131,7 @@ def evaluate( single_instance_mode = label_group.single_instance processing_pair_grouped = processing_pair.__class__(prediction_arr=prediction_arr_grouped, reference_arr=reference_arr_grouped) # type: ignore decision_threshold = self.__decision_threshold - if single_instance_mode and not isinstance( - processing_pair, MatchedInstancePair - ): + if single_instance_mode and not isinstance(processing_pair, MatchedInstancePair): processing_pair_grouped = MatchedInstancePair( prediction_arr=processing_pair_grouped.prediction_arr, reference_arr=processing_pair_grouped.reference_arr, @@ -142,7 +143,8 @@ def evaluate( edge_case_handler=self.__edge_case_handler, instance_approximator=self.__instance_approximator, instance_matcher=self.__instance_matcher, - eval_metrics=self.__eval_metrics, + instance_metrics=self.__eval_metrics, + global_metrics=self.__global_metrics, decision_metric=self.__decision_metric, decision_threshold=decision_threshold, result_all=result_all, @@ -154,12 +156,11 @@ def evaluate( def panoptic_evaluate( - processing_pair: ( - SemanticPair | UnmatchedInstancePair | MatchedInstancePair | PanopticaResult - ), + processing_pair: SemanticPair | UnmatchedInstancePair | MatchedInstancePair | PanopticaResult, instance_approximator: InstanceApproximator | None = None, instance_matcher: InstanceMatchingAlgorithm | None = None, - eval_metrics: list[Metric] = [Metric.DSC, Metric.IOU, Metric.ASSD], + instance_metrics: list[Metric] = [Metric.DSC, Metric.IOU, Metric.ASSD], + global_metrics: list[Metric] = [Metric.DSC], decision_metric: Metric | None = None, decision_threshold: float | None = None, edge_case_handler: EdgeCaseHandler | None = None, @@ -212,9 +213,7 @@ def panoptic_evaluate( processing_pair.crop_data() if isinstance(processing_pair, SemanticPair): - assert ( - instance_approximator is not None - ), "Got SemanticPair but not InstanceApproximator" + assert instance_approximator is not None, "Got SemanticPair but not InstanceApproximator" if verbose: print("-- Got SemanticPair, will approximate instances") start = perf_counter() @@ -227,16 +226,15 @@ def panoptic_evaluate( if isinstance(processing_pair, UnmatchedInstancePair): processing_pair = _handle_zero_instances_cases( processing_pair, - eval_metrics=eval_metrics, + eval_metrics=instance_metrics, + global_metrics=global_metrics, edge_case_handler=edge_case_handler, ) if isinstance(processing_pair, UnmatchedInstancePair): if verbose: print("-- Got UnmatchedInstancePair, will match instances") - assert ( - instance_matcher is not None - ), "Got UnmatchedInstancePair but not InstanceMatchingAlgorithm" + assert instance_matcher is not None, "Got UnmatchedInstancePair but not InstanceMatchingAlgorithm" start = perf_counter() processing_pair = instance_matcher.match_instances( processing_pair, @@ -250,7 +248,8 @@ def panoptic_evaluate( if isinstance(processing_pair, MatchedInstancePair): processing_pair = _handle_zero_instances_cases( processing_pair, - eval_metrics=eval_metrics, + eval_metrics=instance_metrics, + global_metrics=global_metrics, edge_case_handler=edge_case_handler, ) @@ -260,15 +259,27 @@ def panoptic_evaluate( start = perf_counter() processing_pair = evaluate_matched_instance( processing_pair, - eval_metrics=eval_metrics, + eval_metrics=instance_metrics, decision_metric=decision_metric, decision_threshold=decision_threshold, - edge_case_handler=edge_case_handler, ) if log_times: print(f"-- Instance Evaluation took {perf_counter() - start} seconds") + if isinstance(processing_pair, EvaluateInstancePair): + processing_pair = PanopticaResult( + reference_arr=processing_pair.reference_arr, + prediction_arr=processing_pair.prediction_arr, + num_pred_instances=processing_pair.num_pred_instances, + num_ref_instances=processing_pair.num_ref_instances, + tp=processing_pair.tp, + list_metrics=processing_pair.list_metrics, + global_metrics=global_metrics, + edge_case_handler=edge_case_handler, + ) + if isinstance(processing_pair, PanopticaResult): + processing_pair._global_metrics = global_metrics if result_all: processing_pair.calculate_all(print_errors=verbose_calc) return processing_pair, debug_data @@ -279,6 +290,7 @@ def panoptic_evaluate( def _handle_zero_instances_cases( processing_pair: UnmatchedInstancePair | MatchedInstancePair, edge_case_handler: EdgeCaseHandler, + global_metrics: list[Metric], eval_metrics: list[_Metric] = [Metric.DSC, Metric.IOU, Metric.ASSD], ) -> UnmatchedInstancePair | MatchedInstancePair | PanopticaResult: """ @@ -322,6 +334,7 @@ def _handle_zero_instances_cases( is_edge_case = True if is_edge_case: + panoptica_result_args["global_metrics"] = global_metrics panoptica_result_args["num_ref_instances"] = n_reference_instance panoptica_result_args["num_pred_instances"] = n_prediction_instance return PanopticaResult(**panoptica_result_args) diff --git a/panoptica/panoptica_result.py b/panoptica/panoptica_result.py index b831e04..e87a33b 100644 --- a/panoptica/panoptica_result.py +++ b/panoptica/panoptica_result.py @@ -24,12 +24,12 @@ def __init__( self, reference_arr: np.ndarray, prediction_arr: np.ndarray, - # TODO some metadata object containing dtype, voxel spacing, ... num_pred_instances: int, num_ref_instances: int, tp: int, list_metrics: dict[Metric, list[float]], edge_case_handler: EdgeCaseHandler, + global_metrics: list[Metric] = [], ): """Result object for Panoptica, contains all calculatable metrics @@ -46,6 +46,7 @@ def __init__( empty_list_std = self._edge_case_handler.handle_empty_list_std().value self._prediction_arr = prediction_arr self._reference_arr = reference_arr + self._global_metrics: list[Metric] = global_metrics ###################### # Evaluation Metrics # ###################### @@ -119,39 +120,6 @@ def __init__( ) # endregion # - # region Global - self.global_bin_dsc: int - self._add_metric( - "global_bin_dsc", - MetricType.GLOBAL, - global_bin_dsc, - long_name="Global Binary Dice", - ) - # - self.global_bin_cldsc: int - self._add_metric( - "global_bin_cldsc", - MetricType.GLOBAL, - global_bin_cldsc, - long_name="Global Binary Centerline Dice", - ) - # - self.global_bin_assd: int - self._add_metric( - "global_bin_assd", - MetricType.GLOBAL, - global_bin_assd, - long_name="Global Binary Average Symmetric Surface Distance", - ) - # - self.global_bin_rvd: int - self._add_metric( - "global_bin_rvd", - MetricType.GLOBAL, - global_bin_rvd, - long_name="Global Binary Relative Volume Difference", - ) - # endregion # # region IOU self.sq: float @@ -259,19 +227,35 @@ def __init__( ) # endregion + # region Global + # Just for autocomplete + self.global_bin_dsc: int + self.global_bin_iou: int + self.global_bin_cldsc: int + self.global_bin_assd: int + self.global_bin_rvd: int + # endregion + ################## # List Metrics # ################## self._list_metrics: dict[Metric, Evaluation_List_Metric] = {} - for k, v in list_metrics.items(): - is_edge_case, edge_case_result = self._edge_case_handler.handle_zero_tp( - metric=k, - tp=self.tp, - num_pred_instances=self.num_pred_instances, - num_ref_instances=self.num_ref_instances, - ) - self._list_metrics[k] = Evaluation_List_Metric( - k, empty_list_std, v, is_edge_case, edge_case_result + # Loop over all available metric, add it to evaluation_list_metric if available, but also add the global references + for m in Metric: + if m in list_metrics: + is_edge_case, edge_case_result = self._edge_case_handler.handle_zero_tp( + metric=m, + tp=self.tp, + num_pred_instances=self.num_pred_instances, + num_ref_instances=self.num_ref_instances, + ) + self._list_metrics[m] = Evaluation_List_Metric(m, empty_list_std, list_metrics[m], is_edge_case, edge_case_result) + # even if not available, set the global vars + self._add_metric( + f"global_bin_{m.name.lower()}", + MetricType.GLOBAL, + _build_global_bin_metric_function(m), + long_name="Global Binary " + m.value.long_name, ) def _add_metric( @@ -321,6 +305,8 @@ def __str__(self) -> str: for metric_type in MetricType: if metric_type == MetricType.NO_PRINT: continue + if metric_type == MetricType.GLOBAL and len(self._global_metrics) == 0: + continue text += f"\n+++ {metric_type.name} +++\n" for k, v in self._evaluation_metrics.items(): if v.metric_type != metric_type: @@ -341,19 +327,13 @@ def __str__(self) -> str: return text def to_dict(self) -> dict: - return { - k: getattr(self, v.id) - for k, v in self._evaluation_metrics.items() - if (v._error == False and v._was_calculated) - } + return {k: getattr(self, v.id) for k, v in self._evaluation_metrics.items() if (v._error == False and v._was_calculated)} def get_list_metric(self, metric: Metric, mode: MetricMode): if metric in self._list_metrics: return self._list_metrics[metric][mode] else: - raise MetricCouldNotBeComputedException( - f"{metric} could not be found, have you set it in eval_metrics during evaluation?" - ) + raise MetricCouldNotBeComputedException(f"{metric} could not be found, have you set it in eval_metrics during evaluation?") def _calc_metric(self, metric_name: str, supress_error: bool = False): if metric_name in self._evaluation_metrics: @@ -369,9 +349,7 @@ def _calc_metric(self, metric_name: str, supress_error: bool = False): self._evaluation_metrics[metric_name]._was_calculated = True return value else: - raise MetricCouldNotBeComputedException( - f"could not find metric with name {metric_name}" - ) + raise MetricCouldNotBeComputedException(f"could not find metric with name {metric_name}") def __getattribute__(self, __name: str) -> Any: attr = None @@ -384,9 +362,7 @@ def __getattribute__(self, __name: str) -> Any: raise e if attr is None: if self._evaluation_metrics[__name]._error: - raise MetricCouldNotBeComputedException( - f"Requested metric {__name} that could not be computed" - ) + raise MetricCouldNotBeComputedException(f"Requested metric {__name} that could not be computed") elif not self._evaluation_metrics[__name]._was_calculated: value = self._calc_metric(__name) setattr(self, __name, value) @@ -506,45 +482,25 @@ def sq_rvd_std(res: PanopticaResult): # endregion -# region Global -def global_bin_dsc(res: PanopticaResult): - if res.tp == 0: - return 0.0 - pred_binary = res._prediction_arr.copy() - ref_binary = res._reference_arr.copy() - pred_binary[pred_binary != 0] = 1 - ref_binary[ref_binary != 0] = 1 - return _compute_dice_coefficient(ref_binary, pred_binary) - - -def global_bin_cldsc(res: PanopticaResult): - if res.tp == 0: - return 0.0 - pred_binary = res._prediction_arr.copy() - ref_binary = res._reference_arr.copy() - pred_binary[pred_binary != 0] = 1 - ref_binary[ref_binary != 0] = 1 - return _compute_centerline_dice_coefficient(ref_binary, pred_binary) - - -def global_bin_assd(res: PanopticaResult): - if res.tp == 0: - return 0.0 - pred_binary = res._prediction_arr.copy() - ref_binary = res._reference_arr.copy() - pred_binary[pred_binary != 0] = 1 - ref_binary[ref_binary != 0] = 1 - return _average_symmetric_surface_distance(ref_binary, pred_binary) - +def _build_global_bin_metric_function(metric: Metric): + + def function_template(res: PanopticaResult): + if metric not in res._global_metrics: + raise MetricCouldNotBeComputedException(f"Global Metric {metric} not set") + if res.tp == 0: + is_edgecase, result = res._edge_case_handler.handle_zero_tp(metric, res.tp, res.num_pred_instances, res.num_ref_instances) + if is_edgecase: + return result + pred_binary = res._prediction_arr.copy() + ref_binary = res._reference_arr.copy() + pred_binary[pred_binary != 0] = 1 + ref_binary[ref_binary != 0] = 1 + return metric( + reference_arr=res._reference_arr, + prediction_arr=res._prediction_arr, + ) -def global_bin_rvd(res: PanopticaResult): - if res.tp == 0: - return 0.0 - pred_binary = res._prediction_arr.copy() - ref_binary = res._reference_arr.copy() - pred_binary[pred_binary != 0] = 1 - ref_binary[ref_binary != 0] = 1 - return _compute_relative_volume_difference(ref_binary, pred_binary) + return function_template # endregion diff --git a/panoptica/utils/edge_case_handling.py b/panoptica/utils/edge_case_handling.py index 1a865f0..d6a0e02 100644 --- a/panoptica/utils/edge_case_handling.py +++ b/panoptica/utils/edge_case_handling.py @@ -41,26 +41,12 @@ def __init__( self._default_result = default_result self._edgecase_dict: dict[EdgeCaseZeroTP, EdgeCaseResult] = {} - self._edgecase_dict[EdgeCaseZeroTP.EMPTY_PRED] = ( - empty_prediction_result - if empty_prediction_result is not None - else default_result - ) - self._edgecase_dict[EdgeCaseZeroTP.EMPTY_REF] = ( - empty_reference_result - if empty_reference_result is not None - else default_result - ) - self._edgecase_dict[EdgeCaseZeroTP.NO_INSTANCES] = ( - no_instances_result if no_instances_result is not None else default_result - ) - self._edgecase_dict[EdgeCaseZeroTP.NORMAL] = ( - normal if normal is not None else default_result - ) + self._edgecase_dict[EdgeCaseZeroTP.EMPTY_PRED] = empty_prediction_result if empty_prediction_result is not None else default_result + self._edgecase_dict[EdgeCaseZeroTP.EMPTY_REF] = empty_reference_result if empty_reference_result is not None else default_result + self._edgecase_dict[EdgeCaseZeroTP.NO_INSTANCES] = no_instances_result if no_instances_result is not None else default_result + self._edgecase_dict[EdgeCaseZeroTP.NORMAL] = normal if normal is not None else default_result - def __call__( - self, tp: int, num_pred_instances, num_ref_instances - ) -> tuple[bool, float | None]: + def __call__(self, tp: int, num_pred_instances, num_ref_instances) -> tuple[bool, float | None]: if tp != 0: return False, EdgeCaseResult.NONE.value # @@ -131,9 +117,7 @@ def __init__( }, empty_list_std: EdgeCaseResult = EdgeCaseResult.NAN, ) -> None: - self.__listmetric_zeroTP_handling: dict[ - Metric, MetricZeroTPEdgeCaseHandling - ] = listmetric_zeroTP_handling + self.__listmetric_zeroTP_handling: dict[Metric, MetricZeroTPEdgeCaseHandling] = listmetric_zeroTP_handling self.__empty_list_std: EdgeCaseResult = empty_list_std def handle_zero_tp( @@ -143,12 +127,24 @@ def handle_zero_tp( num_pred_instances: int, num_ref_instances: int, ) -> tuple[bool, float | None]: + """_summary_ + + Args: + metric (Metric): _description_ + tp (int): _description_ + num_pred_instances (int): _description_ + num_ref_instances (int): _description_ + + Raises: + NotImplementedError: _description_ + + Returns: + tuple[bool, float | None]: if edge case, and its edge case value + """ if tp != 0: return False, EdgeCaseResult.NONE.value if metric not in self.__listmetric_zeroTP_handling: - raise NotImplementedError( - f"Metric {metric} encountered zero TP, but no edge handling available" - ) + raise NotImplementedError(f"Metric {metric} encountered zero TP, but no edge handling available") return self.__listmetric_zeroTP_handling[metric]( tp=tp, @@ -185,9 +181,7 @@ def _yaml_repr(cls, node) -> dict: print() # print(handler.get_metric_zero_tp_handle(ListMetric.IOU)) - r = handler.handle_zero_tp( - Metric.IOU, tp=0, num_pred_instances=1, num_ref_instances=1 - ) + r = handler.handle_zero_tp(Metric.IOU, tp=0, num_pred_instances=1, num_ref_instances=1) print(r) iou_test = MetricZeroTPEdgeCaseHandling( diff --git a/panoptica/utils/processing_pair.py b/panoptica/utils/processing_pair.py index c64b1e9..b89d5ad 100644 --- a/panoptica/utils/processing_pair.py +++ b/panoptica/utils/processing_pair.py @@ -5,6 +5,8 @@ from panoptica._functionals import _get_paired_crop from panoptica.utils import _count_unique_without_zeros, _unique_without_zeros from panoptica.utils.constants import _Enum_Compare +from dataclasses import dataclass +from panoptica.metrics import Metric uint_type: type = np.unsignedinteger int_type: type = np.integer @@ -23,9 +25,7 @@ class _ProcessingPair(ABC): _pred_labels: tuple[int, ...] n_dim: int - def __init__( - self, prediction_arr: np.ndarray, reference_arr: np.ndarray, dtype: type | None - ) -> None: + def __init__(self, prediction_arr: np.ndarray, reference_arr: np.ndarray, dtype: type | None) -> None: """Initializes a general Processing Pair Args: @@ -38,12 +38,8 @@ def __init__( self._reference_arr = reference_arr self.dtype = dtype self.n_dim = reference_arr.ndim - self._ref_labels: tuple[int, ...] = tuple( - _unique_without_zeros(reference_arr) - ) # type:ignore - self._pred_labels: tuple[int, ...] = tuple( - _unique_without_zeros(prediction_arr) - ) # type:ignore + self._ref_labels: tuple[int, ...] = tuple(_unique_without_zeros(reference_arr)) # type:ignore + self._pred_labels: tuple[int, ...] = tuple(_unique_without_zeros(prediction_arr)) # type:ignore self.crop: tuple[slice, ...] = None self.is_cropped: bool = False self.uncropped_shape: tuple[int, ...] = reference_arr.shape @@ -60,41 +56,25 @@ def crop_data(self, verbose: bool = False): self._prediction_arr = self._prediction_arr[self.crop] self._reference_arr = self._reference_arr[self.crop] - ( - print( - f"-- Cropped from {self.uncropped_shape} to {self._prediction_arr.shape}" - ) - if verbose - else None - ) + (print(f"-- Cropped from {self.uncropped_shape} to {self._prediction_arr.shape}") if verbose else None) self.is_cropped = True def uncrop_data(self, verbose: bool = False): if self.is_cropped == False: return - assert ( - self.uncropped_shape is not None - ), "Calling uncrop_data() without having cropped first" + assert self.uncropped_shape is not None, "Calling uncrop_data() without having cropped first" prediction_arr = np.zeros(self.uncropped_shape) prediction_arr[self.crop] = self._prediction_arr self._prediction_arr = prediction_arr reference_arr = np.zeros(self.uncropped_shape) reference_arr[self.crop] = self._reference_arr - ( - print( - f"-- Uncropped from {self._reference_arr.shape} to {self.uncropped_shape}" - ) - if verbose - else None - ) + (print(f"-- Uncropped from {self._reference_arr.shape} to {self.uncropped_shape}") if verbose else None) self._reference_arr = reference_arr self.is_cropped = False def set_dtype(self, type): - assert np.issubdtype( - type, int_type - ), "set_dtype: tried to set dtype to something other than integers" + assert np.issubdtype(type, int_type), "set_dtype: tried to set dtype to something other than integers" self._prediction_arr = self._prediction_arr.astype(type) self._reference_arr = self._reference_arr.astype(type) @@ -173,9 +153,7 @@ def copy(self): ) # type:ignore -def _check_array_integrity( - prediction_arr: np.ndarray, reference_arr: np.ndarray, dtype: type | None = None -): +def _check_array_integrity(prediction_arr: np.ndarray, reference_arr: np.ndarray, dtype: type | None = None): """ Check the integrity of two numpy arrays. @@ -197,12 +175,8 @@ def _check_array_integrity( assert isinstance(prediction_arr, np.ndarray) and isinstance( reference_arr, np.ndarray ), "prediction and/or reference are not numpy arrays" - assert ( - prediction_arr.shape == reference_arr.shape - ), f"shape mismatch, got {prediction_arr.shape},{reference_arr.shape}" - assert ( - prediction_arr.dtype == reference_arr.dtype - ), f"dtype mismatch, got {prediction_arr.dtype},{reference_arr.dtype}" + assert prediction_arr.shape == reference_arr.shape, f"shape mismatch, got {prediction_arr.shape},{reference_arr.shape}" + assert prediction_arr.dtype == reference_arr.dtype, f"dtype mismatch, got {prediction_arr.dtype},{reference_arr.dtype}" if dtype is not None: assert ( np.issubdtype(prediction_arr.dtype, dtype) @@ -285,15 +259,11 @@ def __init__( self.matched_instances = matched_instances if missed_reference_labels is None: - missed_reference_labels = list( - [i for i in self._ref_labels if i not in self._pred_labels] - ) + missed_reference_labels = list([i for i in self._ref_labels if i not in self._pred_labels]) self.missed_reference_labels = missed_reference_labels if missed_prediction_labels is None: - missed_prediction_labels = list( - [i for i in self._pred_labels if i not in self._ref_labels] - ) + missed_prediction_labels = list([i for i in self._pred_labels if i not in self._ref_labels]) self.missed_prediction_labels = missed_prediction_labels @property @@ -315,12 +285,20 @@ def copy(self): ) +@dataclass +class EvaluateInstancePair: + reference_arr: np.ndarray + prediction_arr: np.ndarray + num_pred_instances: int + num_ref_instances: int + tp: int + list_metrics: dict[Metric, list[float]] + + class InputType(_Enum_Compare): SEMANTIC = SemanticPair UNMATCHED_INSTANCE = UnmatchedInstancePair MATCHED_INSTANCE = MatchedInstancePair - def __call__( - self, prediction_arr: np.ndarray, reference_arr: np.ndarray - ) -> _ProcessingPair: + def __call__(self, prediction_arr: np.ndarray, reference_arr: np.ndarray) -> _ProcessingPair: return self.value(prediction_arr, reference_arr) diff --git a/unit_tests/test_panoptic_evaluator.py b/unit_tests/test_panoptic_evaluator.py index 5fe9c0e..a3009ae 100644 --- a/unit_tests/test_panoptic_evaluator.py +++ b/unit_tests/test_panoptic_evaluator.py @@ -70,7 +70,7 @@ def test_simple_evaluation_DSC_partial(self): expected_input=InputType.SEMANTIC, instance_approximator=ConnectedComponentsInstanceApproximator(), instance_matcher=NaiveThresholdMatching(matching_metric=Metric.DSC), - eval_metrics=[Metric.DSC], + instance_metrics=[Metric.DSC], ) result, debug_data = evaluator.evaluate(b, a)["ungrouped"] diff --git a/unit_tests/test_panoptic_result.py b/unit_tests/test_panoptic_result.py index f5087ea..883c604 100644 --- a/unit_tests/test_panoptic_result.py +++ b/unit_tests/test_panoptic_result.py @@ -28,6 +28,7 @@ def test_simple_evaluation(self): num_pred_instances=5, tp=0, list_metrics={Metric.IOU: []}, + global_metrics=[Metric.DSC], edge_case_handler=EdgeCaseHandler(), ) c.calculate_all(print_errors=True) @@ -85,9 +86,7 @@ def test_existing_metrics(self): def powerset(iterable): s = list(iterable) - return list( - chain.from_iterable(combinations(s, r) for r in range(len(s) + 1)) - ) + return list(chain.from_iterable(combinations(s, r) for r in range(len(s) + 1))) power_set = powerset([Metric.DSC, Metric.IOU, Metric.ASSD]) for m in power_set[1:]: From 5808e462b2f5821f9c47dae49b847f6fc61301b7 Mon Sep 17 00:00:00 2001 From: "brainless-bot[bot]" <153751247+brainless-bot[bot]@users.noreply.github.com> Date: Wed, 7 Aug 2024 15:42:48 +0000 Subject: [PATCH 2/2] Autoformat with black --- examples/example_spine_instance_config.py | 4 +- examples/example_spine_semantic.py | 4 +- panoptica/instance_evaluator.py | 18 ++++-- panoptica/metrics/metrics.py | 70 +++++++++++++++++------ panoptica/panoptica_evaluator.py | 43 ++++++++++---- panoptica/panoptica_result.py | 26 +++++++-- panoptica/utils/edge_case_handling.py | 36 +++++++++--- panoptica/utils/processing_pair.py | 60 ++++++++++++++----- unit_tests/test_panoptic_result.py | 4 +- 9 files changed, 205 insertions(+), 60 deletions(-) diff --git a/examples/example_spine_instance_config.py b/examples/example_spine_instance_config.py index 2855add..56c61e4 100644 --- a/examples/example_spine_instance_config.py +++ b/examples/example_spine_instance_config.py @@ -10,7 +10,9 @@ reference_mask = read_nifti(directory + "/spine_seg/matched_instance/ref.nii.gz") prediction_mask = read_nifti(directory + "/spine_seg/matched_instance/pred.nii.gz") -evaluator = Panoptica_Evaluator.load_from_config_name("panoptica_evaluator_unmatched_instance") +evaluator = Panoptica_Evaluator.load_from_config_name( + "panoptica_evaluator_unmatched_instance" +) with cProfile.Profile() as pr: diff --git a/examples/example_spine_semantic.py b/examples/example_spine_semantic.py index 2793592..a2e5a32 100644 --- a/examples/example_spine_semantic.py +++ b/examples/example_spine_semantic.py @@ -25,7 +25,9 @@ with cProfile.Profile() as pr: if __name__ == "__main__": - result, debug_data = evaluator.evaluate(prediction_mask, reference_mask)["ungrouped"] + result, debug_data = evaluator.evaluate(prediction_mask, reference_mask)[ + "ungrouped" + ] print(result) pr.dump_stats(directory + "/semantic_example.log") diff --git a/panoptica/instance_evaluator.py b/panoptica/instance_evaluator.py index 9b784e7..303abf0 100644 --- a/panoptica/instance_evaluator.py +++ b/panoptica/instance_evaluator.py @@ -28,7 +28,9 @@ def evaluate_matched_instance( >>> result = map_instance_labels(unmatched_instance_pair, labelmap) """ if decision_metric is not None: - assert decision_metric.name in [v.name for v in eval_metrics], "decision metric not contained in eval_metrics" + assert decision_metric.name in [ + v.name for v in eval_metrics + ], "decision metric not contained in eval_metrics" assert decision_threshold is not None, "decision metric set but no threshold" # Initialize variables for True Positives (tp) tp = len(matched_instance_pair.matched_instances) @@ -40,13 +42,21 @@ def evaluate_matched_instance( ) ref_matched_labels = matched_instance_pair.matched_instances - instance_pairs = [(reference_arr, prediction_arr, ref_idx, eval_metrics) for ref_idx in ref_matched_labels] + instance_pairs = [ + (reference_arr, prediction_arr, ref_idx, eval_metrics) + for ref_idx in ref_matched_labels + ] with Pool() as pool: - metric_dicts: list[dict[Metric, float]] = pool.starmap(_evaluate_instance, instance_pairs) + metric_dicts: list[dict[Metric, float]] = pool.starmap( + _evaluate_instance, instance_pairs + ) for metric_dict in metric_dicts: if decision_metric is None or ( - decision_threshold is not None and decision_metric.score_beats_threshold(metric_dict[decision_metric], decision_threshold) + decision_threshold is not None + and decision_metric.score_beats_threshold( + metric_dict[decision_metric], decision_threshold + ) ): for k, v in metric_dict.items(): score_dict[k].append(v) diff --git a/panoptica/metrics/metrics.py b/panoptica/metrics/metrics.py index 4f02e48..5cbce45 100644 --- a/panoptica/metrics/metrics.py +++ b/panoptica/metrics/metrics.py @@ -40,7 +40,9 @@ def __call__( reference_arr = reference_arr.copy() == ref_instance_idx if isinstance(pred_instance_idx, int): pred_instance_idx = [pred_instance_idx] - prediction_arr = np.isin(prediction_arr.copy(), pred_instance_idx) # type:ignore + prediction_arr = np.isin( + prediction_arr.copy(), pred_instance_idx + ) # type:ignore return self._metric_function(reference_arr, prediction_arr, *args, **kwargs) def __eq__(self, __value: object) -> bool: @@ -64,8 +66,12 @@ def __hash__(self) -> int: def increasing(self): return not self.decreasing - def score_beats_threshold(self, matching_score: float, matching_threshold: float) -> bool: - return (self.increasing and matching_score >= matching_threshold) or (self.decreasing and matching_score <= matching_threshold) + def score_beats_threshold( + self, matching_score: float, matching_threshold: float + ) -> bool: + return (self.increasing and matching_score >= matching_threshold) or ( + self.decreasing and matching_score <= matching_threshold + ) class DirectValueMeta(EnumMeta): @@ -88,9 +94,19 @@ class Metric(_Enum_Compare): DSC = _Metric("DSC", "Dice", False, _compute_instance_volumetric_dice) IOU = _Metric("IOU", "Intersection over Union", False, _compute_instance_iou) - ASSD = _Metric("ASSD", "Average Symmetric Surface Distance", True, _compute_instance_average_symmetric_surface_distance) + ASSD = _Metric( + "ASSD", + "Average Symmetric Surface Distance", + True, + _compute_instance_average_symmetric_surface_distance, + ) clDSC = _Metric("clDSC", "Centerline Dice", False, _compute_centerline_dice) - RVD = _Metric("RVD", "Relative Volume Difference", True, _compute_instance_relative_volume_difference) + RVD = _Metric( + "RVD", + "Relative Volume Difference", + True, + _compute_instance_relative_volume_difference, + ) # ST = _Metric("ST", False, _compute_instance_segmentation_tendency) def __call__( @@ -122,7 +138,9 @@ def __call__( **kwargs, ) - def score_beats_threshold(self, matching_score: float, matching_threshold: float) -> bool: + def score_beats_threshold( + self, matching_score: float, matching_threshold: float + ) -> bool: """Calculates whether a score beats a specified threshold Args: @@ -132,7 +150,9 @@ def score_beats_threshold(self, matching_score: float, matching_threshold: float Returns: bool: True if the matching_score beats the threshold, False otherwise. """ - return (self.increasing and matching_score >= matching_threshold) or (self.decreasing and matching_score <= matching_threshold) + return (self.increasing and matching_score >= matching_threshold) or ( + self.decreasing and matching_score <= matching_threshold + ) @property def name(self): @@ -229,7 +249,9 @@ def __call__(self, result_obj: "PanopticaResult") -> Any: # ERROR if self._error: if self._error_obj is None: - self._error_obj = MetricCouldNotBeComputedException(f"Metric {self.id} requested, but could not be computed") + self._error_obj = MetricCouldNotBeComputedException( + f"Metric {self.id} requested, but could not be computed" + ) raise self._error_obj # Already calculated? if self._was_calculated: @@ -237,8 +259,12 @@ def __call__(self, result_obj: "PanopticaResult") -> Any: # Calculate it try: - assert not self._was_calculated, f"Metric {self.id} was called to compute, but is set to have been already calculated" - assert self._calc_func is not None, f"Metric {self.id} was called to compute, but has no calculation function set" + assert ( + not self._was_calculated + ), f"Metric {self.id} was called to compute, but is set to have been already calculated" + assert ( + self._calc_func is not None + ), f"Metric {self.id} was called to compute, but has no calculation function set" value = self._calc_func(result_obj) except MetricCouldNotBeComputedException as e: value = e @@ -283,20 +309,32 @@ def __init__( else: self.AVG = None if self.ALL is None else np.average(self.ALL) self.SUM = None if self.ALL is None else np.sum(self.ALL) - self.MIN = None if self.ALL is None or len(self.ALL) == 0 else np.min(self.ALL) - self.MAX = None if self.ALL is None or len(self.ALL) == 0 else np.max(self.ALL) - - self.STD = None if self.ALL is None else empty_list_std if len(self.ALL) == 0 else np.std(self.ALL) + self.MIN = ( + None if self.ALL is None or len(self.ALL) == 0 else np.min(self.ALL) + ) + self.MAX = ( + None if self.ALL is None or len(self.ALL) == 0 else np.max(self.ALL) + ) + + self.STD = ( + None + if self.ALL is None + else empty_list_std if len(self.ALL) == 0 else np.std(self.ALL) + ) def __getitem__(self, mode: MetricMode | str): if self.error: - raise MetricCouldNotBeComputedException(f"Metric {self.id} has not been calculated, add it to your eval_metrics") + raise MetricCouldNotBeComputedException( + f"Metric {self.id} has not been calculated, add it to your eval_metrics" + ) if isinstance(mode, MetricMode): mode = mode.name if hasattr(self, mode): return getattr(self, mode) else: - raise MetricCouldNotBeComputedException(f"List_Metric {self.id} does not contain {mode} member") + raise MetricCouldNotBeComputedException( + f"List_Metric {self.id} does not contain {mode} member" + ) if __name__ == "__main__": diff --git a/panoptica/panoptica_evaluator.py b/panoptica/panoptica_evaluator.py index 2ffdebf..f0522f8 100644 --- a/panoptica/panoptica_evaluator.py +++ b/panoptica/panoptica_evaluator.py @@ -31,7 +31,12 @@ def __init__( instance_matcher: InstanceMatchingAlgorithm | None = None, edge_case_handler: EdgeCaseHandler | None = None, segmentation_class_groups: SegmentationClassGroups | None = None, - instance_metrics: list[Metric] = [Metric.DSC, Metric.IOU, Metric.ASSD, Metric.RVD], + instance_metrics: list[Metric] = [ + Metric.DSC, + Metric.IOU, + Metric.ASSD, + Metric.RVD, + ], global_metrics: list[Metric] = [Metric.DSC], decision_metric: Metric | None = None, decision_threshold: float | None = None, @@ -65,9 +70,13 @@ def __init__( self.__segmentation_class_groups = segmentation_class_groups - self.__edge_case_handler = edge_case_handler if edge_case_handler is not None else EdgeCaseHandler() + self.__edge_case_handler = ( + edge_case_handler if edge_case_handler is not None else EdgeCaseHandler() + ) if self.__decision_metric is not None: - assert self.__decision_threshold is not None, "decision metric set but no decision threshold for it" + assert ( + self.__decision_threshold is not None + ), "decision metric set but no decision threshold for it" # self.__log_times = log_times self.__verbose = verbose @@ -98,7 +107,9 @@ def evaluate( verbose: bool | None = None, ) -> dict[str, tuple[PanopticaResult, dict[str, _ProcessingPair]]]: processing_pair = self.__expected_input(prediction_arr, reference_arr) - assert isinstance(processing_pair, self.__expected_input.value), f"input not of expected type {self.__expected_input}" + assert isinstance( + processing_pair, self.__expected_input.value + ), f"input not of expected type {self.__expected_input}" if self.__segmentation_class_groups is None: return { @@ -118,8 +129,12 @@ def evaluate( ) } - self.__segmentation_class_groups.has_defined_labels_for(processing_pair.prediction_arr, raise_error=True) - self.__segmentation_class_groups.has_defined_labels_for(processing_pair.reference_arr, raise_error=True) + self.__segmentation_class_groups.has_defined_labels_for( + processing_pair.prediction_arr, raise_error=True + ) + self.__segmentation_class_groups.has_defined_labels_for( + processing_pair.reference_arr, raise_error=True + ) result_grouped = {} for group_name, label_group in self.__segmentation_class_groups.items(): @@ -131,7 +146,9 @@ def evaluate( single_instance_mode = label_group.single_instance processing_pair_grouped = processing_pair.__class__(prediction_arr=prediction_arr_grouped, reference_arr=reference_arr_grouped) # type: ignore decision_threshold = self.__decision_threshold - if single_instance_mode and not isinstance(processing_pair, MatchedInstancePair): + if single_instance_mode and not isinstance( + processing_pair, MatchedInstancePair + ): processing_pair_grouped = MatchedInstancePair( prediction_arr=processing_pair_grouped.prediction_arr, reference_arr=processing_pair_grouped.reference_arr, @@ -156,7 +173,9 @@ def evaluate( def panoptic_evaluate( - processing_pair: SemanticPair | UnmatchedInstancePair | MatchedInstancePair | PanopticaResult, + processing_pair: ( + SemanticPair | UnmatchedInstancePair | MatchedInstancePair | PanopticaResult + ), instance_approximator: InstanceApproximator | None = None, instance_matcher: InstanceMatchingAlgorithm | None = None, instance_metrics: list[Metric] = [Metric.DSC, Metric.IOU, Metric.ASSD], @@ -213,7 +232,9 @@ def panoptic_evaluate( processing_pair.crop_data() if isinstance(processing_pair, SemanticPair): - assert instance_approximator is not None, "Got SemanticPair but not InstanceApproximator" + assert ( + instance_approximator is not None + ), "Got SemanticPair but not InstanceApproximator" if verbose: print("-- Got SemanticPair, will approximate instances") start = perf_counter() @@ -234,7 +255,9 @@ def panoptic_evaluate( if isinstance(processing_pair, UnmatchedInstancePair): if verbose: print("-- Got UnmatchedInstancePair, will match instances") - assert instance_matcher is not None, "Got UnmatchedInstancePair but not InstanceMatchingAlgorithm" + assert ( + instance_matcher is not None + ), "Got UnmatchedInstancePair but not InstanceMatchingAlgorithm" start = perf_counter() processing_pair = instance_matcher.match_instances( processing_pair, diff --git a/panoptica/panoptica_result.py b/panoptica/panoptica_result.py index e87a33b..ef631f9 100644 --- a/panoptica/panoptica_result.py +++ b/panoptica/panoptica_result.py @@ -249,7 +249,9 @@ def __init__( num_pred_instances=self.num_pred_instances, num_ref_instances=self.num_ref_instances, ) - self._list_metrics[m] = Evaluation_List_Metric(m, empty_list_std, list_metrics[m], is_edge_case, edge_case_result) + self._list_metrics[m] = Evaluation_List_Metric( + m, empty_list_std, list_metrics[m], is_edge_case, edge_case_result + ) # even if not available, set the global vars self._add_metric( f"global_bin_{m.name.lower()}", @@ -327,13 +329,19 @@ def __str__(self) -> str: return text def to_dict(self) -> dict: - return {k: getattr(self, v.id) for k, v in self._evaluation_metrics.items() if (v._error == False and v._was_calculated)} + return { + k: getattr(self, v.id) + for k, v in self._evaluation_metrics.items() + if (v._error == False and v._was_calculated) + } def get_list_metric(self, metric: Metric, mode: MetricMode): if metric in self._list_metrics: return self._list_metrics[metric][mode] else: - raise MetricCouldNotBeComputedException(f"{metric} could not be found, have you set it in eval_metrics during evaluation?") + raise MetricCouldNotBeComputedException( + f"{metric} could not be found, have you set it in eval_metrics during evaluation?" + ) def _calc_metric(self, metric_name: str, supress_error: bool = False): if metric_name in self._evaluation_metrics: @@ -349,7 +357,9 @@ def _calc_metric(self, metric_name: str, supress_error: bool = False): self._evaluation_metrics[metric_name]._was_calculated = True return value else: - raise MetricCouldNotBeComputedException(f"could not find metric with name {metric_name}") + raise MetricCouldNotBeComputedException( + f"could not find metric with name {metric_name}" + ) def __getattribute__(self, __name: str) -> Any: attr = None @@ -362,7 +372,9 @@ def __getattribute__(self, __name: str) -> Any: raise e if attr is None: if self._evaluation_metrics[__name]._error: - raise MetricCouldNotBeComputedException(f"Requested metric {__name} that could not be computed") + raise MetricCouldNotBeComputedException( + f"Requested metric {__name} that could not be computed" + ) elif not self._evaluation_metrics[__name]._was_calculated: value = self._calc_metric(__name) setattr(self, __name, value) @@ -488,7 +500,9 @@ def function_template(res: PanopticaResult): if metric not in res._global_metrics: raise MetricCouldNotBeComputedException(f"Global Metric {metric} not set") if res.tp == 0: - is_edgecase, result = res._edge_case_handler.handle_zero_tp(metric, res.tp, res.num_pred_instances, res.num_ref_instances) + is_edgecase, result = res._edge_case_handler.handle_zero_tp( + metric, res.tp, res.num_pred_instances, res.num_ref_instances + ) if is_edgecase: return result pred_binary = res._prediction_arr.copy() diff --git a/panoptica/utils/edge_case_handling.py b/panoptica/utils/edge_case_handling.py index d6a0e02..b88f49e 100644 --- a/panoptica/utils/edge_case_handling.py +++ b/panoptica/utils/edge_case_handling.py @@ -41,12 +41,26 @@ def __init__( self._default_result = default_result self._edgecase_dict: dict[EdgeCaseZeroTP, EdgeCaseResult] = {} - self._edgecase_dict[EdgeCaseZeroTP.EMPTY_PRED] = empty_prediction_result if empty_prediction_result is not None else default_result - self._edgecase_dict[EdgeCaseZeroTP.EMPTY_REF] = empty_reference_result if empty_reference_result is not None else default_result - self._edgecase_dict[EdgeCaseZeroTP.NO_INSTANCES] = no_instances_result if no_instances_result is not None else default_result - self._edgecase_dict[EdgeCaseZeroTP.NORMAL] = normal if normal is not None else default_result + self._edgecase_dict[EdgeCaseZeroTP.EMPTY_PRED] = ( + empty_prediction_result + if empty_prediction_result is not None + else default_result + ) + self._edgecase_dict[EdgeCaseZeroTP.EMPTY_REF] = ( + empty_reference_result + if empty_reference_result is not None + else default_result + ) + self._edgecase_dict[EdgeCaseZeroTP.NO_INSTANCES] = ( + no_instances_result if no_instances_result is not None else default_result + ) + self._edgecase_dict[EdgeCaseZeroTP.NORMAL] = ( + normal if normal is not None else default_result + ) - def __call__(self, tp: int, num_pred_instances, num_ref_instances) -> tuple[bool, float | None]: + def __call__( + self, tp: int, num_pred_instances, num_ref_instances + ) -> tuple[bool, float | None]: if tp != 0: return False, EdgeCaseResult.NONE.value # @@ -117,7 +131,9 @@ def __init__( }, empty_list_std: EdgeCaseResult = EdgeCaseResult.NAN, ) -> None: - self.__listmetric_zeroTP_handling: dict[Metric, MetricZeroTPEdgeCaseHandling] = listmetric_zeroTP_handling + self.__listmetric_zeroTP_handling: dict[ + Metric, MetricZeroTPEdgeCaseHandling + ] = listmetric_zeroTP_handling self.__empty_list_std: EdgeCaseResult = empty_list_std def handle_zero_tp( @@ -144,7 +160,9 @@ def handle_zero_tp( if tp != 0: return False, EdgeCaseResult.NONE.value if metric not in self.__listmetric_zeroTP_handling: - raise NotImplementedError(f"Metric {metric} encountered zero TP, but no edge handling available") + raise NotImplementedError( + f"Metric {metric} encountered zero TP, but no edge handling available" + ) return self.__listmetric_zeroTP_handling[metric]( tp=tp, @@ -181,7 +199,9 @@ def _yaml_repr(cls, node) -> dict: print() # print(handler.get_metric_zero_tp_handle(ListMetric.IOU)) - r = handler.handle_zero_tp(Metric.IOU, tp=0, num_pred_instances=1, num_ref_instances=1) + r = handler.handle_zero_tp( + Metric.IOU, tp=0, num_pred_instances=1, num_ref_instances=1 + ) print(r) iou_test = MetricZeroTPEdgeCaseHandling( diff --git a/panoptica/utils/processing_pair.py b/panoptica/utils/processing_pair.py index b89d5ad..e0afa86 100644 --- a/panoptica/utils/processing_pair.py +++ b/panoptica/utils/processing_pair.py @@ -25,7 +25,9 @@ class _ProcessingPair(ABC): _pred_labels: tuple[int, ...] n_dim: int - def __init__(self, prediction_arr: np.ndarray, reference_arr: np.ndarray, dtype: type | None) -> None: + def __init__( + self, prediction_arr: np.ndarray, reference_arr: np.ndarray, dtype: type | None + ) -> None: """Initializes a general Processing Pair Args: @@ -38,8 +40,12 @@ def __init__(self, prediction_arr: np.ndarray, reference_arr: np.ndarray, dtype: self._reference_arr = reference_arr self.dtype = dtype self.n_dim = reference_arr.ndim - self._ref_labels: tuple[int, ...] = tuple(_unique_without_zeros(reference_arr)) # type:ignore - self._pred_labels: tuple[int, ...] = tuple(_unique_without_zeros(prediction_arr)) # type:ignore + self._ref_labels: tuple[int, ...] = tuple( + _unique_without_zeros(reference_arr) + ) # type:ignore + self._pred_labels: tuple[int, ...] = tuple( + _unique_without_zeros(prediction_arr) + ) # type:ignore self.crop: tuple[slice, ...] = None self.is_cropped: bool = False self.uncropped_shape: tuple[int, ...] = reference_arr.shape @@ -56,25 +62,41 @@ def crop_data(self, verbose: bool = False): self._prediction_arr = self._prediction_arr[self.crop] self._reference_arr = self._reference_arr[self.crop] - (print(f"-- Cropped from {self.uncropped_shape} to {self._prediction_arr.shape}") if verbose else None) + ( + print( + f"-- Cropped from {self.uncropped_shape} to {self._prediction_arr.shape}" + ) + if verbose + else None + ) self.is_cropped = True def uncrop_data(self, verbose: bool = False): if self.is_cropped == False: return - assert self.uncropped_shape is not None, "Calling uncrop_data() without having cropped first" + assert ( + self.uncropped_shape is not None + ), "Calling uncrop_data() without having cropped first" prediction_arr = np.zeros(self.uncropped_shape) prediction_arr[self.crop] = self._prediction_arr self._prediction_arr = prediction_arr reference_arr = np.zeros(self.uncropped_shape) reference_arr[self.crop] = self._reference_arr - (print(f"-- Uncropped from {self._reference_arr.shape} to {self.uncropped_shape}") if verbose else None) + ( + print( + f"-- Uncropped from {self._reference_arr.shape} to {self.uncropped_shape}" + ) + if verbose + else None + ) self._reference_arr = reference_arr self.is_cropped = False def set_dtype(self, type): - assert np.issubdtype(type, int_type), "set_dtype: tried to set dtype to something other than integers" + assert np.issubdtype( + type, int_type + ), "set_dtype: tried to set dtype to something other than integers" self._prediction_arr = self._prediction_arr.astype(type) self._reference_arr = self._reference_arr.astype(type) @@ -153,7 +175,9 @@ def copy(self): ) # type:ignore -def _check_array_integrity(prediction_arr: np.ndarray, reference_arr: np.ndarray, dtype: type | None = None): +def _check_array_integrity( + prediction_arr: np.ndarray, reference_arr: np.ndarray, dtype: type | None = None +): """ Check the integrity of two numpy arrays. @@ -175,8 +199,12 @@ def _check_array_integrity(prediction_arr: np.ndarray, reference_arr: np.ndarray assert isinstance(prediction_arr, np.ndarray) and isinstance( reference_arr, np.ndarray ), "prediction and/or reference are not numpy arrays" - assert prediction_arr.shape == reference_arr.shape, f"shape mismatch, got {prediction_arr.shape},{reference_arr.shape}" - assert prediction_arr.dtype == reference_arr.dtype, f"dtype mismatch, got {prediction_arr.dtype},{reference_arr.dtype}" + assert ( + prediction_arr.shape == reference_arr.shape + ), f"shape mismatch, got {prediction_arr.shape},{reference_arr.shape}" + assert ( + prediction_arr.dtype == reference_arr.dtype + ), f"dtype mismatch, got {prediction_arr.dtype},{reference_arr.dtype}" if dtype is not None: assert ( np.issubdtype(prediction_arr.dtype, dtype) @@ -259,11 +287,15 @@ def __init__( self.matched_instances = matched_instances if missed_reference_labels is None: - missed_reference_labels = list([i for i in self._ref_labels if i not in self._pred_labels]) + missed_reference_labels = list( + [i for i in self._ref_labels if i not in self._pred_labels] + ) self.missed_reference_labels = missed_reference_labels if missed_prediction_labels is None: - missed_prediction_labels = list([i for i in self._pred_labels if i not in self._ref_labels]) + missed_prediction_labels = list( + [i for i in self._pred_labels if i not in self._ref_labels] + ) self.missed_prediction_labels = missed_prediction_labels @property @@ -300,5 +332,7 @@ class InputType(_Enum_Compare): UNMATCHED_INSTANCE = UnmatchedInstancePair MATCHED_INSTANCE = MatchedInstancePair - def __call__(self, prediction_arr: np.ndarray, reference_arr: np.ndarray) -> _ProcessingPair: + def __call__( + self, prediction_arr: np.ndarray, reference_arr: np.ndarray + ) -> _ProcessingPair: return self.value(prediction_arr, reference_arr) diff --git a/unit_tests/test_panoptic_result.py b/unit_tests/test_panoptic_result.py index 883c604..200d084 100644 --- a/unit_tests/test_panoptic_result.py +++ b/unit_tests/test_panoptic_result.py @@ -86,7 +86,9 @@ def test_existing_metrics(self): def powerset(iterable): s = list(iterable) - return list(chain.from_iterable(combinations(s, r) for r in range(len(s) + 1))) + return list( + chain.from_iterable(combinations(s, r) for r in range(len(s) + 1)) + ) power_set = powerset([Metric.DSC, Metric.IOU, Metric.ASSD]) for m in power_set[1:]: