diff --git a/examples/example_spine_instance.py b/examples/example_spine_instance.py index bcea628..8a38960 100644 --- a/examples/example_spine_instance.py +++ b/examples/example_spine_instance.py @@ -14,7 +14,7 @@ evaluator = Panoptica_Evaluator( expected_input=InputType.MATCHED_INSTANCE, - eval_metrics=[Metric.DSC, Metric.IOU], + instance_metrics=[Metric.DSC, Metric.IOU], segmentation_class_groups=SegmentationClassGroups( { "vertebra": LabelGroup([i for i in range(1, 10)]), diff --git a/examples/example_spine_instance_config.py b/examples/example_spine_instance_config.py index 56c61e4..2855add 100644 --- a/examples/example_spine_instance_config.py +++ b/examples/example_spine_instance_config.py @@ -10,9 +10,7 @@ reference_mask = read_nifti(directory + "/spine_seg/matched_instance/ref.nii.gz") prediction_mask = read_nifti(directory + "/spine_seg/matched_instance/pred.nii.gz") -evaluator = Panoptica_Evaluator.load_from_config_name( - "panoptica_evaluator_unmatched_instance" -) +evaluator = Panoptica_Evaluator.load_from_config_name("panoptica_evaluator_unmatched_instance") with cProfile.Profile() as pr: diff --git a/examples/example_spine_semantic.py b/examples/example_spine_semantic.py index a2e5a32..2793592 100644 --- a/examples/example_spine_semantic.py +++ b/examples/example_spine_semantic.py @@ -25,9 +25,7 @@ with cProfile.Profile() as pr: if __name__ == "__main__": - result, debug_data = evaluator.evaluate(prediction_mask, reference_mask)[ - "ungrouped" - ] + result, debug_data = evaluator.evaluate(prediction_mask, reference_mask)["ungrouped"] print(result) pr.dump_stats(directory + "/semantic_example.log") diff --git a/panoptica/configs/panoptica_evaluator_unmatched_instance.yaml b/panoptica/configs/panoptica_evaluator_unmatched_instance.yaml index ad8d9b0..55fbbc3 100644 --- a/panoptica/configs/panoptica_evaluator_unmatched_instance.yaml +++ b/panoptica/configs/panoptica_evaluator_unmatched_instance.yaml @@ -19,7 +19,8 @@ edge_case_handler: !EdgeCaseHandler !Metric RVD: !MetricZeroTPEdgeCaseHandling {empty_prediction_result: !EdgeCaseResult NAN, empty_reference_result: !EdgeCaseResult NAN, no_instances_result: !EdgeCaseResult NAN, normal: !EdgeCaseResult NAN} -eval_metrics: [!Metric DSC, !Metric IOU] +instance_metrics: [!Metric DSC, !Metric IOU] +global_metrics: [!Metric DSC, !Metric RVD] expected_input: !InputType UNMATCHED_INSTANCE instance_approximator: null instance_matcher: !NaiveThresholdMatching {allow_many_to_one: false, matching_metric: !Metric IOU, diff --git a/panoptica/instance_evaluator.py b/panoptica/instance_evaluator.py index 383df0e..9b784e7 100644 --- a/panoptica/instance_evaluator.py +++ b/panoptica/instance_evaluator.py @@ -1,11 +1,8 @@ from multiprocessing import Pool - import numpy as np from panoptica.metrics import Metric -from panoptica.panoptica_result import PanopticaResult -from panoptica.utils import EdgeCaseHandler -from panoptica.utils.processing_pair import MatchedInstancePair +from panoptica.utils.processing_pair import MatchedInstancePair, EvaluateInstancePair def evaluate_matched_instance( @@ -13,9 +10,8 @@ def evaluate_matched_instance( eval_metrics: list[Metric] = [Metric.DSC, Metric.IOU, Metric.ASSD], decision_metric: Metric | None = Metric.IOU, decision_threshold: float | None = None, - edge_case_handler: EdgeCaseHandler | None = None, **kwargs, -) -> PanopticaResult: +) -> EvaluateInstancePair: """ Map instance labels based on the provided labelmap and create a MatchedInstancePair. @@ -31,12 +27,8 @@ def evaluate_matched_instance( >>> labelmap = [([1, 2], [3, 4]), ([5], [6])] >>> result = map_instance_labels(unmatched_instance_pair, labelmap) """ - if edge_case_handler is None: - edge_case_handler = EdgeCaseHandler() if decision_metric is not None: - assert decision_metric.name in [ - v.name for v in eval_metrics - ], "decision metric not contained in eval_metrics" + assert decision_metric.name in [v.name for v in eval_metrics], "decision metric not contained in eval_metrics" assert decision_threshold is not None, "decision metric set but no threshold" # Initialize variables for True Positives (tp) tp = len(matched_instance_pair.matched_instances) @@ -48,34 +40,25 @@ def evaluate_matched_instance( ) ref_matched_labels = matched_instance_pair.matched_instances - instance_pairs = [ - (reference_arr, prediction_arr, ref_idx, eval_metrics) - for ref_idx in ref_matched_labels - ] + instance_pairs = [(reference_arr, prediction_arr, ref_idx, eval_metrics) for ref_idx in ref_matched_labels] with Pool() as pool: - metric_dicts: list[dict[Metric, float]] = pool.starmap( - _evaluate_instance, instance_pairs - ) + metric_dicts: list[dict[Metric, float]] = pool.starmap(_evaluate_instance, instance_pairs) for metric_dict in metric_dicts: if decision_metric is None or ( - decision_threshold is not None - and decision_metric.score_beats_threshold( - metric_dict[decision_metric], decision_threshold - ) + decision_threshold is not None and decision_metric.score_beats_threshold(metric_dict[decision_metric], decision_threshold) ): for k, v in metric_dict.items(): score_dict[k].append(v) # Create and return the PanopticaResult object with computed metrics - return PanopticaResult( + return EvaluateInstancePair( reference_arr=matched_instance_pair.reference_arr, prediction_arr=matched_instance_pair.prediction_arr, num_pred_instances=matched_instance_pair.n_prediction_instance, num_ref_instances=matched_instance_pair.n_reference_instance, tp=tp, list_metrics=score_dict, - edge_case_handler=edge_case_handler, ) diff --git a/panoptica/metrics/metrics.py b/panoptica/metrics/metrics.py index 2d6a4c0..4f02e48 100644 --- a/panoptica/metrics/metrics.py +++ b/panoptica/metrics/metrics.py @@ -23,6 +23,7 @@ class _Metric: """A Metric class containing a name, whether higher or lower values is better, and a function to calculate that metric between two instances in an array""" name: str + long_name: str decreasing: bool _metric_function: Callable @@ -39,9 +40,7 @@ def __call__( reference_arr = reference_arr.copy() == ref_instance_idx if isinstance(pred_instance_idx, int): pred_instance_idx = [pred_instance_idx] - prediction_arr = np.isin( - prediction_arr.copy(), pred_instance_idx - ) # type:ignore + prediction_arr = np.isin(prediction_arr.copy(), pred_instance_idx) # type:ignore return self._metric_function(reference_arr, prediction_arr, *args, **kwargs) def __eq__(self, __value: object) -> bool: @@ -65,12 +64,8 @@ def __hash__(self) -> int: def increasing(self): return not self.decreasing - def score_beats_threshold( - self, matching_score: float, matching_threshold: float - ) -> bool: - return (self.increasing and matching_score >= matching_threshold) or ( - self.decreasing and matching_score <= matching_threshold - ) + def score_beats_threshold(self, matching_score: float, matching_threshold: float) -> bool: + return (self.increasing and matching_score >= matching_threshold) or (self.decreasing and matching_score <= matching_threshold) class DirectValueMeta(EnumMeta): @@ -91,11 +86,11 @@ class Metric(_Enum_Compare): _type_: _description_ """ - DSC = _Metric("DSC", False, _compute_instance_volumetric_dice) - IOU = _Metric("IOU", False, _compute_instance_iou) - ASSD = _Metric("ASSD", True, _compute_instance_average_symmetric_surface_distance) - clDSC = _Metric("clDSC", False, _compute_centerline_dice) - RVD = _Metric("RVD", True, _compute_instance_relative_volume_difference) + DSC = _Metric("DSC", "Dice", False, _compute_instance_volumetric_dice) + IOU = _Metric("IOU", "Intersection over Union", False, _compute_instance_iou) + ASSD = _Metric("ASSD", "Average Symmetric Surface Distance", True, _compute_instance_average_symmetric_surface_distance) + clDSC = _Metric("clDSC", "Centerline Dice", False, _compute_centerline_dice) + RVD = _Metric("RVD", "Relative Volume Difference", True, _compute_instance_relative_volume_difference) # ST = _Metric("ST", False, _compute_instance_segmentation_tendency) def __call__( @@ -127,9 +122,7 @@ def __call__( **kwargs, ) - def score_beats_threshold( - self, matching_score: float, matching_threshold: float - ) -> bool: + def score_beats_threshold(self, matching_score: float, matching_threshold: float) -> bool: """Calculates whether a score beats a specified threshold Args: @@ -139,9 +132,7 @@ def score_beats_threshold( Returns: bool: True if the matching_score beats the threshold, False otherwise. """ - return (self.increasing and matching_score >= matching_threshold) or ( - self.decreasing and matching_score <= matching_threshold - ) + return (self.increasing and matching_score >= matching_threshold) or (self.decreasing and matching_score <= matching_threshold) @property def name(self): @@ -238,9 +229,7 @@ def __call__(self, result_obj: "PanopticaResult") -> Any: # ERROR if self._error: if self._error_obj is None: - self._error_obj = MetricCouldNotBeComputedException( - f"Metric {self.id} requested, but could not be computed" - ) + self._error_obj = MetricCouldNotBeComputedException(f"Metric {self.id} requested, but could not be computed") raise self._error_obj # Already calculated? if self._was_calculated: @@ -248,12 +237,8 @@ def __call__(self, result_obj: "PanopticaResult") -> Any: # Calculate it try: - assert ( - not self._was_calculated - ), f"Metric {self.id} was called to compute, but is set to have been already calculated" - assert ( - self._calc_func is not None - ), f"Metric {self.id} was called to compute, but has no calculation function set" + assert not self._was_calculated, f"Metric {self.id} was called to compute, but is set to have been already calculated" + assert self._calc_func is not None, f"Metric {self.id} was called to compute, but has no calculation function set" value = self._calc_func(result_obj) except MetricCouldNotBeComputedException as e: value = e @@ -298,32 +283,20 @@ def __init__( else: self.AVG = None if self.ALL is None else np.average(self.ALL) self.SUM = None if self.ALL is None else np.sum(self.ALL) - self.MIN = ( - None if self.ALL is None or len(self.ALL) == 0 else np.min(self.ALL) - ) - self.MAX = ( - None if self.ALL is None or len(self.ALL) == 0 else np.max(self.ALL) - ) - - self.STD = ( - None - if self.ALL is None - else empty_list_std if len(self.ALL) == 0 else np.std(self.ALL) - ) + self.MIN = None if self.ALL is None or len(self.ALL) == 0 else np.min(self.ALL) + self.MAX = None if self.ALL is None or len(self.ALL) == 0 else np.max(self.ALL) + + self.STD = None if self.ALL is None else empty_list_std if len(self.ALL) == 0 else np.std(self.ALL) def __getitem__(self, mode: MetricMode | str): if self.error: - raise MetricCouldNotBeComputedException( - f"Metric {self.id} has not been calculated, add it to your eval_metrics" - ) + raise MetricCouldNotBeComputedException(f"Metric {self.id} has not been calculated, add it to your eval_metrics") if isinstance(mode, MetricMode): mode = mode.name if hasattr(self, mode): return getattr(self, mode) else: - raise MetricCouldNotBeComputedException( - f"List_Metric {self.id} does not contain {mode} member" - ) + raise MetricCouldNotBeComputedException(f"List_Metric {self.id} does not contain {mode} member") if __name__ == "__main__": diff --git a/panoptica/panoptica_evaluator.py b/panoptica/panoptica_evaluator.py index fd5617c..2ffdebf 100644 --- a/panoptica/panoptica_evaluator.py +++ b/panoptica/panoptica_evaluator.py @@ -15,6 +15,7 @@ UnmatchedInstancePair, _ProcessingPair, InputType, + EvaluateInstancePair, ) import numpy as np from panoptica.utils.config import SupportsConfig @@ -30,7 +31,8 @@ def __init__( instance_matcher: InstanceMatchingAlgorithm | None = None, edge_case_handler: EdgeCaseHandler | None = None, segmentation_class_groups: SegmentationClassGroups | None = None, - eval_metrics: list[Metric] = [Metric.DSC, Metric.IOU, Metric.ASSD, Metric.RVD], + instance_metrics: list[Metric] = [Metric.DSC, Metric.IOU, Metric.ASSD, Metric.RVD], + global_metrics: list[Metric] = [Metric.DSC], decision_metric: Metric | None = None, decision_threshold: float | None = None, log_times: bool = False, @@ -43,24 +45,29 @@ def __init__( instance_approximator (InstanceApproximator | None, optional): Determines which instance approximator is used if necessary. Defaults to None. instance_matcher (InstanceMatchingAlgorithm | None, optional): Determines which instance matching algorithm is used if necessary. Defaults to None. iou_threshold (float, optional): Iou Threshold for evaluation. Defaults to 0.5. + edge_case_handler (edge_case_handler, optional): EdgeCaseHandler to be used. If none, will create the default one + segmentation_class_groups (SegmentationClassGroups, optional): If not none, will evaluate per class group defined, instead of over all at the same time. + instance_metrics (list[Metric]): List of all metrics that should be calculated between all instances + global_metrics (list[Metric]): List of all metrics that should be calculated on the global binary masks + decision_metric: (Metric | None, optional): This metric is the final decision point between True Positive and False Positive. Can be left away if the matching algorithm is used (it will match by a metric and threshold already) + decision_threshold: (float | None, optional): Threshold for the decision_metric + log_times (bool): If true, will printout the times for the different phases of the pipeline. + verbose (bool): If true, will spit out more details than you want. """ self.__expected_input = expected_input # self.__instance_approximator = instance_approximator self.__instance_matcher = instance_matcher - self.__eval_metrics = eval_metrics + self.__eval_metrics = instance_metrics + self.__global_metrics = global_metrics self.__decision_metric = decision_metric self.__decision_threshold = decision_threshold self.__segmentation_class_groups = segmentation_class_groups - self.__edge_case_handler = ( - edge_case_handler if edge_case_handler is not None else EdgeCaseHandler() - ) + self.__edge_case_handler = edge_case_handler if edge_case_handler is not None else EdgeCaseHandler() if self.__decision_metric is not None: - assert ( - self.__decision_threshold is not None - ), "decision metric set but no decision threshold for it" + assert self.__decision_threshold is not None, "decision metric set but no decision threshold for it" # self.__log_times = log_times self.__verbose = verbose @@ -74,6 +81,7 @@ def _yaml_repr(cls, node) -> dict: "edge_case_handler": node.__edge_case_handler, "segmentation_class_groups": node.__segmentation_class_groups, "eval_metrics": node.__eval_metrics, + "global_metrics": node.__global_metrics, "decision_metric": node.__decision_metric, "decision_threshold": node.__decision_threshold, "log_times": node.__log_times, @@ -90,9 +98,7 @@ def evaluate( verbose: bool | None = None, ) -> dict[str, tuple[PanopticaResult, dict[str, _ProcessingPair]]]: processing_pair = self.__expected_input(prediction_arr, reference_arr) - assert isinstance( - processing_pair, self.__expected_input.value - ), f"input not of expected type {self.__expected_input}" + assert isinstance(processing_pair, self.__expected_input.value), f"input not of expected type {self.__expected_input}" if self.__segmentation_class_groups is None: return { @@ -101,7 +107,8 @@ def evaluate( edge_case_handler=self.__edge_case_handler, instance_approximator=self.__instance_approximator, instance_matcher=self.__instance_matcher, - eval_metrics=self.__eval_metrics, + instance_metrics=self.__eval_metrics, + global_metrics=self.__global_metrics, decision_metric=self.__decision_metric, decision_threshold=self.__decision_threshold, result_all=result_all, @@ -111,12 +118,8 @@ def evaluate( ) } - self.__segmentation_class_groups.has_defined_labels_for( - processing_pair.prediction_arr, raise_error=True - ) - self.__segmentation_class_groups.has_defined_labels_for( - processing_pair.reference_arr, raise_error=True - ) + self.__segmentation_class_groups.has_defined_labels_for(processing_pair.prediction_arr, raise_error=True) + self.__segmentation_class_groups.has_defined_labels_for(processing_pair.reference_arr, raise_error=True) result_grouped = {} for group_name, label_group in self.__segmentation_class_groups.items(): @@ -128,9 +131,7 @@ def evaluate( single_instance_mode = label_group.single_instance processing_pair_grouped = processing_pair.__class__(prediction_arr=prediction_arr_grouped, reference_arr=reference_arr_grouped) # type: ignore decision_threshold = self.__decision_threshold - if single_instance_mode and not isinstance( - processing_pair, MatchedInstancePair - ): + if single_instance_mode and not isinstance(processing_pair, MatchedInstancePair): processing_pair_grouped = MatchedInstancePair( prediction_arr=processing_pair_grouped.prediction_arr, reference_arr=processing_pair_grouped.reference_arr, @@ -142,7 +143,8 @@ def evaluate( edge_case_handler=self.__edge_case_handler, instance_approximator=self.__instance_approximator, instance_matcher=self.__instance_matcher, - eval_metrics=self.__eval_metrics, + instance_metrics=self.__eval_metrics, + global_metrics=self.__global_metrics, decision_metric=self.__decision_metric, decision_threshold=decision_threshold, result_all=result_all, @@ -154,12 +156,11 @@ def evaluate( def panoptic_evaluate( - processing_pair: ( - SemanticPair | UnmatchedInstancePair | MatchedInstancePair | PanopticaResult - ), + processing_pair: SemanticPair | UnmatchedInstancePair | MatchedInstancePair | PanopticaResult, instance_approximator: InstanceApproximator | None = None, instance_matcher: InstanceMatchingAlgorithm | None = None, - eval_metrics: list[Metric] = [Metric.DSC, Metric.IOU, Metric.ASSD], + instance_metrics: list[Metric] = [Metric.DSC, Metric.IOU, Metric.ASSD], + global_metrics: list[Metric] = [Metric.DSC], decision_metric: Metric | None = None, decision_threshold: float | None = None, edge_case_handler: EdgeCaseHandler | None = None, @@ -212,9 +213,7 @@ def panoptic_evaluate( processing_pair.crop_data() if isinstance(processing_pair, SemanticPair): - assert ( - instance_approximator is not None - ), "Got SemanticPair but not InstanceApproximator" + assert instance_approximator is not None, "Got SemanticPair but not InstanceApproximator" if verbose: print("-- Got SemanticPair, will approximate instances") start = perf_counter() @@ -227,16 +226,15 @@ def panoptic_evaluate( if isinstance(processing_pair, UnmatchedInstancePair): processing_pair = _handle_zero_instances_cases( processing_pair, - eval_metrics=eval_metrics, + eval_metrics=instance_metrics, + global_metrics=global_metrics, edge_case_handler=edge_case_handler, ) if isinstance(processing_pair, UnmatchedInstancePair): if verbose: print("-- Got UnmatchedInstancePair, will match instances") - assert ( - instance_matcher is not None - ), "Got UnmatchedInstancePair but not InstanceMatchingAlgorithm" + assert instance_matcher is not None, "Got UnmatchedInstancePair but not InstanceMatchingAlgorithm" start = perf_counter() processing_pair = instance_matcher.match_instances( processing_pair, @@ -250,7 +248,8 @@ def panoptic_evaluate( if isinstance(processing_pair, MatchedInstancePair): processing_pair = _handle_zero_instances_cases( processing_pair, - eval_metrics=eval_metrics, + eval_metrics=instance_metrics, + global_metrics=global_metrics, edge_case_handler=edge_case_handler, ) @@ -260,15 +259,27 @@ def panoptic_evaluate( start = perf_counter() processing_pair = evaluate_matched_instance( processing_pair, - eval_metrics=eval_metrics, + eval_metrics=instance_metrics, decision_metric=decision_metric, decision_threshold=decision_threshold, - edge_case_handler=edge_case_handler, ) if log_times: print(f"-- Instance Evaluation took {perf_counter() - start} seconds") + if isinstance(processing_pair, EvaluateInstancePair): + processing_pair = PanopticaResult( + reference_arr=processing_pair.reference_arr, + prediction_arr=processing_pair.prediction_arr, + num_pred_instances=processing_pair.num_pred_instances, + num_ref_instances=processing_pair.num_ref_instances, + tp=processing_pair.tp, + list_metrics=processing_pair.list_metrics, + global_metrics=global_metrics, + edge_case_handler=edge_case_handler, + ) + if isinstance(processing_pair, PanopticaResult): + processing_pair._global_metrics = global_metrics if result_all: processing_pair.calculate_all(print_errors=verbose_calc) return processing_pair, debug_data @@ -279,6 +290,7 @@ def panoptic_evaluate( def _handle_zero_instances_cases( processing_pair: UnmatchedInstancePair | MatchedInstancePair, edge_case_handler: EdgeCaseHandler, + global_metrics: list[Metric], eval_metrics: list[_Metric] = [Metric.DSC, Metric.IOU, Metric.ASSD], ) -> UnmatchedInstancePair | MatchedInstancePair | PanopticaResult: """ @@ -322,6 +334,7 @@ def _handle_zero_instances_cases( is_edge_case = True if is_edge_case: + panoptica_result_args["global_metrics"] = global_metrics panoptica_result_args["num_ref_instances"] = n_reference_instance panoptica_result_args["num_pred_instances"] = n_prediction_instance return PanopticaResult(**panoptica_result_args) diff --git a/panoptica/panoptica_result.py b/panoptica/panoptica_result.py index b831e04..e87a33b 100644 --- a/panoptica/panoptica_result.py +++ b/panoptica/panoptica_result.py @@ -24,12 +24,12 @@ def __init__( self, reference_arr: np.ndarray, prediction_arr: np.ndarray, - # TODO some metadata object containing dtype, voxel spacing, ... num_pred_instances: int, num_ref_instances: int, tp: int, list_metrics: dict[Metric, list[float]], edge_case_handler: EdgeCaseHandler, + global_metrics: list[Metric] = [], ): """Result object for Panoptica, contains all calculatable metrics @@ -46,6 +46,7 @@ def __init__( empty_list_std = self._edge_case_handler.handle_empty_list_std().value self._prediction_arr = prediction_arr self._reference_arr = reference_arr + self._global_metrics: list[Metric] = global_metrics ###################### # Evaluation Metrics # ###################### @@ -119,39 +120,6 @@ def __init__( ) # endregion # - # region Global - self.global_bin_dsc: int - self._add_metric( - "global_bin_dsc", - MetricType.GLOBAL, - global_bin_dsc, - long_name="Global Binary Dice", - ) - # - self.global_bin_cldsc: int - self._add_metric( - "global_bin_cldsc", - MetricType.GLOBAL, - global_bin_cldsc, - long_name="Global Binary Centerline Dice", - ) - # - self.global_bin_assd: int - self._add_metric( - "global_bin_assd", - MetricType.GLOBAL, - global_bin_assd, - long_name="Global Binary Average Symmetric Surface Distance", - ) - # - self.global_bin_rvd: int - self._add_metric( - "global_bin_rvd", - MetricType.GLOBAL, - global_bin_rvd, - long_name="Global Binary Relative Volume Difference", - ) - # endregion # # region IOU self.sq: float @@ -259,19 +227,35 @@ def __init__( ) # endregion + # region Global + # Just for autocomplete + self.global_bin_dsc: int + self.global_bin_iou: int + self.global_bin_cldsc: int + self.global_bin_assd: int + self.global_bin_rvd: int + # endregion + ################## # List Metrics # ################## self._list_metrics: dict[Metric, Evaluation_List_Metric] = {} - for k, v in list_metrics.items(): - is_edge_case, edge_case_result = self._edge_case_handler.handle_zero_tp( - metric=k, - tp=self.tp, - num_pred_instances=self.num_pred_instances, - num_ref_instances=self.num_ref_instances, - ) - self._list_metrics[k] = Evaluation_List_Metric( - k, empty_list_std, v, is_edge_case, edge_case_result + # Loop over all available metric, add it to evaluation_list_metric if available, but also add the global references + for m in Metric: + if m in list_metrics: + is_edge_case, edge_case_result = self._edge_case_handler.handle_zero_tp( + metric=m, + tp=self.tp, + num_pred_instances=self.num_pred_instances, + num_ref_instances=self.num_ref_instances, + ) + self._list_metrics[m] = Evaluation_List_Metric(m, empty_list_std, list_metrics[m], is_edge_case, edge_case_result) + # even if not available, set the global vars + self._add_metric( + f"global_bin_{m.name.lower()}", + MetricType.GLOBAL, + _build_global_bin_metric_function(m), + long_name="Global Binary " + m.value.long_name, ) def _add_metric( @@ -321,6 +305,8 @@ def __str__(self) -> str: for metric_type in MetricType: if metric_type == MetricType.NO_PRINT: continue + if metric_type == MetricType.GLOBAL and len(self._global_metrics) == 0: + continue text += f"\n+++ {metric_type.name} +++\n" for k, v in self._evaluation_metrics.items(): if v.metric_type != metric_type: @@ -341,19 +327,13 @@ def __str__(self) -> str: return text def to_dict(self) -> dict: - return { - k: getattr(self, v.id) - for k, v in self._evaluation_metrics.items() - if (v._error == False and v._was_calculated) - } + return {k: getattr(self, v.id) for k, v in self._evaluation_metrics.items() if (v._error == False and v._was_calculated)} def get_list_metric(self, metric: Metric, mode: MetricMode): if metric in self._list_metrics: return self._list_metrics[metric][mode] else: - raise MetricCouldNotBeComputedException( - f"{metric} could not be found, have you set it in eval_metrics during evaluation?" - ) + raise MetricCouldNotBeComputedException(f"{metric} could not be found, have you set it in eval_metrics during evaluation?") def _calc_metric(self, metric_name: str, supress_error: bool = False): if metric_name in self._evaluation_metrics: @@ -369,9 +349,7 @@ def _calc_metric(self, metric_name: str, supress_error: bool = False): self._evaluation_metrics[metric_name]._was_calculated = True return value else: - raise MetricCouldNotBeComputedException( - f"could not find metric with name {metric_name}" - ) + raise MetricCouldNotBeComputedException(f"could not find metric with name {metric_name}") def __getattribute__(self, __name: str) -> Any: attr = None @@ -384,9 +362,7 @@ def __getattribute__(self, __name: str) -> Any: raise e if attr is None: if self._evaluation_metrics[__name]._error: - raise MetricCouldNotBeComputedException( - f"Requested metric {__name} that could not be computed" - ) + raise MetricCouldNotBeComputedException(f"Requested metric {__name} that could not be computed") elif not self._evaluation_metrics[__name]._was_calculated: value = self._calc_metric(__name) setattr(self, __name, value) @@ -506,45 +482,25 @@ def sq_rvd_std(res: PanopticaResult): # endregion -# region Global -def global_bin_dsc(res: PanopticaResult): - if res.tp == 0: - return 0.0 - pred_binary = res._prediction_arr.copy() - ref_binary = res._reference_arr.copy() - pred_binary[pred_binary != 0] = 1 - ref_binary[ref_binary != 0] = 1 - return _compute_dice_coefficient(ref_binary, pred_binary) - - -def global_bin_cldsc(res: PanopticaResult): - if res.tp == 0: - return 0.0 - pred_binary = res._prediction_arr.copy() - ref_binary = res._reference_arr.copy() - pred_binary[pred_binary != 0] = 1 - ref_binary[ref_binary != 0] = 1 - return _compute_centerline_dice_coefficient(ref_binary, pred_binary) - - -def global_bin_assd(res: PanopticaResult): - if res.tp == 0: - return 0.0 - pred_binary = res._prediction_arr.copy() - ref_binary = res._reference_arr.copy() - pred_binary[pred_binary != 0] = 1 - ref_binary[ref_binary != 0] = 1 - return _average_symmetric_surface_distance(ref_binary, pred_binary) - +def _build_global_bin_metric_function(metric: Metric): + + def function_template(res: PanopticaResult): + if metric not in res._global_metrics: + raise MetricCouldNotBeComputedException(f"Global Metric {metric} not set") + if res.tp == 0: + is_edgecase, result = res._edge_case_handler.handle_zero_tp(metric, res.tp, res.num_pred_instances, res.num_ref_instances) + if is_edgecase: + return result + pred_binary = res._prediction_arr.copy() + ref_binary = res._reference_arr.copy() + pred_binary[pred_binary != 0] = 1 + ref_binary[ref_binary != 0] = 1 + return metric( + reference_arr=res._reference_arr, + prediction_arr=res._prediction_arr, + ) -def global_bin_rvd(res: PanopticaResult): - if res.tp == 0: - return 0.0 - pred_binary = res._prediction_arr.copy() - ref_binary = res._reference_arr.copy() - pred_binary[pred_binary != 0] = 1 - ref_binary[ref_binary != 0] = 1 - return _compute_relative_volume_difference(ref_binary, pred_binary) + return function_template # endregion diff --git a/panoptica/utils/edge_case_handling.py b/panoptica/utils/edge_case_handling.py index 1a865f0..d6a0e02 100644 --- a/panoptica/utils/edge_case_handling.py +++ b/panoptica/utils/edge_case_handling.py @@ -41,26 +41,12 @@ def __init__( self._default_result = default_result self._edgecase_dict: dict[EdgeCaseZeroTP, EdgeCaseResult] = {} - self._edgecase_dict[EdgeCaseZeroTP.EMPTY_PRED] = ( - empty_prediction_result - if empty_prediction_result is not None - else default_result - ) - self._edgecase_dict[EdgeCaseZeroTP.EMPTY_REF] = ( - empty_reference_result - if empty_reference_result is not None - else default_result - ) - self._edgecase_dict[EdgeCaseZeroTP.NO_INSTANCES] = ( - no_instances_result if no_instances_result is not None else default_result - ) - self._edgecase_dict[EdgeCaseZeroTP.NORMAL] = ( - normal if normal is not None else default_result - ) + self._edgecase_dict[EdgeCaseZeroTP.EMPTY_PRED] = empty_prediction_result if empty_prediction_result is not None else default_result + self._edgecase_dict[EdgeCaseZeroTP.EMPTY_REF] = empty_reference_result if empty_reference_result is not None else default_result + self._edgecase_dict[EdgeCaseZeroTP.NO_INSTANCES] = no_instances_result if no_instances_result is not None else default_result + self._edgecase_dict[EdgeCaseZeroTP.NORMAL] = normal if normal is not None else default_result - def __call__( - self, tp: int, num_pred_instances, num_ref_instances - ) -> tuple[bool, float | None]: + def __call__(self, tp: int, num_pred_instances, num_ref_instances) -> tuple[bool, float | None]: if tp != 0: return False, EdgeCaseResult.NONE.value # @@ -131,9 +117,7 @@ def __init__( }, empty_list_std: EdgeCaseResult = EdgeCaseResult.NAN, ) -> None: - self.__listmetric_zeroTP_handling: dict[ - Metric, MetricZeroTPEdgeCaseHandling - ] = listmetric_zeroTP_handling + self.__listmetric_zeroTP_handling: dict[Metric, MetricZeroTPEdgeCaseHandling] = listmetric_zeroTP_handling self.__empty_list_std: EdgeCaseResult = empty_list_std def handle_zero_tp( @@ -143,12 +127,24 @@ def handle_zero_tp( num_pred_instances: int, num_ref_instances: int, ) -> tuple[bool, float | None]: + """_summary_ + + Args: + metric (Metric): _description_ + tp (int): _description_ + num_pred_instances (int): _description_ + num_ref_instances (int): _description_ + + Raises: + NotImplementedError: _description_ + + Returns: + tuple[bool, float | None]: if edge case, and its edge case value + """ if tp != 0: return False, EdgeCaseResult.NONE.value if metric not in self.__listmetric_zeroTP_handling: - raise NotImplementedError( - f"Metric {metric} encountered zero TP, but no edge handling available" - ) + raise NotImplementedError(f"Metric {metric} encountered zero TP, but no edge handling available") return self.__listmetric_zeroTP_handling[metric]( tp=tp, @@ -185,9 +181,7 @@ def _yaml_repr(cls, node) -> dict: print() # print(handler.get_metric_zero_tp_handle(ListMetric.IOU)) - r = handler.handle_zero_tp( - Metric.IOU, tp=0, num_pred_instances=1, num_ref_instances=1 - ) + r = handler.handle_zero_tp(Metric.IOU, tp=0, num_pred_instances=1, num_ref_instances=1) print(r) iou_test = MetricZeroTPEdgeCaseHandling( diff --git a/panoptica/utils/processing_pair.py b/panoptica/utils/processing_pair.py index c64b1e9..b89d5ad 100644 --- a/panoptica/utils/processing_pair.py +++ b/panoptica/utils/processing_pair.py @@ -5,6 +5,8 @@ from panoptica._functionals import _get_paired_crop from panoptica.utils import _count_unique_without_zeros, _unique_without_zeros from panoptica.utils.constants import _Enum_Compare +from dataclasses import dataclass +from panoptica.metrics import Metric uint_type: type = np.unsignedinteger int_type: type = np.integer @@ -23,9 +25,7 @@ class _ProcessingPair(ABC): _pred_labels: tuple[int, ...] n_dim: int - def __init__( - self, prediction_arr: np.ndarray, reference_arr: np.ndarray, dtype: type | None - ) -> None: + def __init__(self, prediction_arr: np.ndarray, reference_arr: np.ndarray, dtype: type | None) -> None: """Initializes a general Processing Pair Args: @@ -38,12 +38,8 @@ def __init__( self._reference_arr = reference_arr self.dtype = dtype self.n_dim = reference_arr.ndim - self._ref_labels: tuple[int, ...] = tuple( - _unique_without_zeros(reference_arr) - ) # type:ignore - self._pred_labels: tuple[int, ...] = tuple( - _unique_without_zeros(prediction_arr) - ) # type:ignore + self._ref_labels: tuple[int, ...] = tuple(_unique_without_zeros(reference_arr)) # type:ignore + self._pred_labels: tuple[int, ...] = tuple(_unique_without_zeros(prediction_arr)) # type:ignore self.crop: tuple[slice, ...] = None self.is_cropped: bool = False self.uncropped_shape: tuple[int, ...] = reference_arr.shape @@ -60,41 +56,25 @@ def crop_data(self, verbose: bool = False): self._prediction_arr = self._prediction_arr[self.crop] self._reference_arr = self._reference_arr[self.crop] - ( - print( - f"-- Cropped from {self.uncropped_shape} to {self._prediction_arr.shape}" - ) - if verbose - else None - ) + (print(f"-- Cropped from {self.uncropped_shape} to {self._prediction_arr.shape}") if verbose else None) self.is_cropped = True def uncrop_data(self, verbose: bool = False): if self.is_cropped == False: return - assert ( - self.uncropped_shape is not None - ), "Calling uncrop_data() without having cropped first" + assert self.uncropped_shape is not None, "Calling uncrop_data() without having cropped first" prediction_arr = np.zeros(self.uncropped_shape) prediction_arr[self.crop] = self._prediction_arr self._prediction_arr = prediction_arr reference_arr = np.zeros(self.uncropped_shape) reference_arr[self.crop] = self._reference_arr - ( - print( - f"-- Uncropped from {self._reference_arr.shape} to {self.uncropped_shape}" - ) - if verbose - else None - ) + (print(f"-- Uncropped from {self._reference_arr.shape} to {self.uncropped_shape}") if verbose else None) self._reference_arr = reference_arr self.is_cropped = False def set_dtype(self, type): - assert np.issubdtype( - type, int_type - ), "set_dtype: tried to set dtype to something other than integers" + assert np.issubdtype(type, int_type), "set_dtype: tried to set dtype to something other than integers" self._prediction_arr = self._prediction_arr.astype(type) self._reference_arr = self._reference_arr.astype(type) @@ -173,9 +153,7 @@ def copy(self): ) # type:ignore -def _check_array_integrity( - prediction_arr: np.ndarray, reference_arr: np.ndarray, dtype: type | None = None -): +def _check_array_integrity(prediction_arr: np.ndarray, reference_arr: np.ndarray, dtype: type | None = None): """ Check the integrity of two numpy arrays. @@ -197,12 +175,8 @@ def _check_array_integrity( assert isinstance(prediction_arr, np.ndarray) and isinstance( reference_arr, np.ndarray ), "prediction and/or reference are not numpy arrays" - assert ( - prediction_arr.shape == reference_arr.shape - ), f"shape mismatch, got {prediction_arr.shape},{reference_arr.shape}" - assert ( - prediction_arr.dtype == reference_arr.dtype - ), f"dtype mismatch, got {prediction_arr.dtype},{reference_arr.dtype}" + assert prediction_arr.shape == reference_arr.shape, f"shape mismatch, got {prediction_arr.shape},{reference_arr.shape}" + assert prediction_arr.dtype == reference_arr.dtype, f"dtype mismatch, got {prediction_arr.dtype},{reference_arr.dtype}" if dtype is not None: assert ( np.issubdtype(prediction_arr.dtype, dtype) @@ -285,15 +259,11 @@ def __init__( self.matched_instances = matched_instances if missed_reference_labels is None: - missed_reference_labels = list( - [i for i in self._ref_labels if i not in self._pred_labels] - ) + missed_reference_labels = list([i for i in self._ref_labels if i not in self._pred_labels]) self.missed_reference_labels = missed_reference_labels if missed_prediction_labels is None: - missed_prediction_labels = list( - [i for i in self._pred_labels if i not in self._ref_labels] - ) + missed_prediction_labels = list([i for i in self._pred_labels if i not in self._ref_labels]) self.missed_prediction_labels = missed_prediction_labels @property @@ -315,12 +285,20 @@ def copy(self): ) +@dataclass +class EvaluateInstancePair: + reference_arr: np.ndarray + prediction_arr: np.ndarray + num_pred_instances: int + num_ref_instances: int + tp: int + list_metrics: dict[Metric, list[float]] + + class InputType(_Enum_Compare): SEMANTIC = SemanticPair UNMATCHED_INSTANCE = UnmatchedInstancePair MATCHED_INSTANCE = MatchedInstancePair - def __call__( - self, prediction_arr: np.ndarray, reference_arr: np.ndarray - ) -> _ProcessingPair: + def __call__(self, prediction_arr: np.ndarray, reference_arr: np.ndarray) -> _ProcessingPair: return self.value(prediction_arr, reference_arr) diff --git a/unit_tests/test_panoptic_evaluator.py b/unit_tests/test_panoptic_evaluator.py index 5fe9c0e..a3009ae 100644 --- a/unit_tests/test_panoptic_evaluator.py +++ b/unit_tests/test_panoptic_evaluator.py @@ -70,7 +70,7 @@ def test_simple_evaluation_DSC_partial(self): expected_input=InputType.SEMANTIC, instance_approximator=ConnectedComponentsInstanceApproximator(), instance_matcher=NaiveThresholdMatching(matching_metric=Metric.DSC), - eval_metrics=[Metric.DSC], + instance_metrics=[Metric.DSC], ) result, debug_data = evaluator.evaluate(b, a)["ungrouped"] diff --git a/unit_tests/test_panoptic_result.py b/unit_tests/test_panoptic_result.py index f5087ea..883c604 100644 --- a/unit_tests/test_panoptic_result.py +++ b/unit_tests/test_panoptic_result.py @@ -28,6 +28,7 @@ def test_simple_evaluation(self): num_pred_instances=5, tp=0, list_metrics={Metric.IOU: []}, + global_metrics=[Metric.DSC], edge_case_handler=EdgeCaseHandler(), ) c.calculate_all(print_errors=True) @@ -85,9 +86,7 @@ def test_existing_metrics(self): def powerset(iterable): s = list(iterable) - return list( - chain.from_iterable(combinations(s, r) for r in range(len(s) + 1)) - ) + return list(chain.from_iterable(combinations(s, r) for r in range(len(s) + 1))) power_set = powerset([Metric.DSC, Metric.IOU, Metric.ASSD]) for m in power_set[1:]: