diff --git a/panoptica/instance_evaluator.py b/panoptica/instance_evaluator.py index 7421a6d..028e19f 100644 --- a/panoptica/instance_evaluator.py +++ b/panoptica/instance_evaluator.py @@ -33,7 +33,9 @@ def evaluate_matched_instance( if edge_case_handler is None: edge_case_handler = EdgeCaseHandler() if decision_metric is not None: - assert decision_metric.name in [v.name for v in eval_metrics], "decision metric not contained in eval_metrics" + assert decision_metric.name in [ + v.name for v in eval_metrics + ], "decision metric not contained in eval_metrics" assert decision_threshold is not None, "decision metric set but no threshold" # Initialize variables for True Positives (tp) tp = len(matched_instance_pair.matched_instances) @@ -45,14 +47,21 @@ def evaluate_matched_instance( ) ref_matched_labels = matched_instance_pair.matched_instances - instance_pairs = [(reference_arr, prediction_arr, ref_idx, eval_metrics) for ref_idx in ref_matched_labels] + instance_pairs = [ + (reference_arr, prediction_arr, ref_idx, eval_metrics) + for ref_idx in ref_matched_labels + ] with Pool() as pool: - metric_dicts: list[dict[Metric, float]] = pool.starmap(_evaluate_instance, instance_pairs) + metric_dicts: list[dict[Metric, float]] = pool.starmap( + _evaluate_instance, instance_pairs + ) for metric_dict in metric_dicts: if decision_metric is None or ( decision_threshold is not None - and decision_metric.score_beats_threshold(metric_dict[decision_metric], decision_threshold) + and decision_metric.score_beats_threshold( + metric_dict[decision_metric], decision_threshold + ) ): for k, v in metric_dict.items(): score_dict[k].append(v) diff --git a/panoptica/metrics/__init__.py b/panoptica/metrics/__init__.py index 373bfb3..1ebfd9c 100644 --- a/panoptica/metrics/__init__.py +++ b/panoptica/metrics/__init__.py @@ -7,7 +7,7 @@ _compute_instance_volumetric_dice, ) from panoptica.metrics.iou import ( - _compute_instance_iou, + _compute_instance_iou, _compute_iou, ) from panoptica.metrics.cldice import ( diff --git a/panoptica/metrics/cldice.py b/panoptica/metrics/cldice.py index 9443627..3924751 100644 --- a/panoptica/metrics/cldice.py +++ b/panoptica/metrics/cldice.py @@ -16,10 +16,10 @@ def cl_score(volume: np.ndarray, skeleton: np.ndarray): def _compute_centerline_dice( - ref_labels: np.ndarray, - pred_labels: np.ndarray, - ref_instance_idx: int, - pred_instance_idx: int, + ref_labels: np.ndarray, + pred_labels: np.ndarray, + ref_instance_idx: int, + pred_instance_idx: int, ) -> float: """Compute the centerline Dice (clDice) coefficient between a specific pair of instances. @@ -38,7 +38,6 @@ def _compute_centerline_dice( reference=ref_instance_mask, prediction=pred_instance_mask, ) - def _compute_centerline_dice_coefficient( @@ -49,10 +48,10 @@ def _compute_centerline_dice_coefficient( ndim = reference.ndim assert 2 <= ndim <= 3, "clDice only implemented for 2D or 3D" if ndim == 2: - tprec = cl_score(prediction,skeletonize(reference)) - tsens = cl_score(reference,skeletonize(prediction)) + tprec = cl_score(prediction, skeletonize(reference)) + tsens = cl_score(reference, skeletonize(prediction)) elif ndim == 3: - tprec = cl_score(prediction,skeletonize_3d(reference)) - tsens = cl_score(reference,skeletonize_3d(prediction)) + tprec = cl_score(prediction, skeletonize_3d(reference)) + tsens = cl_score(reference, skeletonize_3d(prediction)) - return 2 * tprec * tsens / (tprec + tsens) \ No newline at end of file + return 2 * tprec * tsens / (tprec + tsens) diff --git a/panoptica/metrics/metrics.py b/panoptica/metrics/metrics.py index b1133e4..3b2b46f 100644 --- a/panoptica/metrics/metrics.py +++ b/panoptica/metrics/metrics.py @@ -15,8 +15,8 @@ @dataclass class _Metric: - """A Metric class containing a name, whether higher or lower values is better, and a function to calculate that metric between two instances in an array - """ + """A Metric class containing a name, whether higher or lower values is better, and a function to calculate that metric between two instances in an array""" + name: str decreasing: bool _metric_function: Callable @@ -34,7 +34,9 @@ def __call__( reference_arr = reference_arr.copy() == ref_instance_idx if isinstance(pred_instance_idx, int): pred_instance_idx = [pred_instance_idx] - prediction_arr = np.isin(prediction_arr.copy(), pred_instance_idx) #type:ignore + prediction_arr = np.isin( + prediction_arr.copy(), pred_instance_idx + ) # type:ignore return self._metric_function(reference_arr, prediction_arr, *args, **kwargs) def __eq__(self, __value: object) -> bool: @@ -50,7 +52,7 @@ def __str__(self) -> str: def __repr__(self) -> str: return str(self) - + def __hash__(self) -> int: return abs(hash(self.name)) % (10**8) @@ -58,26 +60,32 @@ def __hash__(self) -> int: def increasing(self): return not self.decreasing - def score_beats_threshold(self, matching_score: float, matching_threshold: float) -> bool: - return (self.increasing and matching_score >= matching_threshold) or (self.decreasing and matching_score <= matching_threshold) + def score_beats_threshold( + self, matching_score: float, matching_threshold: float + ) -> bool: + return (self.increasing and matching_score >= matching_threshold) or ( + self.decreasing and matching_score <= matching_threshold + ) -class DirectValueMeta(EnumMeta): +class DirectValueMeta(EnumMeta): "Metaclass that allows for directly getting an enum attribute" + def __getattribute__(cls, name) -> _Metric: value = super().__getattribute__(name) if isinstance(value, cls): value = value.value return value - + class Metric(_Enum_Compare): """Enum containing important metrics that must be calculated in the evaluator, can be set for thresholding in matching and evaluation Never call the .value member here, use the properties directly - + Returns: _type_: _description_ """ + DSC = _Metric("DSC", False, _compute_dice_coefficient) IOU = _Metric("IOU", False, _compute_iou) ASSD = _Metric("ASSD", True, _average_symmetric_surface_distance) @@ -103,9 +111,18 @@ def __call__( Returns: int | float: The metric value """ - return self.value(reference_arr=reference_arr, prediction_arr=prediction_arr, ref_instance_idx=ref_instance_idx, pred_instance_idx=pred_instance_idx, *args, **kwargs,) - - def score_beats_threshold(self, matching_score: float, matching_threshold: float) -> bool: + return self.value( + reference_arr=reference_arr, + prediction_arr=prediction_arr, + ref_instance_idx=ref_instance_idx, + pred_instance_idx=pred_instance_idx, + *args, + **kwargs, + ) + + def score_beats_threshold( + self, matching_score: float, matching_threshold: float + ) -> bool: """Calculates whether a score beats a specified threshold Args: @@ -115,12 +132,14 @@ def score_beats_threshold(self, matching_score: float, matching_threshold: float Returns: bool: True if the matching_score beats the threshold, False otherwise. """ - return (self.increasing and matching_score >= matching_threshold) or (self.decreasing and matching_score <= matching_threshold) + return (self.increasing and matching_score >= matching_threshold) or ( + self.decreasing and matching_score <= matching_threshold + ) @property def name(self): return self.value.name - + @property def decreasing(self): return self.value.decreasing @@ -139,6 +158,7 @@ class MetricMode(_Enum_Compare): Args: _Enum_Compare (_type_): _description_ """ + ALL = auto() AVG = auto() SUM = auto() @@ -154,4 +174,4 @@ class MetricMode(_Enum_Compare): print(Metric.DSC.name == "DSC") # print(Metric.DSC == Metric.IOU) - print(Metric.DSC == "IOU") \ No newline at end of file + print(Metric.DSC == "IOU") diff --git a/panoptica/panoptic_evaluator.py b/panoptica/panoptic_evaluator.py index d562de6..1a10b3c 100644 --- a/panoptica/panoptic_evaluator.py +++ b/panoptica/panoptic_evaluator.py @@ -21,7 +21,9 @@ class Panoptic_Evaluator: def __init__( self, - expected_input: Type[SemanticPair] | Type[UnmatchedInstancePair] | Type[MatchedInstancePair] = MatchedInstancePair, + expected_input: Type[SemanticPair] + | Type[UnmatchedInstancePair] + | Type[MatchedInstancePair] = MatchedInstancePair, instance_approximator: InstanceApproximator | None = None, instance_matcher: InstanceMatchingAlgorithm | None = None, edge_case_handler: EdgeCaseHandler | None = None, @@ -47,9 +49,13 @@ def __init__( self.__decision_metric = decision_metric self.__decision_threshold = decision_threshold - self.__edge_case_handler = edge_case_handler if edge_case_handler is not None else EdgeCaseHandler() + self.__edge_case_handler = ( + edge_case_handler if edge_case_handler is not None else EdgeCaseHandler() + ) if self.__decision_metric is not None: - assert self.__decision_threshold is not None, "decision metric set but no decision threshold for it" + assert ( + self.__decision_threshold is not None + ), "decision metric set but no decision threshold for it" # self.__log_times = log_times self.__verbose = verbose @@ -58,11 +64,16 @@ def __init__( @measure_time def evaluate( self, - processing_pair: SemanticPair | UnmatchedInstancePair | MatchedInstancePair | PanopticaResult, + processing_pair: SemanticPair + | UnmatchedInstancePair + | MatchedInstancePair + | PanopticaResult, result_all: bool = True, verbose: bool | None = None, ) -> tuple[PanopticaResult, dict[str, _ProcessingPair]]: - assert type(processing_pair) == self.__expected_input, f"input not of expected type {self.__expected_input}" + assert ( + type(processing_pair) == self.__expected_input + ), f"input not of expected type {self.__expected_input}" return panoptic_evaluate( processing_pair=processing_pair, edge_case_handler=self.__edge_case_handler, @@ -78,7 +89,10 @@ def evaluate( def panoptic_evaluate( - processing_pair: SemanticPair | UnmatchedInstancePair | MatchedInstancePair | PanopticaResult, + processing_pair: SemanticPair + | UnmatchedInstancePair + | MatchedInstancePair + | PanopticaResult, instance_approximator: InstanceApproximator | None = None, instance_matcher: InstanceMatchingAlgorithm | None = None, eval_metrics: list[Metric] = [Metric.DSC, Metric.IOU, Metric.ASSD], @@ -131,7 +145,9 @@ def panoptic_evaluate( processing_pair.crop_data() if isinstance(processing_pair, SemanticPair): - assert instance_approximator is not None, "Got SemanticPair but not InstanceApproximator" + assert ( + instance_approximator is not None + ), "Got SemanticPair but not InstanceApproximator" print("-- Got SemanticPair, will approximate instances") processing_pair = instance_approximator.approximate_instances(processing_pair) start = perf_counter() @@ -142,11 +158,17 @@ def panoptic_evaluate( # Second Phase: Instance Matching if isinstance(processing_pair, UnmatchedInstancePair): - processing_pair = _handle_zero_instances_cases(processing_pair, eval_metrics=eval_metrics, edge_case_handler=edge_case_handler) + processing_pair = _handle_zero_instances_cases( + processing_pair, + eval_metrics=eval_metrics, + edge_case_handler=edge_case_handler, + ) if isinstance(processing_pair, UnmatchedInstancePair): print("-- Got UnmatchedInstancePair, will match instances") - assert instance_matcher is not None, "Got UnmatchedInstancePair but not InstanceMatchingAlgorithm" + assert ( + instance_matcher is not None + ), "Got UnmatchedInstancePair but not InstanceMatchingAlgorithm" start = perf_counter() processing_pair = instance_matcher.match_instances( processing_pair, @@ -158,7 +180,11 @@ def panoptic_evaluate( # Third Phase: Instance Evaluation if isinstance(processing_pair, MatchedInstancePair): - processing_pair = _handle_zero_instances_cases(processing_pair, eval_metrics=eval_metrics, edge_case_handler=edge_case_handler) + processing_pair = _handle_zero_instances_cases( + processing_pair, + eval_metrics=eval_metrics, + edge_case_handler=edge_case_handler, + ) if isinstance(processing_pair, MatchedInstancePair): print("-- Got MatchedInstancePair, will evaluate instances") @@ -211,23 +237,23 @@ def _handle_zero_instances_cases( # Handle cases where either the reference or the prediction is empty if n_prediction_instance == 0 and n_reference_instance == 0: # Both references and predictions are empty, perfect match - n_reference_instance=0 - n_prediction_instance=0 - is_edge_case=True + n_reference_instance = 0 + n_prediction_instance = 0 + is_edge_case = True elif n_reference_instance == 0: # All references are missing, only false positives - n_reference_instance=0 - n_prediction_instance=n_prediction_instance - is_edge_case=True + n_reference_instance = 0 + n_prediction_instance = n_prediction_instance + is_edge_case = True elif n_prediction_instance == 0: # All predictions are missing, only false negatives - n_reference_instance=n_reference_instance - n_prediction_instance=0 - is_edge_case=True - + n_reference_instance = n_reference_instance + n_prediction_instance = 0 + is_edge_case = True + if is_edge_case: panoptica_result_args["num_ref_instances"] = n_reference_instance panoptica_result_args["num_pred_instances"] = n_prediction_instance return PanopticaResult(**panoptica_result_args) - + return processing_pair diff --git a/panoptica/panoptic_result.py b/panoptica/panoptic_result.py index a0e15dd..7493fd5 100644 --- a/panoptica/panoptic_result.py +++ b/panoptica/panoptic_result.py @@ -3,14 +3,17 @@ from typing import Any, Callable import numpy as np from panoptica.metrics import MetricMode, Metric -from panoptica.metrics import _compute_dice_coefficient, _compute_centerline_dice_coefficient +from panoptica.metrics import ( + _compute_dice_coefficient, + _compute_centerline_dice_coefficient, +) from panoptica.utils import EdgeCaseHandler from panoptica.utils.processing_pair import MatchedInstancePair class MetricCouldNotBeComputedException(Exception): - """Exception for when a Metric cannot be computed - """ + """Exception for when a Metric cannot be computed""" + def __init__(self, *args: object) -> None: super().__init__(*args) @@ -43,11 +46,17 @@ def __init__( def __call__(self, result_obj: PanopticaResult) -> Any: if self.error: if self.error_obj is None: - raise MetricCouldNotBeComputedException(f"Metric {self.id} requested, but could not be computed") + raise MetricCouldNotBeComputedException( + f"Metric {self.id} requested, but could not be computed" + ) else: raise self.error_obj - assert not self.was_calculated, f"Metric {self.id} was called to compute, but is set to have been already calculated" - assert self.calc_func is not None, f"Metric {self.id} was called to compute, but has no calculation function set" + assert ( + not self.was_calculated + ), f"Metric {self.id} was called to compute, but is set to have been already calculated" + assert ( + self.calc_func is not None + ), f"Metric {self.id} was called to compute, but has no calculation function set" try: value = self.calc_func(result_obj) except MetricCouldNotBeComputedException as e: @@ -88,17 +97,27 @@ def __init__( else: self.AVG = None if self.ALL is None else np.average(self.ALL) self.SUM = None if self.ALL is None else np.sum(self.ALL) - self.STD = None if self.ALL is None else empty_list_std if len(self.ALL) == 0 else np.std(self.ALL) + self.STD = ( + None + if self.ALL is None + else empty_list_std + if len(self.ALL) == 0 + else np.std(self.ALL) + ) def __getitem__(self, mode: MetricMode | str): if self.error: - raise MetricCouldNotBeComputedException(f"Metric {self.id} has not been calculated, add it to your eval_metrics") + raise MetricCouldNotBeComputedException( + f"Metric {self.id} has not been calculated, add it to your eval_metrics" + ) if isinstance(mode, MetricMode): mode = mode.name if hasattr(self, mode): return getattr(self, mode) else: - raise MetricCouldNotBeComputedException(f"List_Metric {self.id} does not contain {mode} member") + raise MetricCouldNotBeComputedException( + f"List_Metric {self.id} does not contain {mode} member" + ) class PanopticaResult(object): @@ -133,7 +152,7 @@ def __init__( ###################### self._evaluation_metrics: dict[str, Evaluation_Metric] = {} # - #region Already Calculated + # region Already Calculated self.num_ref_instances: int self._add_metric( "num_ref_instances", @@ -158,9 +177,9 @@ def __init__( default_value=tp, was_calculated=True, ) - #endregion - # - #region Basic + # endregion + # + # region Basic self.fp: int self._add_metric( "fp", @@ -179,9 +198,9 @@ def __init__( rq, long_name="Recognition Quality", ) - #endregion - # - #region Global + # endregion + # + # region Global self.global_bin_dsc: int self._add_metric( "global_bin_dsc", @@ -195,9 +214,9 @@ def __init__( global_bin_cldsc, long_name="Global Binary Centerline Dice", ) - #endregion + # endregion # - #region IOU + # region IOU self.sq: float self._add_metric( "sq", @@ -216,9 +235,9 @@ def __init__( pq, long_name="Panoptic Quality IoU", ) - #endregion + # endregion # - #region DICE + # region DICE self.sq_dsc: float self._add_metric( "sq_dsc", @@ -237,9 +256,9 @@ def __init__( pq_dsc, long_name="Panoptic Quality Dsc", ) - #endregion + # endregion # - #region clDICE + # region clDICE self.sq_cldsc: float self._add_metric( "sq_cldsc", @@ -258,9 +277,9 @@ def __init__( pq_cldsc, long_name="Panoptic Quality Centerline Dsc", ) - #endregion - # - #region ASSD + # endregion + # + # region ASSD self.sq_assd: float self._add_metric( "sq_assd", @@ -273,7 +292,7 @@ def __init__( sq_assd_std, long_name="Segmentation Quality Assd Standard Deviation", ) - #endregion + # endregion ################## # List Metrics # @@ -286,7 +305,9 @@ def __init__( num_pred_instances=self.num_pred_instances, num_ref_instances=self.num_ref_instances, ) - self._list_metrics[k] = Evaluation_List_Metric(k, empty_list_std, v, is_edge_case, edge_case_result) + self._list_metrics[k] = Evaluation_List_Metric( + k, empty_list_std, v, is_edge_case, edge_case_result + ) def _add_metric( self, @@ -348,7 +369,9 @@ def get_list_metric(self, metric: Metric, mode: MetricMode): if metric in self._list_metrics: return self._list_metrics[metric][mode] else: - raise MetricCouldNotBeComputedException(f"{metric} could not be found, have you set it in eval_metrics during evaluation?") + raise MetricCouldNotBeComputedException( + f"{metric} could not be found, have you set it in eval_metrics during evaluation?" + ) def _calc_metric(self, metric_name: str, supress_error: bool = False): if metric_name in self._evaluation_metrics: @@ -364,7 +387,9 @@ def _calc_metric(self, metric_name: str, supress_error: bool = False): self._evaluation_metrics[metric_name].was_calculated = True return value else: - raise MetricCouldNotBeComputedException(f"could not find metric with name {metric_name}") + raise MetricCouldNotBeComputedException( + f"could not find metric with name {metric_name}" + ) def __getattribute__(self, __name: str) -> Any: attr = None @@ -377,7 +402,9 @@ def __getattribute__(self, __name: str) -> Any: raise e if attr is None: if self._evaluation_metrics[__name].error: - raise MetricCouldNotBeComputedException(f"Requested metric {__name} that could not be computed") + raise MetricCouldNotBeComputedException( + f"Requested metric {__name} that could not be computed" + ) elif not self._evaluation_metrics[__name].was_calculated: value = self._calc_metric(__name) setattr(self, __name, value) @@ -392,7 +419,8 @@ def __getattribute__(self, __name: str) -> Any: # Calculation functions # ######################### -#region Basic + +# region Basic def fp(res: PanopticaResult): return res.num_pred_instances - res.tp @@ -411,9 +439,12 @@ def rq(res: PanopticaResult): if res.tp == 0: return 0.0 if res.num_pred_instances + res.num_ref_instances > 0 else np.nan return res.tp / (res.tp + 0.5 * res.fp + 0.5 * res.fn) -#endregion -#region IOU + +# endregion + + +# region IOU def sq(res: PanopticaResult): return res.get_list_metric(Metric.IOU, mode=MetricMode.AVG) @@ -424,9 +455,12 @@ def sq_std(res: PanopticaResult): def pq(res: PanopticaResult): return res.sq * res.rq -#endregion -#region DSC + +# endregion + + +# region DSC def sq_dsc(res: PanopticaResult): return res.get_list_metric(Metric.DSC, mode=MetricMode.AVG) @@ -437,9 +471,12 @@ def sq_dsc_std(res: PanopticaResult): def pq_dsc(res: PanopticaResult): return res.sq_dsc * res.rq -#endregion -#region clDSC + +# endregion + + +# region clDSC def sq_cldsc(res: PanopticaResult): return res.get_list_metric(Metric.clDSC, mode=MetricMode.AVG) @@ -450,18 +487,24 @@ def sq_cldsc_std(res: PanopticaResult): def pq_cldsc(res: PanopticaResult): return res.sq_cldsc * res.rq -#endregion -#region ASSD + +# endregion + + +# region ASSD def sq_assd(res: PanopticaResult): return res.get_list_metric(Metric.ASSD, mode=MetricMode.AVG) def sq_assd_std(res: PanopticaResult): return res.get_list_metric(Metric.ASSD, mode=MetricMode.STD) -#endregion -#region Global + +# endregion + + +# region Global def global_bin_dsc(res: PanopticaResult): if res.tp == 0: return 0.0 @@ -471,6 +514,7 @@ def global_bin_dsc(res: PanopticaResult): ref_binary[ref_binary != 0] = 1 return _compute_dice_coefficient(ref_binary, pred_binary) + def global_bin_cldsc(res: PanopticaResult): if res.tp == 0: return 0.0 @@ -479,13 +523,15 @@ def global_bin_cldsc(res: PanopticaResult): pred_binary[pred_binary != 0] = 1 ref_binary[ref_binary != 0] = 1 return _compute_centerline_dice_coefficient(ref_binary, pred_binary) -#endregion + + +# endregion if __name__ == "__main__": c = PanopticaResult( - reference_arr=np.zeros([5,5,5]), - prediction_arr=np.zeros([5,5,5]), + reference_arr=np.zeros([5, 5, 5]), + prediction_arr=np.zeros([5, 5, 5]), num_ref_instances=2, num_pred_instances=5, tp=0, diff --git a/unit_tests/test_datatype.py b/unit_tests/test_datatype.py index b307b8c..faffc56 100644 --- a/unit_tests/test_datatype.py +++ b/unit_tests/test_datatype.py @@ -42,6 +42,4 @@ def test_matching_metric(self): self.assertFalse(assd_metric.score_beats_threshold(0.55, 0.5)) self.assertTrue(assd_metric.score_beats_threshold(0.5, 0.55)) - # TODO listmetric + Mode (STD and so on) - \ No newline at end of file diff --git a/unit_tests/test_panoptic_result.py b/unit_tests/test_panoptic_result.py index c79e91e..4a87b1d 100644 --- a/unit_tests/test_panoptic_result.py +++ b/unit_tests/test_panoptic_result.py @@ -82,10 +82,13 @@ def test_std_edge_case(self): def test_existing_metrics(self): from itertools import chain, combinations + def powerset(iterable): s = list(iterable) - return list(chain.from_iterable(combinations(s, r) for r in range(len(s)+1))) - + return list( + chain.from_iterable(combinations(s, r) for r in range(len(s) + 1)) + ) + power_set = powerset([Metric.DSC, Metric.IOU, Metric.ASSD]) for m in power_set[1:]: list_metrics: dict = {} @@ -130,4 +133,4 @@ def powerset(iterable): with self.assertRaises(MetricCouldNotBeComputedException): c.sq_assd with self.assertRaises(MetricCouldNotBeComputedException): - c.sq_assd_std \ No newline at end of file + c.sq_assd_std