diff --git a/examples/example_spine_instance.py b/examples/example_spine_instance.py index 49e3684..61720c3 100644 --- a/examples/example_spine_instance.py +++ b/examples/example_spine_instance.py @@ -4,7 +4,7 @@ from auxiliary.turbopath import turbopath from panoptica import MatchedInstancePair, Panoptic_Evaluator -from panoptica.metrics import MatchingMetrics, ListMetric, ListMetricMode +from panoptica.metrics import Metric, Metric, MetricMode directory = turbopath(__file__).parent @@ -17,16 +17,15 @@ evaluator = Panoptic_Evaluator( expected_input=MatchedInstancePair, - eval_metrics=[MatchingMetrics.clDSC, MatchingMetrics.DSC], - decision_metric=MatchingMetrics.DSC, + eval_metrics=[Metric.DSC, Metric.IOU], + decision_metric=Metric.DSC, decision_threshold=0.5, ) with cProfile.Profile() as pr: if __name__ == "__main__": - result, debug_data = evaluator.evaluate(sample) - + result, debug_data = evaluator.evaluate(sample, verbose=True) print(result) pr.dump_stats(directory + "/instance_example.log") diff --git a/examples/example_spine_semantic.py b/examples/example_spine_semantic.py index d522336..19ae99b 100644 --- a/examples/example_spine_semantic.py +++ b/examples/example_spine_semantic.py @@ -9,7 +9,7 @@ Panoptic_Evaluator, SemanticPair, ) -from panoptica.metrics import MatchingMetrics +from panoptica.metrics import Metric directory = turbopath(__file__).parent diff --git a/panoptica/_functionals.py b/panoptica/_functionals.py index 820f08e..071af97 100644 --- a/panoptica/_functionals.py +++ b/panoptica/_functionals.py @@ -2,7 +2,7 @@ import numpy as np -from panoptica.metrics import _compute_instance_iou, _MatchingMetric +from panoptica.metrics import _compute_instance_iou, Metric from panoptica.utils.constants import CCABackend from panoptica.utils.numpy_utils import _get_bbox_nd @@ -41,7 +41,7 @@ def _calc_matching_metric_of_overlapping_labels( prediction_arr: np.ndarray, reference_arr: np.ndarray, ref_labels: tuple[int, ...], - matching_metric: _MatchingMetric, + matching_metric: Metric, ) -> list[tuple[float, tuple[int, int]]]: """Calculates the MatchingMetric for all overlapping labels (fast!) @@ -62,7 +62,7 @@ def _calc_matching_metric_of_overlapping_labels( ) ] with Pool() as pool: - mm_values = pool.starmap(matching_metric._metric_function, instance_pairs) + mm_values = pool.starmap(matching_metric.value._metric_function, instance_pairs) mm_pairs = [ (i, (instance_pairs[idx][2], instance_pairs[idx][3])) diff --git a/panoptica/instance_evaluator.py b/panoptica/instance_evaluator.py index d14fbea..7421a6d 100644 --- a/panoptica/instance_evaluator.py +++ b/panoptica/instance_evaluator.py @@ -1,23 +1,16 @@ -import concurrent.futures -import gc from multiprocessing import Pool - import numpy as np -from panoptica.metrics import ( - _MatchingMetric, -) from panoptica.panoptic_result import PanopticaResult -from panoptica.timing import measure_time from panoptica.utils import EdgeCaseHandler from panoptica.utils.processing_pair import MatchedInstancePair -from panoptica.metrics import MatchingMetrics, ListMetric +from panoptica.metrics import Metric def evaluate_matched_instance( matched_instance_pair: MatchedInstancePair, - eval_metrics: list[_MatchingMetric] = [MatchingMetrics.DSC, MatchingMetrics.IOU, MatchingMetrics.ASSD], - decision_metric: _MatchingMetric | None = MatchingMetrics.IOU, + eval_metrics: list[Metric] = [Metric.DSC, Metric.IOU, Metric.ASSD], + decision_metric: Metric | None = Metric.IOU, decision_threshold: float | None = None, edge_case_handler: EdgeCaseHandler | None = None, **kwargs, @@ -41,11 +34,10 @@ def evaluate_matched_instance( edge_case_handler = EdgeCaseHandler() if decision_metric is not None: assert decision_metric.name in [v.name for v in eval_metrics], "decision metric not contained in eval_metrics" - assert decision_metric.name in ListMetric.__members__, f"decision metric {decision_metric} not a member of ListMetric" assert decision_threshold is not None, "decision metric set but no threshold" # Initialize variables for True Positives (tp) tp = len(matched_instance_pair.matched_instances) - score_dict: dict[ListMetric, list[float]] = {ListMetric[m.name]: [] for m in eval_metrics} + score_dict: dict[Metric, list[float]] = {m: [] for m in eval_metrics} reference_arr, prediction_arr = ( matched_instance_pair._reference_arr, @@ -55,12 +47,12 @@ def evaluate_matched_instance( instance_pairs = [(reference_arr, prediction_arr, ref_idx, eval_metrics) for ref_idx in ref_matched_labels] with Pool() as pool: - metric_dicts = pool.starmap(_evaluate_instance, instance_pairs) + metric_dicts: list[dict[Metric, float]] = pool.starmap(_evaluate_instance, instance_pairs) for metric_dict in metric_dicts: if decision_metric is None or ( decision_threshold is not None - and decision_metric.score_beats_threshold(metric_dict[ListMetric[decision_metric.name]], decision_threshold) + and decision_metric.score_beats_threshold(metric_dict[decision_metric], decision_threshold) ): for k, v in metric_dict.items(): score_dict[k].append(v) @@ -81,8 +73,8 @@ def _evaluate_instance( reference_arr: np.ndarray, prediction_arr: np.ndarray, ref_idx: int, - eval_metrics: list[_MatchingMetric], -) -> dict[ListMetric, float]: + eval_metrics: list[Metric], +) -> dict[Metric, float]: """ Evaluate a single instance. @@ -97,17 +89,12 @@ def _evaluate_instance( """ ref_arr = reference_arr == ref_idx pred_arr = prediction_arr == ref_idx - result: dict[ListMetric, float] = {} + result: dict[Metric, float] = {} if ref_arr.sum() == 0 or pred_arr.sum() == 0: return result else: for metric in eval_metrics: - try: - metric_name = ListMetric[metric.name] - except Exception as e: - print(f"metric with name {metric} not defined in ListMetric. Add it to it or remove it.") - raise e - value = metric._metric_function(ref_arr, pred_arr) - result[metric_name] = value + metric_value = metric(ref_arr, pred_arr) + result[metric] = metric_value return result diff --git a/panoptica/instance_matcher.py b/panoptica/instance_matcher.py index 6139de0..32a927c 100644 --- a/panoptica/instance_matcher.py +++ b/panoptica/instance_matcher.py @@ -6,7 +6,7 @@ _calc_matching_metric_of_overlapping_labels, _map_labels, ) -from panoptica.metrics import MatchingMetrics, _MatchingMetric +from panoptica.metrics import Metric, _Metric from panoptica.utils.processing_pair import ( InstanceLabelMap, MatchedInstancePair, @@ -153,7 +153,7 @@ class NaiveThresholdMatching(InstanceMatchingAlgorithm): def __init__( self, - matching_metric: _MatchingMetric = MatchingMetrics.IOU, + matching_metric: Metric = Metric.IOU, matching_threshold: float = 0.5, allow_many_to_one: bool = False, ) -> None: @@ -228,7 +228,7 @@ class MaximizeMergeMatching(InstanceMatchingAlgorithm): def __init__( self, - matching_metric: _MatchingMetric = MatchingMetrics.IOU, + matching_metric: Metric = Metric.IOU, matching_threshold: float = 0.5, ) -> None: """ diff --git a/panoptica/metrics/__init__.py b/panoptica/metrics/__init__.py index b27c599..373bfb3 100644 --- a/panoptica/metrics/__init__.py +++ b/panoptica/metrics/__init__.py @@ -14,4 +14,4 @@ _compute_centerline_dice, _compute_centerline_dice_coefficient, ) -from panoptica.metrics.metrics import MatchingMetrics, ListMetric, _MatchingMetric, ListMetricMode +from panoptica.metrics.metrics import Metric, Metric, _Metric, MetricMode diff --git a/panoptica/metrics/metrics.py b/panoptica/metrics/metrics.py index 054e054..b1133e4 100644 --- a/panoptica/metrics/metrics.py +++ b/panoptica/metrics/metrics.py @@ -1,6 +1,6 @@ from dataclasses import dataclass -from enum import EnumMeta -from typing import Callable +from enum import EnumMeta, Enum +from typing import Any, Callable import numpy as np @@ -14,7 +14,9 @@ @dataclass -class _MatchingMetric: +class _Metric: + """A Metric class containing a name, whether higher or lower values is better, and a function to calculate that metric between two instances in an array + """ name: str decreasing: bool _metric_function: Callable @@ -27,16 +29,16 @@ def __call__( pred_instance_idx: int | list[int] | None = None, *args, **kwargs, - ): + ) -> int | float: if ref_instance_idx is not None and pred_instance_idx is not None: reference_arr = reference_arr.copy() == ref_instance_idx if isinstance(pred_instance_idx, int): pred_instance_idx = [pred_instance_idx] - prediction_arr = np.isin(prediction_arr.copy(), pred_instance_idx) + prediction_arr = np.isin(prediction_arr.copy(), pred_instance_idx) #type:ignore return self._metric_function(reference_arr, prediction_arr, *args, **kwargs) def __eq__(self, __value: object) -> bool: - if isinstance(__value, _MatchingMetric): + if isinstance(__value, _Metric): return self.name == __value.name elif isinstance(__value, str): return self.name == __value @@ -48,6 +50,9 @@ def __str__(self) -> str: def __repr__(self) -> str: return str(self) + + def __hash__(self) -> int: + return abs(hash(self.name)) % (10**8) @property def increasing(self): @@ -57,37 +62,96 @@ def score_beats_threshold(self, matching_score: float, matching_threshold: float return (self.increasing and matching_score >= matching_threshold) or (self.decreasing and matching_score <= matching_threshold) -# Important metrics that must be calculated in the evaluator, can be set for thresholding in matching and evaluation -class MatchingMetrics: - DSC: _MatchingMetric = _MatchingMetric("DSC", False, _compute_dice_coefficient) - IOU: _MatchingMetric = _MatchingMetric("IOU", False, _compute_iou) - ASSD: _MatchingMetric = _MatchingMetric("ASSD", True, _average_symmetric_surface_distance) - clDSC: _MatchingMetric = _MatchingMetric("clDSC", False, _compute_centerline_dice_coefficient) +class DirectValueMeta(EnumMeta): + "Metaclass that allows for directly getting an enum attribute" + def __getattribute__(cls, name) -> _Metric: + value = super().__getattribute__(name) + if isinstance(value, cls): + value = value.value + return value + + +class Metric(_Enum_Compare): + """Enum containing important metrics that must be calculated in the evaluator, can be set for thresholding in matching and evaluation + Never call the .value member here, use the properties directly + + Returns: + _type_: _description_ + """ + DSC = _Metric("DSC", False, _compute_dice_coefficient) + IOU = _Metric("IOU", False, _compute_iou) + ASSD = _Metric("ASSD", True, _average_symmetric_surface_distance) + clDSC = _Metric("clDSC", False, _compute_centerline_dice_coefficient) -class ListMetricMode(_Enum_Compare): - ALL = auto() - AVG = auto() - SUM = auto() - STD = auto() + def __call__( + self, + reference_arr: np.ndarray, + prediction_arr: np.ndarray, + ref_instance_idx: int | None = None, + pred_instance_idx: int | list[int] | None = None, + *args, + **kwargs, + ) -> int | float: + """Calculates the underlaying metric + + Args: + reference_arr (np.ndarray): Reference array + prediction_arr (np.ndarray): Prediction array + ref_instance_idx (int | None, optional): The index label to be evaluated for the reference. Defaults to None. + pred_instance_idx (int | list[int] | None, optional): The index label to be evaluated for the prediction. Defaults to None. + Returns: + int | float: The metric value + """ + return self.value(reference_arr=reference_arr, prediction_arr=prediction_arr, ref_instance_idx=ref_instance_idx, pred_instance_idx=pred_instance_idx, *args, **kwargs,) + + def score_beats_threshold(self, matching_score: float, matching_threshold: float) -> bool: + """Calculates whether a score beats a specified threshold -class ListMetric(_Enum_Compare): - DSC = MatchingMetrics.DSC.name - IOU = MatchingMetrics.IOU.name - ASSD = MatchingMetrics.ASSD.name - clDSC = MatchingMetrics.clDSC.name + Args: + matching_score (float): Metric score + matching_threshold (float): Threshold to compare against + + Returns: + bool: True if the matching_score beats the threshold, False otherwise. + """ + return (self.increasing and matching_score >= matching_threshold) or (self.decreasing and matching_score <= matching_threshold) + + @property + def name(self): + return self.value.name + + @property + def decreasing(self): + return self.value.decreasing + + @property + def increasing(self): + return self.value.increasing def __hash__(self) -> int: - return abs(hash(self.value)) % (10**8) + return abs(hash(self.name)) % (10**8) + + +class MetricMode(_Enum_Compare): + """Different modalities from Metrics + + Args: + _Enum_Compare (_type_): _description_ + """ + ALL = auto() + AVG = auto() + SUM = auto() + STD = auto() if __name__ == "__main__": - print(MatchingMetrics.DSC) + print(Metric.DSC) # print(MatchingMetric.DSC.name) - print(MatchingMetrics.DSC == MatchingMetrics.DSC) - print(MatchingMetrics.DSC == "DSC") - print(MatchingMetrics.DSC.name == "DSC") + print(Metric.DSC == Metric.DSC) + print(Metric.DSC == "DSC") + print(Metric.DSC.name == "DSC") # - print(MatchingMetrics.DSC == MatchingMetrics.IOU) - print(MatchingMetrics.DSC == "IOU") + print(Metric.DSC == Metric.IOU) + print(Metric.DSC == "IOU") \ No newline at end of file diff --git a/panoptica/panoptic_evaluator.py b/panoptica/panoptic_evaluator.py index 4372c46..d562de6 100644 --- a/panoptica/panoptic_evaluator.py +++ b/panoptica/panoptic_evaluator.py @@ -5,7 +5,7 @@ from panoptica.instance_approximator import InstanceApproximator from panoptica.instance_evaluator import evaluate_matched_instance from panoptica.instance_matcher import InstanceMatchingAlgorithm -from panoptica.metrics import MatchingMetrics, _MatchingMetric, ListMetric +from panoptica.metrics import Metric, _Metric, Metric from panoptica.panoptic_result import PanopticaResult from panoptica.timing import measure_time from panoptica.utils import EdgeCaseHandler @@ -25,8 +25,8 @@ def __init__( instance_approximator: InstanceApproximator | None = None, instance_matcher: InstanceMatchingAlgorithm | None = None, edge_case_handler: EdgeCaseHandler | None = None, - eval_metrics: list[_MatchingMetric] = [MatchingMetrics.DSC, MatchingMetrics.IOU, MatchingMetrics.ASSD], - decision_metric: _MatchingMetric | None = None, + eval_metrics: list[Metric] = [Metric.DSC, Metric.IOU, Metric.ASSD], + decision_metric: Metric | None = None, decision_threshold: float | None = None, log_times: bool = False, verbose: bool = False, @@ -60,6 +60,7 @@ def evaluate( self, processing_pair: SemanticPair | UnmatchedInstancePair | MatchedInstancePair | PanopticaResult, result_all: bool = True, + verbose: bool | None = None, ) -> tuple[PanopticaResult, dict[str, _ProcessingPair]]: assert type(processing_pair) == self.__expected_input, f"input not of expected type {self.__expected_input}" return panoptic_evaluate( @@ -72,7 +73,7 @@ def evaluate( decision_threshold=self.__decision_threshold, result_all=result_all, log_times=self.__log_times, - verbose=self.__verbose, + verbose=self.__verbose if verbose is None else verbose, ) @@ -80,8 +81,8 @@ def panoptic_evaluate( processing_pair: SemanticPair | UnmatchedInstancePair | MatchedInstancePair | PanopticaResult, instance_approximator: InstanceApproximator | None = None, instance_matcher: InstanceMatchingAlgorithm | None = None, - eval_metrics: list[_MatchingMetric] = [MatchingMetrics.DSC, MatchingMetrics.IOU, MatchingMetrics.ASSD], - decision_metric: _MatchingMetric | None = None, + eval_metrics: list[Metric] = [Metric.DSC, Metric.IOU, Metric.ASSD], + decision_metric: Metric | None = None, decision_threshold: float | None = None, edge_case_handler: EdgeCaseHandler | None = None, log_times: bool = False, @@ -173,7 +174,7 @@ def panoptic_evaluate( if isinstance(processing_pair, PanopticaResult): if result_all: - processing_pair.calculate_all(print_errors=False) + processing_pair.calculate_all(print_errors=verbose) return processing_pair, debug_data raise RuntimeError("End of panoptic pipeline reached without results") @@ -182,7 +183,7 @@ def panoptic_evaluate( def _handle_zero_instances_cases( processing_pair: UnmatchedInstancePair | MatchedInstancePair, edge_case_handler: EdgeCaseHandler, - eval_metrics: list[_MatchingMetric] = [MatchingMetrics.DSC, MatchingMetrics.IOU, MatchingMetrics.ASSD], + eval_metrics: list[_Metric] = [Metric.DSC, Metric.IOU, Metric.ASSD], ) -> UnmatchedInstancePair | MatchedInstancePair | PanopticaResult: """ Handle edge cases when comparing reference and prediction masks. @@ -198,7 +199,7 @@ def _handle_zero_instances_cases( n_prediction_instance = processing_pair.n_prediction_instance panoptica_result_args = { - "list_metrics": {ListMetric[k.name]: [] for k in eval_metrics}, + "list_metrics": {Metric[k.name]: [] for k in eval_metrics}, "tp": 0, "edge_case_handler": edge_case_handler, "reference_arr": processing_pair.reference_arr, diff --git a/panoptica/panoptic_result.py b/panoptica/panoptic_result.py index 761b154..a0e15dd 100644 --- a/panoptica/panoptic_result.py +++ b/panoptica/panoptic_result.py @@ -2,13 +2,15 @@ from typing import Any, Callable import numpy as np -from panoptica.metrics import ListMetricMode, ListMetric +from panoptica.metrics import MetricMode, Metric from panoptica.metrics import _compute_dice_coefficient, _compute_centerline_dice_coefficient from panoptica.utils import EdgeCaseHandler from panoptica.utils.processing_pair import MatchedInstancePair class MetricCouldNotBeComputedException(Exception): + """Exception for when a Metric cannot be computed + """ def __init__(self, *args: object) -> None: super().__init__(*args) @@ -36,13 +38,23 @@ def __init__( self.long_name = long_name self.was_calculated = was_calculated self.error = error + self.error_obj: MetricCouldNotBeComputedException | None = None def __call__(self, result_obj: PanopticaResult) -> Any: if self.error: - raise MetricCouldNotBeComputedException(f"Metric {self.id} requested, but could not be computed") + if self.error_obj is None: + raise MetricCouldNotBeComputedException(f"Metric {self.id} requested, but could not be computed") + else: + raise self.error_obj assert not self.was_calculated, f"Metric {self.id} was called to compute, but is set to have been already calculated" assert self.calc_func is not None, f"Metric {self.id} was called to compute, but has no calculation function set" - return self.calc_func(result_obj) + try: + value = self.calc_func(result_obj) + except MetricCouldNotBeComputedException as e: + value = e + self.error = True + self.error_obj = e + return value def __str__(self) -> str: if self.long_name is not None: @@ -54,16 +66,16 @@ def __str__(self) -> str: class Evaluation_List_Metric: def __init__( self, - name_id: ListMetric, + name_id: Metric, empty_list_std: float | None, value_list: list[float] | None, # None stands for not calculated is_edge_case: bool = False, edge_case_result: float | None = None, ): - """This represents the metrics resulting from a ListMetric (IoU, ASSD, Dice) + """This represents the metrics resulting from a Metric calculated between paired instances (IoU, ASSD, Dice, ...) Args: - name_id (ListMetric): code-name of this metric + name_id (Metric): code-name of this metric empty_list_std (float): Value for the standard deviation if the list of values is empty value_list (list[float] | None): List of values of that metric (only the TPs) """ @@ -78,10 +90,10 @@ def __init__( self.SUM = None if self.ALL is None else np.sum(self.ALL) self.STD = None if self.ALL is None else empty_list_std if len(self.ALL) == 0 else np.std(self.ALL) - def __getitem__(self, mode: ListMetricMode | str): + def __getitem__(self, mode: MetricMode | str): if self.error: raise MetricCouldNotBeComputedException(f"Metric {self.id} has not been calculated, add it to your eval_metrics") - if isinstance(mode, ListMetricMode): + if isinstance(mode, MetricMode): mode = mode.name if hasattr(self, mode): return getattr(self, mode) @@ -98,9 +110,20 @@ def __init__( num_pred_instances: int, num_ref_instances: int, tp: int, - list_metrics: dict[ListMetric, list[float]], + list_metrics: dict[Metric, list[float]], edge_case_handler: EdgeCaseHandler, ): + """Result object for Panoptica, contains all calculatable metrics + + Args: + reference_arr (np.ndarray): matched reference arr + prediction_arr (np.ndarray): matched prediction arr + num_pred_instances (int): number of prediction instances + num_ref_instances (int): number of reference instances + tp (int): number of true positives (matched instances) + list_metrics (dict[Metric, list[float]]): dictionary containing the metrics for each TP + edge_case_handler (EdgeCaseHandler): EdgeCaseHandler object that handles various forms of edge cases + """ self._edge_case_handler = edge_case_handler empty_list_std = self._edge_case_handler.handle_empty_list_std() self._prediction_arr = prediction_arr @@ -255,7 +278,7 @@ def __init__( ################## # List Metrics # ################## - self._list_metrics: dict[ListMetric, Evaluation_List_Metric] = {} + self._list_metrics: dict[Metric, Evaluation_List_Metric] = {} for k, v in list_metrics.items(): is_edge_case, edge_case_result = self._edge_case_handler.handle_zero_tp( metric=k, @@ -284,6 +307,11 @@ def _add_metric( return default_value def calculate_all(self, print_errors: bool = False): + """Calculates all possible metrics that can be derived + + Args: + print_errors (bool, optional): If true, will print every metric that could not be computed and its reason. Defaults to False. + """ metric_errors: dict[str, Exception] = {} for k, v in self._evaluation_metrics.items(): try: @@ -316,7 +344,7 @@ def __str__(self) -> str: def to_dict(self) -> dict: return self._evaluation_metrics - def get_list_metric(self, metric: ListMetric, mode: ListMetricMode): + def get_list_metric(self, metric: Metric, mode: MetricMode): if metric in self._list_metrics: return self._list_metrics[metric][mode] else: @@ -387,11 +415,11 @@ def rq(res: PanopticaResult): #region IOU def sq(res: PanopticaResult): - return res.get_list_metric(ListMetric.IOU, mode=ListMetricMode.AVG) + return res.get_list_metric(Metric.IOU, mode=MetricMode.AVG) def sq_std(res: PanopticaResult): - return res.get_list_metric(ListMetric.IOU, mode=ListMetricMode.STD) + return res.get_list_metric(Metric.IOU, mode=MetricMode.STD) def pq(res: PanopticaResult): @@ -400,11 +428,11 @@ def pq(res: PanopticaResult): #region DSC def sq_dsc(res: PanopticaResult): - return res.get_list_metric(ListMetric.DSC, mode=ListMetricMode.AVG) + return res.get_list_metric(Metric.DSC, mode=MetricMode.AVG) def sq_dsc_std(res: PanopticaResult): - return res.get_list_metric(ListMetric.DSC, mode=ListMetricMode.STD) + return res.get_list_metric(Metric.DSC, mode=MetricMode.STD) def pq_dsc(res: PanopticaResult): @@ -413,11 +441,11 @@ def pq_dsc(res: PanopticaResult): #region clDSC def sq_cldsc(res: PanopticaResult): - return res.get_list_metric(ListMetric.clDSC, mode=ListMetricMode.AVG) + return res.get_list_metric(Metric.clDSC, mode=MetricMode.AVG) def sq_cldsc_std(res: PanopticaResult): - return res.get_list_metric(ListMetric.clDSC, mode=ListMetricMode.STD) + return res.get_list_metric(Metric.clDSC, mode=MetricMode.STD) def pq_cldsc(res: PanopticaResult): @@ -426,14 +454,17 @@ def pq_cldsc(res: PanopticaResult): #region ASSD def sq_assd(res: PanopticaResult): - return res.get_list_metric(ListMetric.ASSD, mode=ListMetricMode.AVG) + return res.get_list_metric(Metric.ASSD, mode=MetricMode.AVG) def sq_assd_std(res: PanopticaResult): - return res.get_list_metric(ListMetric.ASSD, mode=ListMetricMode.STD) + return res.get_list_metric(Metric.ASSD, mode=MetricMode.STD) #endregion +#region Global def global_bin_dsc(res: PanopticaResult): + if res.tp == 0: + return 0.0 pred_binary = res._prediction_arr.copy() ref_binary = res._reference_arr.copy() pred_binary[pred_binary != 0] = 1 @@ -441,19 +472,24 @@ def global_bin_dsc(res: PanopticaResult): return _compute_dice_coefficient(ref_binary, pred_binary) def global_bin_cldsc(res: PanopticaResult): + if res.tp == 0: + return 0.0 pred_binary = res._prediction_arr.copy() ref_binary = res._reference_arr.copy() pred_binary[pred_binary != 0] = 1 ref_binary[ref_binary != 0] = 1 return _compute_centerline_dice_coefficient(ref_binary, pred_binary) +#endregion if __name__ == "__main__": c = PanopticaResult( + reference_arr=np.zeros([5,5,5]), + prediction_arr=np.zeros([5,5,5]), num_ref_instances=2, num_pred_instances=5, tp=0, - list_metrics={ListMetric.IOU: []}, + list_metrics={Metric.IOU: []}, edge_case_handler=EdgeCaseHandler(), ) diff --git a/panoptica/utils/edge_case_handling.py b/panoptica/utils/edge_case_handling.py index 67c3e41..33b4c29 100644 --- a/panoptica/utils/edge_case_handling.py +++ b/panoptica/utils/edge_case_handling.py @@ -2,7 +2,7 @@ import numpy as np -from panoptica.metrics import ListMetric, MatchingMetrics +from panoptica.metrics import Metric, Metric from panoptica.utils.constants import _Enum_Compare, auto @@ -81,17 +81,17 @@ def __str__(self) -> str: class EdgeCaseHandler: def __init__( self, - listmetric_zeroTP_handling: dict[ListMetric, MetricZeroTPEdgeCaseHandling] = { - ListMetric.DSC: MetricZeroTPEdgeCaseHandling( + listmetric_zeroTP_handling: dict[Metric, MetricZeroTPEdgeCaseHandling] = { + Metric.DSC: MetricZeroTPEdgeCaseHandling( no_instances_result=EdgeCaseResult.NAN, default_result=EdgeCaseResult.ZERO, ), - ListMetric.IOU: MetricZeroTPEdgeCaseHandling( + Metric.IOU: MetricZeroTPEdgeCaseHandling( no_instances_result=EdgeCaseResult.NAN, empty_prediction_result=EdgeCaseResult.ZERO, default_result=EdgeCaseResult.ZERO, ), - ListMetric.ASSD: MetricZeroTPEdgeCaseHandling( + Metric.ASSD: MetricZeroTPEdgeCaseHandling( no_instances_result=EdgeCaseResult.NAN, default_result=EdgeCaseResult.INF, ), @@ -99,13 +99,13 @@ def __init__( empty_list_std: EdgeCaseResult = EdgeCaseResult.NAN, ) -> None: self.__listmetric_zeroTP_handling: dict[ - ListMetric, MetricZeroTPEdgeCaseHandling + Metric, MetricZeroTPEdgeCaseHandling ] = listmetric_zeroTP_handling self.__empty_list_std: EdgeCaseResult = empty_list_std def handle_zero_tp( self, - metric: ListMetric, + metric: Metric, tp: int, num_pred_instances: int, num_ref_instances: int, @@ -123,7 +123,7 @@ def handle_zero_tp( num_ref_instances=num_ref_instances, ) - def get_metric_zero_tp_handle(self, metric: ListMetric): + def get_metric_zero_tp_handle(self, metric: Metric): return self.__listmetric_zeroTP_handling[metric] def handle_empty_list_std(self) -> float | None: @@ -142,7 +142,7 @@ def __str__(self) -> str: print() # print(handler.get_metric_zero_tp_handle(ListMetric.IOU)) r = handler.handle_zero_tp( - ListMetric.IOU, tp=0, num_pred_instances=1, num_ref_instances=1 + Metric.IOU, tp=0, num_pred_instances=1, num_ref_instances=1 ) print(r) diff --git a/unit_tests/test_datatype.py b/unit_tests/test_datatype.py new file mode 100644 index 0000000..b307b8c --- /dev/null +++ b/unit_tests/test_datatype.py @@ -0,0 +1,47 @@ +# Call 'python -m unittest' on this folder +# coverage run -m unittest +# coverage report +# coverage html +import unittest +import os +import numpy as np + +from panoptica.panoptic_evaluator import Panoptic_Evaluator +from panoptica.instance_approximator import ConnectedComponentsInstanceApproximator +from panoptica.instance_matcher import NaiveThresholdMatching, MaximizeMergeMatching +from panoptica.panoptic_result import PanopticaResult, MetricCouldNotBeComputedException +from panoptica.metrics import _Metric, Metric, Metric, MetricMode +from panoptica.utils.edge_case_handling import EdgeCaseHandler, EdgeCaseResult +from panoptica.utils.processing_pair import SemanticPair + + +class Test_Panoptic_Evaluator(unittest.TestCase): + def setUp(self) -> None: + os.environ["PANOPTICA_CITATION_REMINDER"] = "False" + return super().setUp() + + def test_metrics_enum(self): + print(Metric.DSC) + # print(MatchingMetric.DSC.name) + + self.assertEqual(Metric.DSC, Metric.DSC) + self.assertEqual(Metric.DSC, "DSC") + self.assertEqual(Metric.DSC.name, "DSC") + # + self.assertNotEqual(Metric.DSC, Metric.IOU) + self.assertNotEqual(Metric.DSC, "IOU") + + def test_matching_metric(self): + dsc_metric = Metric.DSC + + self.assertTrue(dsc_metric.score_beats_threshold(0.55, 0.5)) + self.assertFalse(dsc_metric.score_beats_threshold(0.5, 0.55)) + + assd_metric = Metric.ASSD + + self.assertFalse(assd_metric.score_beats_threshold(0.55, 0.5)) + self.assertTrue(assd_metric.score_beats_threshold(0.5, 0.55)) + + + # TODO listmetric + Mode (STD and so on) + \ No newline at end of file diff --git a/unit_tests/test_panoptic_evaluator.py b/unit_tests/test_panoptic_evaluator.py index 817879a..f732701 100644 --- a/unit_tests/test_panoptic_evaluator.py +++ b/unit_tests/test_panoptic_evaluator.py @@ -9,7 +9,7 @@ from panoptica.panoptic_evaluator import Panoptic_Evaluator from panoptica.instance_approximator import ConnectedComponentsInstanceApproximator from panoptica.instance_matcher import NaiveThresholdMatching, MaximizeMergeMatching -from panoptica.metrics import _MatchingMetric, MatchingMetrics +from panoptica.metrics import _Metric, Metric from panoptica.utils.processing_pair import SemanticPair from panoptica.panoptic_result import PanopticaResult, MetricCouldNotBeComputedException @@ -72,8 +72,8 @@ def test_simple_evaluation_DSC_partial(self): evaluator = Panoptic_Evaluator( expected_input=SemanticPair, instance_approximator=ConnectedComponentsInstanceApproximator(), - instance_matcher=NaiveThresholdMatching(matching_metric=MatchingMetrics.DSC), - eval_metrics=[MatchingMetrics.DSC], + instance_matcher=NaiveThresholdMatching(matching_metric=Metric.DSC), + eval_metrics=[Metric.DSC], ) result, debug_data = evaluator.evaluate(sample) @@ -99,7 +99,7 @@ def test_simple_evaluation_ASSD(self): expected_input=SemanticPair, instance_approximator=ConnectedComponentsInstanceApproximator(), instance_matcher=NaiveThresholdMatching( - matching_metric=MatchingMetrics.ASSD, + matching_metric=Metric.ASSD, matching_threshold=1.0, ), ) @@ -123,7 +123,7 @@ def test_simple_evaluation_ASSD_negative(self): expected_input=SemanticPair, instance_approximator=ConnectedComponentsInstanceApproximator(), instance_matcher=NaiveThresholdMatching( - matching_metric=MatchingMetrics.ASSD, + matching_metric=Metric.ASSD, matching_threshold=0.5, ), ) diff --git a/unit_tests/test_panoptic_result.py b/unit_tests/test_panoptic_result.py index e80eb79..c79e91e 100644 --- a/unit_tests/test_panoptic_result.py +++ b/unit_tests/test_panoptic_result.py @@ -10,7 +10,7 @@ from panoptica.instance_approximator import ConnectedComponentsInstanceApproximator from panoptica.instance_matcher import NaiveThresholdMatching, MaximizeMergeMatching from panoptica.panoptic_result import PanopticaResult, MetricCouldNotBeComputedException -from panoptica.metrics import _MatchingMetric, MatchingMetrics, ListMetric, ListMetricMode +from panoptica.metrics import _Metric, Metric, Metric, MetricMode from panoptica.utils.edge_case_handling import EdgeCaseHandler, EdgeCaseResult from panoptica.utils.processing_pair import SemanticPair @@ -27,7 +27,7 @@ def test_simple_evaluation(self): num_ref_instances=2, num_pred_instances=5, tp=0, - list_metrics={ListMetric.IOU: []}, + list_metrics={Metric.IOU: []}, edge_case_handler=EdgeCaseHandler(), ) c.calculate_all(print_errors=True) @@ -49,7 +49,7 @@ def test_simple_tp_fp(self): num_ref_instances=n_ref, num_pred_instances=n_pred, tp=tp, - list_metrics={ListMetric.IOU: []}, + list_metrics={Metric.IOU: []}, edge_case_handler=EdgeCaseHandler(), ) c.calculate_all(print_errors=False) @@ -67,7 +67,7 @@ def test_std_edge_case(self): num_ref_instances=2, num_pred_instances=5, tp=0, - list_metrics={ListMetric.IOU: []}, + list_metrics={Metric.IOU: []}, edge_case_handler=EdgeCaseHandler(empty_list_std=ecr), ) c.calculate_all(print_errors=True) @@ -86,7 +86,7 @@ def powerset(iterable): s = list(iterable) return list(chain.from_iterable(combinations(s, r) for r in range(len(s)+1))) - power_set = powerset([ListMetric.DSC, ListMetric.IOU, ListMetric.ASSD]) + power_set = powerset([Metric.DSC, Metric.IOU, Metric.ASSD]) for m in power_set[1:]: list_metrics: dict = {} for me in m: @@ -105,7 +105,7 @@ def powerset(iterable): c.calculate_all(print_errors=True) print(c) - if ListMetric.DSC in list_metrics: + if Metric.DSC in list_metrics: self.assertEqual(c.sq_dsc, 1.0) self.assertEqual(c.sq_dsc_std, 0.0) else: @@ -114,7 +114,7 @@ def powerset(iterable): with self.assertRaises(MetricCouldNotBeComputedException): c.sq_dsc_std # - if ListMetric.IOU in list_metrics: + if Metric.IOU in list_metrics: self.assertEqual(c.sq, 1.0) self.assertEqual(c.sq_std, 0.0) else: @@ -123,7 +123,7 @@ def powerset(iterable): with self.assertRaises(MetricCouldNotBeComputedException): c.sq_std # - if ListMetric.ASSD in list_metrics: + if Metric.ASSD in list_metrics: self.assertEqual(c.sq_assd, 1.0) self.assertEqual(c.sq_assd_std, 0.0) else: