diff --git a/examples/example_spine_instance.py b/examples/example_spine_instance.py index 2717721..61720c3 100644 --- a/examples/example_spine_instance.py +++ b/examples/example_spine_instance.py @@ -4,7 +4,7 @@ from auxiliary.turbopath import turbopath from panoptica import MatchedInstancePair, Panoptic_Evaluator -from panoptica.metrics import Metrics +from panoptica.metrics import Metric, Metric, MetricMode directory = turbopath(__file__).parent @@ -17,16 +17,15 @@ evaluator = Panoptic_Evaluator( expected_input=MatchedInstancePair, - eval_metrics=[Metrics.ASSD, Metrics.IOU], - decision_metric=Metrics.IOU, + eval_metrics=[Metric.DSC, Metric.IOU], + decision_metric=Metric.DSC, decision_threshold=0.5, ) with cProfile.Profile() as pr: if __name__ == "__main__": - result, debug_data = evaluator.evaluate(sample) - + result, debug_data = evaluator.evaluate(sample, verbose=True) print(result) pr.dump_stats(directory + "/instance_example.log") diff --git a/examples/example_spine_semantic.py b/examples/example_spine_semantic.py index ac67b25..19ae99b 100644 --- a/examples/example_spine_semantic.py +++ b/examples/example_spine_semantic.py @@ -9,7 +9,7 @@ Panoptic_Evaluator, SemanticPair, ) -from panoptica.metrics import Metrics +from panoptica.metrics import Metric directory = turbopath(__file__).parent diff --git a/panoptica/_functionals.py b/panoptica/_functionals.py index 820f08e..071af97 100644 --- a/panoptica/_functionals.py +++ b/panoptica/_functionals.py @@ -2,7 +2,7 @@ import numpy as np -from panoptica.metrics import _compute_instance_iou, _MatchingMetric +from panoptica.metrics import _compute_instance_iou, Metric from panoptica.utils.constants import CCABackend from panoptica.utils.numpy_utils import _get_bbox_nd @@ -41,7 +41,7 @@ def _calc_matching_metric_of_overlapping_labels( prediction_arr: np.ndarray, reference_arr: np.ndarray, ref_labels: tuple[int, ...], - matching_metric: _MatchingMetric, + matching_metric: Metric, ) -> list[tuple[float, tuple[int, int]]]: """Calculates the MatchingMetric for all overlapping labels (fast!) @@ -62,7 +62,7 @@ def _calc_matching_metric_of_overlapping_labels( ) ] with Pool() as pool: - mm_values = pool.starmap(matching_metric._metric_function, instance_pairs) + mm_values = pool.starmap(matching_metric.value._metric_function, instance_pairs) mm_pairs = [ (i, (instance_pairs[idx][2], instance_pairs[idx][3])) diff --git a/panoptica/instance_evaluator.py b/panoptica/instance_evaluator.py index f27fb11..028e19f 100644 --- a/panoptica/instance_evaluator.py +++ b/panoptica/instance_evaluator.py @@ -1,23 +1,16 @@ -import concurrent.futures -import gc from multiprocessing import Pool - import numpy as np -from panoptica.metrics import ( - _MatchingMetric, -) from panoptica.panoptic_result import PanopticaResult -from panoptica.timing import measure_time from panoptica.utils import EdgeCaseHandler from panoptica.utils.processing_pair import MatchedInstancePair -from panoptica.metrics import Metrics +from panoptica.metrics import Metric def evaluate_matched_instance( matched_instance_pair: MatchedInstancePair, - eval_metrics: list[_MatchingMetric] = [Metrics.DSC, Metrics.IOU, Metrics.ASSD], - decision_metric: _MatchingMetric | None = Metrics.IOU, + eval_metrics: list[Metric] = [Metric.DSC, Metric.IOU, Metric.ASSD], + decision_metric: Metric | None = Metric.IOU, decision_threshold: float | None = None, edge_case_handler: EdgeCaseHandler | None = None, **kwargs, @@ -46,9 +39,7 @@ def evaluate_matched_instance( assert decision_threshold is not None, "decision metric set but no threshold" # Initialize variables for True Positives (tp) tp = len(matched_instance_pair.matched_instances) - score_dict: dict[str | _MatchingMetric, list[float]] = { - m.name: [] for m in eval_metrics - } + score_dict: dict[Metric, list[float]] = {m: [] for m in eval_metrics} reference_arr, prediction_arr = ( matched_instance_pair._reference_arr, @@ -61,13 +52,15 @@ def evaluate_matched_instance( for ref_idx in ref_matched_labels ] with Pool() as pool: - metric_dicts = pool.starmap(_evaluate_instance, instance_pairs) + metric_dicts: list[dict[Metric, float]] = pool.starmap( + _evaluate_instance, instance_pairs + ) for metric_dict in metric_dicts: if decision_metric is None or ( decision_threshold is not None and decision_metric.score_beats_threshold( - metric_dict[decision_metric.name], decision_threshold + metric_dict[decision_metric], decision_threshold ) ): for k, v in metric_dict.items(): @@ -75,8 +68,10 @@ def evaluate_matched_instance( # Create and return the PanopticaResult object with computed metrics return PanopticaResult( - num_ref_instances=matched_instance_pair.n_reference_instance, + reference_arr=matched_instance_pair.reference_arr, + prediction_arr=matched_instance_pair.prediction_arr, num_pred_instances=matched_instance_pair.n_prediction_instance, + num_ref_instances=matched_instance_pair.n_reference_instance, tp=tp, list_metrics=score_dict, edge_case_handler=edge_case_handler, @@ -87,8 +82,8 @@ def _evaluate_instance( reference_arr: np.ndarray, prediction_arr: np.ndarray, ref_idx: int, - eval_metrics: list[_MatchingMetric], -) -> dict[str, float]: + eval_metrics: list[Metric], +) -> dict[Metric, float]: """ Evaluate a single instance. @@ -103,12 +98,12 @@ def _evaluate_instance( """ ref_arr = reference_arr == ref_idx pred_arr = prediction_arr == ref_idx - result: dict[str, float] = {} + result: dict[Metric, float] = {} if ref_arr.sum() == 0 or pred_arr.sum() == 0: return result else: for metric in eval_metrics: - value = metric._metric_function(ref_arr, pred_arr) - result[metric.name] = value + metric_value = metric(ref_arr, pred_arr) + result[metric] = metric_value return result diff --git a/panoptica/instance_matcher.py b/panoptica/instance_matcher.py index ba0902e..32a927c 100644 --- a/panoptica/instance_matcher.py +++ b/panoptica/instance_matcher.py @@ -6,7 +6,7 @@ _calc_matching_metric_of_overlapping_labels, _map_labels, ) -from panoptica.metrics import Metrics, _MatchingMetric +from panoptica.metrics import Metric, _Metric from panoptica.utils.processing_pair import ( InstanceLabelMap, MatchedInstancePair, @@ -153,7 +153,7 @@ class NaiveThresholdMatching(InstanceMatchingAlgorithm): def __init__( self, - matching_metric: _MatchingMetric = Metrics.IOU, + matching_metric: Metric = Metric.IOU, matching_threshold: float = 0.5, allow_many_to_one: bool = False, ) -> None: @@ -228,7 +228,7 @@ class MaximizeMergeMatching(InstanceMatchingAlgorithm): def __init__( self, - matching_metric: _MatchingMetric = Metrics.IOU, + matching_metric: Metric = Metric.IOU, matching_threshold: float = 0.5, ) -> None: """ diff --git a/panoptica/metrics/__init__.py b/panoptica/metrics/__init__.py index fd32f08..f636541 100644 --- a/panoptica/metrics/__init__.py +++ b/panoptica/metrics/__init__.py @@ -6,11 +6,12 @@ _compute_dice_coefficient, _compute_instance_volumetric_dice, ) -from panoptica.metrics.iou import _compute_instance_iou, _compute_iou -from panoptica.metrics.metrics import ( - Metrics, - ListMetric, - EvalMetric, - MetricDict, - _MatchingMetric, +from panoptica.metrics.iou import ( + _compute_instance_iou, + _compute_iou, ) +from panoptica.metrics.cldice import ( + _compute_centerline_dice, + _compute_centerline_dice_coefficient, +) +from panoptica.metrics.metrics import Metric, _Metric, MetricMode diff --git a/panoptica/metrics/cldice.py b/panoptica/metrics/cldice.py new file mode 100644 index 0000000..3924751 --- /dev/null +++ b/panoptica/metrics/cldice.py @@ -0,0 +1,57 @@ +from skimage.morphology import skeletonize, skeletonize_3d +import numpy as np + + +def cl_score(volume: np.ndarray, skeleton: np.ndarray): + """Computes the skeleton volume overlap + + Args: + volume (np.ndarray): volume + skeleton (np.ndarray): skeleton + + Returns: + _type_: skeleton overlap + """ + return np.sum(volume * skeleton) / np.sum(skeleton) + + +def _compute_centerline_dice( + ref_labels: np.ndarray, + pred_labels: np.ndarray, + ref_instance_idx: int, + pred_instance_idx: int, +) -> float: + """Compute the centerline Dice (clDice) coefficient between a specific pair of instances. + + Args: + ref_labels (np.ndarray): Reference instance labels. + pred_labels (np.ndarray): Prediction instance labels. + ref_instance_idx (int): Index of the reference instance. + pred_instance_idx (int): Index of the prediction instance. + + Returns: + float: clDice coefficient + """ + ref_instance_mask = ref_labels == ref_instance_idx + pred_instance_mask = pred_labels == pred_instance_idx + return _compute_centerline_dice_coefficient( + reference=ref_instance_mask, + prediction=pred_instance_mask, + ) + + +def _compute_centerline_dice_coefficient( + reference: np.ndarray, + prediction: np.ndarray, + *args, +) -> float: + ndim = reference.ndim + assert 2 <= ndim <= 3, "clDice only implemented for 2D or 3D" + if ndim == 2: + tprec = cl_score(prediction, skeletonize(reference)) + tsens = cl_score(reference, skeletonize(prediction)) + elif ndim == 3: + tprec = cl_score(prediction, skeletonize_3d(reference)) + tsens = cl_score(reference, skeletonize_3d(prediction)) + + return 2 * tprec * tsens / (tprec + tsens) diff --git a/panoptica/metrics/metrics.py b/panoptica/metrics/metrics.py index b6e7117..3b2b46f 100644 --- a/panoptica/metrics/metrics.py +++ b/panoptica/metrics/metrics.py @@ -1,6 +1,6 @@ from dataclasses import dataclass -from enum import EnumMeta -from typing import Callable +from enum import EnumMeta, Enum +from typing import Any, Callable import numpy as np @@ -8,12 +8,15 @@ _average_symmetric_surface_distance, _compute_dice_coefficient, _compute_iou, + _compute_centerline_dice_coefficient, ) -from panoptica.utils.constants import Enum, _Enum_Compare, auto +from panoptica.utils.constants import _Enum_Compare, auto @dataclass -class _MatchingMetric: +class _Metric: + """A Metric class containing a name, whether higher or lower values is better, and a function to calculate that metric between two instances in an array""" + name: str decreasing: bool _metric_function: Callable @@ -26,16 +29,18 @@ def __call__( pred_instance_idx: int | list[int] | None = None, *args, **kwargs, - ): + ) -> int | float: if ref_instance_idx is not None and pred_instance_idx is not None: reference_arr = reference_arr.copy() == ref_instance_idx if isinstance(pred_instance_idx, int): pred_instance_idx = [pred_instance_idx] - prediction_arr = np.isin(prediction_arr.copy(), pred_instance_idx) + prediction_arr = np.isin( + prediction_arr.copy(), pred_instance_idx + ) # type:ignore return self._metric_function(reference_arr, prediction_arr, *args, **kwargs) def __eq__(self, __value: object) -> bool: - if isinstance(__value, _MatchingMetric): + if isinstance(__value, _Metric): return self.name == __value.name elif isinstance(__value, str): return self.name == __value @@ -48,6 +53,9 @@ def __str__(self) -> str: def __repr__(self) -> str: return str(self) + def __hash__(self) -> int: + return abs(hash(self.name)) % (10**8) + @property def increasing(self): return not self.decreasing @@ -60,65 +68,110 @@ def score_beats_threshold( ) -# class _EnumMeta(EnumMeta): -# def __getattribute__(cls, name) -> MatchingMetric: -# value = super().__getattribute__(name) -# if isinstance(value, cls): -# value = value.value -# return value +class DirectValueMeta(EnumMeta): + "Metaclass that allows for directly getting an enum attribute" + def __getattribute__(cls, name) -> _Metric: + value = super().__getattribute__(name) + if isinstance(value, cls): + value = value.value + return value -# Important metrics that must be calculated in the evaluator, can be set for thresholding in matching and evaluation -# TODO make abstract class for metric, make enum with references to these classes for referenciation and user exposure -class Metrics: - # TODO make this with meta above, and then it can function without the double name, right? - DSC = _MatchingMetric("DSC", False, _compute_dice_coefficient) - IOU = _MatchingMetric("IOU", False, _compute_iou) - ASSD = _MatchingMetric("ASSD", True, _average_symmetric_surface_distance) - # These are all lists of values +class Metric(_Enum_Compare): + """Enum containing important metrics that must be calculated in the evaluator, can be set for thresholding in matching and evaluation + Never call the .value member here, use the properties directly -class ListMetric(_Enum_Compare): - DSC = Metrics.DSC.name - IOU = Metrics.IOU.name - ASSD = Metrics.ASSD.name + Returns: + _type_: _description_ + """ - def __hash__(self) -> int: - return abs(hash(self.value)) % (10**8) + DSC = _Metric("DSC", False, _compute_dice_coefficient) + IOU = _Metric("IOU", False, _compute_iou) + ASSD = _Metric("ASSD", True, _average_symmetric_surface_distance) + clDSC = _Metric("clDSC", False, _compute_centerline_dice_coefficient) + + def __call__( + self, + reference_arr: np.ndarray, + prediction_arr: np.ndarray, + ref_instance_idx: int | None = None, + pred_instance_idx: int | list[int] | None = None, + *args, + **kwargs, + ) -> int | float: + """Calculates the underlaying metric + + Args: + reference_arr (np.ndarray): Reference array + prediction_arr (np.ndarray): Prediction array + ref_instance_idx (int | None, optional): The index label to be evaluated for the reference. Defaults to None. + pred_instance_idx (int | list[int] | None, optional): The index label to be evaluated for the prediction. Defaults to None. + + Returns: + int | float: The metric value + """ + return self.value( + reference_arr=reference_arr, + prediction_arr=prediction_arr, + ref_instance_idx=ref_instance_idx, + pred_instance_idx=pred_instance_idx, + *args, + **kwargs, + ) + + def score_beats_threshold( + self, matching_score: float, matching_threshold: float + ) -> bool: + """Calculates whether a score beats a specified threshold + + Args: + matching_score (float): Metric score + matching_threshold (float): Threshold to compare against + Returns: + bool: True if the matching_score beats the threshold, False otherwise. + """ + return (self.increasing and matching_score >= matching_threshold) or ( + self.decreasing and matching_score <= matching_threshold + ) -# Metrics that are derived from list metrics and can be calculated later -# TODO map result properties to this enum -class EvalMetric(_Enum_Compare): - TP = auto() - FP = auto() - FN = auto() - RQ = auto() - DQ_DSC = auto() - PQ_DSC = auto() - ASSD = auto() - PQ_ASSD = auto() + @property + def name(self): + return self.value.name + + @property + def decreasing(self): + return self.value.decreasing + + @property + def increasing(self): + return self.value.increasing + + def __hash__(self) -> int: + return abs(hash(self.name)) % (10**8) -MetricDict = dict[ListMetric | EvalMetric | str, float | list[float]] +class MetricMode(_Enum_Compare): + """Different modalities from Metrics + Args: + _Enum_Compare (_type_): _description_ + """ -list_of_applicable_std_metrics: list[EvalMetric] = [ - EvalMetric.RQ, - EvalMetric.DQ_DSC, - EvalMetric.PQ_ASSD, - EvalMetric.ASSD, - EvalMetric.PQ_ASSD, -] + ALL = auto() + AVG = auto() + SUM = auto() + STD = auto() if __name__ == "__main__": - print(Metrics.DSC) + print(Metric.DSC) # print(MatchingMetric.DSC.name) - print(Metrics.DSC == Metrics.DSC) - print(Metrics.DSC == "DSC") - print(Metrics.DSC.name == "DSC") + print(Metric.DSC == Metric.DSC) + print(Metric.DSC == "DSC") + print(Metric.DSC.name == "DSC") # - print(Metrics.DSC == Metrics.IOU) - print(Metrics.DSC == "IOU") + print(Metric.DSC == Metric.IOU) + print(Metric.DSC == "IOU") diff --git a/panoptica/panoptic_evaluator.py b/panoptica/panoptic_evaluator.py index ae904d4..1a10b3c 100644 --- a/panoptica/panoptic_evaluator.py +++ b/panoptica/panoptic_evaluator.py @@ -5,7 +5,7 @@ from panoptica.instance_approximator import InstanceApproximator from panoptica.instance_evaluator import evaluate_matched_instance from panoptica.instance_matcher import InstanceMatchingAlgorithm -from panoptica.metrics import Metrics, _MatchingMetric +from panoptica.metrics import Metric, _Metric, Metric from panoptica.panoptic_result import PanopticaResult from panoptica.timing import measure_time from panoptica.utils import EdgeCaseHandler @@ -27,8 +27,8 @@ def __init__( instance_approximator: InstanceApproximator | None = None, instance_matcher: InstanceMatchingAlgorithm | None = None, edge_case_handler: EdgeCaseHandler | None = None, - eval_metrics: list[_MatchingMetric] = [Metrics.DSC, Metrics.IOU, Metrics.ASSD], - decision_metric: _MatchingMetric | None = None, + eval_metrics: list[Metric] = [Metric.DSC, Metric.IOU, Metric.ASSD], + decision_metric: Metric | None = None, decision_threshold: float | None = None, log_times: bool = False, verbose: bool = False, @@ -68,6 +68,8 @@ def evaluate( | UnmatchedInstancePair | MatchedInstancePair | PanopticaResult, + result_all: bool = True, + verbose: bool | None = None, ) -> tuple[PanopticaResult, dict[str, _ProcessingPair]]: assert ( type(processing_pair) == self.__expected_input @@ -80,8 +82,9 @@ def evaluate( eval_metrics=self.__eval_metrics, decision_metric=self.__decision_metric, decision_threshold=self.__decision_threshold, + result_all=result_all, log_times=self.__log_times, - verbose=self.__verbose, + verbose=self.__verbose if verbose is None else verbose, ) @@ -92,11 +95,12 @@ def panoptic_evaluate( | PanopticaResult, instance_approximator: InstanceApproximator | None = None, instance_matcher: InstanceMatchingAlgorithm | None = None, - eval_metrics: list[_MatchingMetric] = [Metrics.DSC, Metrics.IOU, Metrics.ASSD], - decision_metric: _MatchingMetric | None = None, + eval_metrics: list[Metric] = [Metric.DSC, Metric.IOU, Metric.ASSD], + decision_metric: Metric | None = None, decision_threshold: float | None = None, edge_case_handler: EdgeCaseHandler | None = None, log_times: bool = False, + result_all: bool = True, verbose: bool = False, **kwargs, ) -> tuple[PanopticaResult, dict[str, _ProcessingPair]]: @@ -155,7 +159,9 @@ def panoptic_evaluate( # Second Phase: Instance Matching if isinstance(processing_pair, UnmatchedInstancePair): processing_pair = _handle_zero_instances_cases( - processing_pair, edge_case_handler=edge_case_handler + processing_pair, + eval_metrics=eval_metrics, + edge_case_handler=edge_case_handler, ) if isinstance(processing_pair, UnmatchedInstancePair): @@ -175,7 +181,9 @@ def panoptic_evaluate( # Third Phase: Instance Evaluation if isinstance(processing_pair, MatchedInstancePair): processing_pair = _handle_zero_instances_cases( - processing_pair, edge_case_handler=edge_case_handler + processing_pair, + eval_metrics=eval_metrics, + edge_case_handler=edge_case_handler, ) if isinstance(processing_pair, MatchedInstancePair): @@ -191,6 +199,8 @@ def panoptic_evaluate( print(f"-- Instance Evaluation took {perf_counter() - start} seconds") if isinstance(processing_pair, PanopticaResult): + if result_all: + processing_pair.calculate_all(print_errors=verbose) return processing_pair, debug_data raise RuntimeError("End of panoptic pipeline reached without results") @@ -199,6 +209,7 @@ def panoptic_evaluate( def _handle_zero_instances_cases( processing_pair: UnmatchedInstancePair | MatchedInstancePair, edge_case_handler: EdgeCaseHandler, + eval_metrics: list[_Metric] = [Metric.DSC, Metric.IOU, Metric.ASSD], ) -> UnmatchedInstancePair | MatchedInstancePair | PanopticaResult: """ Handle edge cases when comparing reference and prediction masks. @@ -213,32 +224,36 @@ def _handle_zero_instances_cases( n_reference_instance = processing_pair.n_reference_instance n_prediction_instance = processing_pair.n_prediction_instance + panoptica_result_args = { + "list_metrics": {Metric[k.name]: [] for k in eval_metrics}, + "tp": 0, + "edge_case_handler": edge_case_handler, + "reference_arr": processing_pair.reference_arr, + "prediction_arr": processing_pair.prediction_arr, + } + + is_edge_case = False + # Handle cases where either the reference or the prediction is empty if n_prediction_instance == 0 and n_reference_instance == 0: # Both references and predictions are empty, perfect match - return PanopticaResult( - num_ref_instances=0, - num_pred_instances=0, - tp=0, - list_metrics={}, - edge_case_handler=edge_case_handler, - ) - if n_reference_instance == 0: + n_reference_instance = 0 + n_prediction_instance = 0 + is_edge_case = True + elif n_reference_instance == 0: # All references are missing, only false positives - return PanopticaResult( - num_ref_instances=0, - num_pred_instances=n_prediction_instance, - tp=0, - list_metrics={}, - edge_case_handler=edge_case_handler, - ) - if n_prediction_instance == 0: + n_reference_instance = 0 + n_prediction_instance = n_prediction_instance + is_edge_case = True + elif n_prediction_instance == 0: # All predictions are missing, only false negatives - return PanopticaResult( - num_ref_instances=n_reference_instance, - num_pred_instances=0, - tp=0, - list_metrics={}, - edge_case_handler=edge_case_handler, - ) + n_reference_instance = n_reference_instance + n_prediction_instance = 0 + is_edge_case = True + + if is_edge_case: + panoptica_result_args["num_ref_instances"] = n_reference_instance + panoptica_result_args["num_pred_instances"] = n_prediction_instance + return PanopticaResult(**panoptica_result_args) + return processing_pair diff --git a/panoptica/panoptic_result.py b/panoptica/panoptic_result.py index f2b38cc..38b002a 100644 --- a/panoptica/panoptic_result.py +++ b/panoptica/panoptic_result.py @@ -1,367 +1,551 @@ from __future__ import annotations -from typing import Any, List - +from typing import Any, Callable import numpy as np - -from panoptica.metrics import EvalMetric, ListMetric, MetricDict, _MatchingMetric +from panoptica.metrics import MetricMode, Metric +from panoptica.metrics import ( + _compute_dice_coefficient, + _compute_centerline_dice_coefficient, +) from panoptica.utils import EdgeCaseHandler +from panoptica.utils.processing_pair import MatchedInstancePair -class PanopticaResult: - """ - Represents the result of the Panoptic Quality (PQ) computation. +class MetricCouldNotBeComputedException(Exception): + """Exception for when a Metric cannot be computed""" + + def __init__(self, *args: object) -> None: + super().__init__(*args) - Attributes: - num_ref_instances (int): Number of reference instances. - num_pred_instances (int): Number of predicted instances. - tp (int): Number of correctly matched instances (True Positives). - fp (int): Number of extra predicted instances (False Positives). - """ +class Evaluation_Metric: def __init__( self, - num_ref_instances: int, + name_id: str, + calc_func: Callable | None, + long_name: str | None = None, + was_calculated: bool = False, + error: bool = False, + ): + """This represents a metric in the evaluation derived from other metrics or list metrics (no circular dependancies!) + + Args: + name_id (str): code-name of this metric, must be same as the member variable of PanopticResult + calc_func (Callable): the function to calculate this metric based on the PanopticResult object + long_name (str | None, optional): A longer descriptive name for printing/logging purposes. Defaults to None. + was_calculated (bool, optional): Whether this metric has been calculated or not. Defaults to False. + error (bool, optional): If true, means the metric could not have been calculated (because dependancies do not exist or have this flag set to True). Defaults to False. + """ + self.id = name_id + self.calc_func = calc_func + self.long_name = long_name + self.was_calculated = was_calculated + self.error = error + self.error_obj: MetricCouldNotBeComputedException | None = None + + def __call__(self, result_obj: PanopticaResult) -> Any: + if self.error: + if self.error_obj is None: + raise MetricCouldNotBeComputedException( + f"Metric {self.id} requested, but could not be computed" + ) + else: + raise self.error_obj + assert ( + not self.was_calculated + ), f"Metric {self.id} was called to compute, but is set to have been already calculated" + assert ( + self.calc_func is not None + ), f"Metric {self.id} was called to compute, but has no calculation function set" + try: + value = self.calc_func(result_obj) + except MetricCouldNotBeComputedException as e: + value = e + self.error = True + self.error_obj = e + return value + + def __str__(self) -> str: + if self.long_name is not None: + return self.long_name + f" ({self.id})" + else: + return self.id + + +class Evaluation_List_Metric: + def __init__( + self, + name_id: Metric, + empty_list_std: float | None, + value_list: list[float] | None, # None stands for not calculated + is_edge_case: bool = False, + edge_case_result: float | None = None, + ): + """This represents the metrics resulting from a Metric calculated between paired instances (IoU, ASSD, Dice, ...) + + Args: + name_id (Metric): code-name of this metric + empty_list_std (float): Value for the standard deviation if the list of values is empty + value_list (list[float] | None): List of values of that metric (only the TPs) + """ + self.id = name_id + self.error = value_list is None + self.ALL: list[float] | None = value_list + if is_edge_case: + self.AVG: float | None = edge_case_result + self.SUM: None | float = edge_case_result + else: + self.AVG = None if self.ALL is None else np.average(self.ALL) + self.SUM = None if self.ALL is None else np.sum(self.ALL) + self.STD = ( + None + if self.ALL is None + else empty_list_std + if len(self.ALL) == 0 + else np.std(self.ALL) + ) + + def __getitem__(self, mode: MetricMode | str): + if self.error: + raise MetricCouldNotBeComputedException( + f"Metric {self.id} has not been calculated, add it to your eval_metrics" + ) + if isinstance(mode, MetricMode): + mode = mode.name + if hasattr(self, mode): + return getattr(self, mode) + else: + raise MetricCouldNotBeComputedException( + f"List_Metric {self.id} does not contain {mode} member" + ) + + +class PanopticaResult(object): + def __init__( + self, + reference_arr: np.ndarray, + prediction_arr: np.ndarray, + # TODO some metadata object containing dtype, voxel spacing, ... num_pred_instances: int, + num_ref_instances: int, tp: int, - list_metrics: dict[_MatchingMetric | str, list[float]], + list_metrics: dict[Metric, list[float]], edge_case_handler: EdgeCaseHandler, ): - """ - Initialize a PanopticaResult object. + """Result object for Panoptica, contains all calculatable metrics Args: - num_ref_instances (int): Number of reference instances. - num_pred_instances (int): Number of predicted instances. - tp (int): Number of correctly matched instances (True Positives). - list_metrics: dict[MatchingMetric | str, list[float]]: TBD - edge_case_handler: EdgeCaseHandler: TBD - """ - self._tp = tp - self.edge_case_handler = edge_case_handler - self.metric_dict: MetricDict = {} - for k, v in list_metrics.items(): - if isinstance(k, _MatchingMetric): - k = k.name - self.metric_dict[k] = v - - # for k in ListMetric: - # if k.name not in self.metric_dict: - # self.metric_dict[k.name] = [] - self._num_ref_instances = num_ref_instances - self._num_pred_instances = num_pred_instances - - # TODO instead of all the properties, make a generic function inputting metric and std or not, - # and returns it if contained in dictionary, - # otherwise calls function to calculates, saves it and return - - def __str__(self): - text = ( - f"Number of instances in prediction: {self.num_pred_instances}\n" - f"Number of instances in reference: {self.num_ref_instances}\n" - f"True Positives (tp): {self.tp}\n" - f"False Positives (fp): {self.fp}\n" - f"False Negatives (fn): {self.fn}\n" - f"Recognition Quality / F1 Score (RQ): {self.rq}\n" + reference_arr (np.ndarray): matched reference arr + prediction_arr (np.ndarray): matched prediction arr + num_pred_instances (int): number of prediction instances + num_ref_instances (int): number of reference instances + tp (int): number of true positives (matched instances) + list_metrics (dict[Metric, list[float]]): dictionary containing the metrics for each TP + edge_case_handler (EdgeCaseHandler): EdgeCaseHandler object that handles various forms of edge cases + """ + self._edge_case_handler = edge_case_handler + empty_list_std = self._edge_case_handler.handle_empty_list_std() + self._prediction_arr = prediction_arr + self._reference_arr = reference_arr + ###################### + # Evaluation Metrics # + ###################### + self._evaluation_metrics: dict[str, Evaluation_Metric] = {} + # + # region Already Calculated + self.num_ref_instances: int + self._add_metric( + "num_ref_instances", + None, + long_name="Number of instances in reference", + default_value=num_ref_instances, + was_calculated=True, + ) + self.num_pred_instances: int + self._add_metric( + "num_pred_instances", + None, + long_name="Number of instances in prediction", + default_value=num_pred_instances, + was_calculated=True, ) + self.tp: int + self._add_metric( + "tp", + None, + long_name="True Positives", + default_value=tp, + was_calculated=True, + ) + # endregion + # + # region Basic + self.fp: int + self._add_metric( + "fp", + fp, + long_name="False Positives", + ) + self.fn: int + self._add_metric( + "fn", + fn, + long_name="False Negatives", + ) + self.rq: float + self._add_metric( + "rq", + rq, + long_name="Recognition Quality", + ) + # endregion + # + # region Global + self.global_bin_dsc: int + self._add_metric( + "global_bin_dsc", + global_bin_dsc, + long_name="Global Binary Dice", + ) + # + self.global_bin_cldsc: int + self._add_metric( + "global_bin_cldsc", + global_bin_cldsc, + long_name="Global Binary Centerline Dice", + ) + # endregion + # + # region IOU + self.sq: float + self._add_metric( + "sq", + sq, + long_name="Segmentation Quality IoU", + ) + self.sq_std: float + self._add_metric( + "sq_std", + sq_std, + long_name="Segmentation Quality IoU Standard Deviation", + ) + self.pq: float + self._add_metric( + "pq", + pq, + long_name="Panoptic Quality IoU", + ) + # endregion + # + # region DICE + self.sq_dsc: float + self._add_metric( + "sq_dsc", + sq_dsc, + long_name="Segmentation Quality Dsc", + ) + self.sq_dsc_std: float + self._add_metric( + "sq_dsc_std", + sq_dsc_std, + long_name="Segmentation Quality Dsc Standard Deviation", + ) + self.pq_dsc: float + self._add_metric( + "pq_dsc", + pq_dsc, + long_name="Panoptic Quality Dsc", + ) + # endregion + # + # region clDICE + self.sq_cldsc: float + self._add_metric( + "sq_cldsc", + sq_cldsc, + long_name="Segmentation Quality Centerline Dsc", + ) + self.sq_cldsc_std: float + self._add_metric( + "sq_cldsc_std", + sq_cldsc_std, + long_name="Segmentation Quality Centerline Dsc Standard Deviation", + ) + self.pq_cldsc: float + self._add_metric( + "pq_cldsc", + pq_cldsc, + long_name="Panoptic Quality Centerline Dsc", + ) + # endregion + # + # region ASSD + self.sq_assd: float + self._add_metric( + "sq_assd", + sq_assd, + long_name="Segmentation Quality Assd", + ) + self.sq_assd_std: float + self._add_metric( + "sq_assd_std", + sq_assd_std, + long_name="Segmentation Quality Assd Standard Deviation", + ) + # endregion - if ListMetric.IOU.name in self.metric_dict: - text += f"Segmentation Quality (SQ): {self.sq} ± {self.sq_sd}\n" - text += f"Panoptic Quality (PQ): {self.pq}\n" + ################## + # List Metrics # + ################## + self._list_metrics: dict[Metric, Evaluation_List_Metric] = {} + for k, v in list_metrics.items(): + is_edge_case, edge_case_result = self._edge_case_handler.handle_zero_tp( + metric=k, + tp=self.tp, + num_pred_instances=self.num_pred_instances, + num_ref_instances=self.num_ref_instances, + ) + self._list_metrics[k] = Evaluation_List_Metric( + k, empty_list_std, v, is_edge_case, edge_case_result + ) - if ListMetric.DSC.name in self.metric_dict: - text += f"DSC-based Segmentation Quality (DQ_DSC): {self.sq_dsc} ± {self.sq_dsc_sd}\n" - text += f"DSC-based Panoptic Quality (PQ_DSC): {self.pq_dsc}\n" + def _add_metric( + self, + name_id: str, + calc_func: Callable | None, + long_name: str | None = None, + default_value=None, + was_calculated: bool = False, + ): + setattr(self, name_id, default_value) + # assert hasattr(self, name_id), f"added metric {name_id} but it is not a member variable of this class" + if calc_func is None: + assert ( + was_calculated + ), "Tried to add a metric without a calc_function but that hasn't been calculated yet, how did you think this could works?" + eval_metric = Evaluation_Metric(name_id, calc_func, long_name, was_calculated) + self._evaluation_metrics[name_id] = eval_metric + return default_value + + def calculate_all(self, print_errors: bool = False): + """Calculates all possible metrics that can be derived - if ListMetric.ASSD.name in self.metric_dict: - text += f"Average symmetric surface distance (ASSD): {self.sq_assd} ± {self.sq_assd_sd}\n" - text += f"ASSD-based Panoptic Quality (PQ_ASSD): {self.pq_assd}" + Args: + print_errors (bool, optional): If true, will print every metric that could not be computed and its reason. Defaults to False. + """ + metric_errors: dict[str, Exception] = {} + for k, v in self._evaluation_metrics.items(): + try: + v = getattr(self, k) + except Exception as e: + metric_errors[k] = e + + if print_errors: + for k, v in metric_errors.items(): + print(f"Metric {k}: {v}") + + def __str__(self) -> str: + text = "" + for k, v in self._evaluation_metrics.items(): + if k.endswith("_std"): + continue + if v.was_calculated and not v.error: + # is there standard deviation for this? + text += f"{v}: {self.__getattribute__(k)}" + k_std = k + "_std" + if ( + k_std in self._evaluation_metrics + and self._evaluation_metrics[k_std].was_calculated + and not self._evaluation_metrics[k_std].error + ): + text += f" +- {self.__getattribute__(k_std)}" + text += "\n" return text - def to_dict(self): - eval_dict = { - "num_pred_instances": self.num_pred_instances, - "num_ref_instances": self.num_ref_instances, - "tp": self.tp, - "fp": self.fp, - "fn": self.fn, - "rq": self.rq, + def to_dict(self) -> dict: + return { + k: getattr(self, v.id) + for k, v in self._evaluation_metrics.items() + if (v.error == False and v.was_calculated) } - if ListMetric.IOU.name in self.metric_dict: - eval_dict["sq"] = self.sq - eval_dict["sq_sd"] = self.sq_sd - eval_dict["pq"] = self.pq + def get_list_metric(self, metric: Metric, mode: MetricMode): + if metric in self._list_metrics: + return self._list_metrics[metric][mode] + else: + raise MetricCouldNotBeComputedException( + f"{metric} could not be found, have you set it in eval_metrics during evaluation?" + ) - if ListMetric.DSC.name in self.metric_dict: - eval_dict["sq_dsc"] = self.sq_dsc - eval_dict["sq_dsc_sd"] = self.sq_dsc_sd - eval_dict["pq_dsc"] = self.pq_dsc + def _calc_metric(self, metric_name: str, supress_error: bool = False): + if metric_name in self._evaluation_metrics: + try: + value = self._evaluation_metrics[metric_name](self) + except MetricCouldNotBeComputedException as e: + value = e + if isinstance(value, MetricCouldNotBeComputedException): + self._evaluation_metrics[metric_name].error = True + self._evaluation_metrics[metric_name].was_calculated = True + if not supress_error: + raise value + self._evaluation_metrics[metric_name].was_calculated = True + return value + else: + raise MetricCouldNotBeComputedException( + f"could not find metric with name {metric_name}" + ) - if ListMetric.ASSD.name in self.metric_dict: - eval_dict["sq_assd"] = self.sq_assd - eval_dict["sq_assd_sd"] = self.sq_assd_sd - eval_dict["pq_assd"] = self.pq_assd - return eval_dict + def __getattribute__(self, __name: str) -> Any: + attr = None + try: + attr = object.__getattribute__(self, __name) + except AttributeError as e: + if __name in self._evaluation_metrics.keys(): + pass + else: + raise e + if attr is None: + if self._evaluation_metrics[__name].error: + raise MetricCouldNotBeComputedException( + f"Requested metric {__name} that could not be computed" + ) + elif not self._evaluation_metrics[__name].was_calculated: + value = self._calc_metric(__name) + setattr(self, __name, value) + if isinstance(value, MetricCouldNotBeComputedException): + raise value + return value + else: + return attr - @property - def num_ref_instances(self) -> int: - """ - Get the number of reference instances. - Returns: - int: Number of reference instances. - """ - return self._num_ref_instances +######################### +# Calculation functions # +######################### - @property - def num_pred_instances(self) -> int: - """ - Get the number of predicted instances. - Returns: - int: Number of predicted instances. - """ - return self._num_pred_instances +# region Basic +def fp(res: PanopticaResult): + return res.num_pred_instances - res.tp - @property - def tp(self) -> int: - """ - Calculate the number of True Positives (TP). - Returns: - int: Number of True Positives. - """ - return self._tp +def fn(res: PanopticaResult): + return res.num_ref_instances - res.tp - @property - def fp(self) -> int: - """ - Calculate the number of False Positives (FP). - Returns: - int: Number of False Positives. - """ - return self.num_pred_instances - self.tp +def rq(res: PanopticaResult): + """ + Calculate the Recognition Quality (RQ) based on TP, FP, and FN. - @property - def fn(self) -> int: - """ - Calculate the number of False Negatives (FN). + Returns: + float: Recognition Quality (RQ). + """ + if res.tp == 0: + return 0.0 if res.num_pred_instances + res.num_ref_instances > 0 else np.nan + return res.tp / (res.tp + 0.5 * res.fp + 0.5 * res.fn) - Returns: - int: Number of False Negatives. - """ - return self.num_ref_instances - self.tp - @property - def rq(self) -> float: - """ - Calculate the Recognition Quality (RQ) based on TP, FP, and FN. +# endregion - Returns: - float: Recognition Quality (RQ). - """ - if self.tp == 0: - return ( - 0.0 if self.num_pred_instances + self.num_ref_instances > 0 else np.nan - ) - return self.tp / (self.tp + 0.5 * self.fp + 0.5 * self.fn) - @property - def sq(self) -> float: - """ - Calculate the Segmentation Quality (SQ) based on IoU values. +# region IOU +def sq(res: PanopticaResult): + return res.get_list_metric(Metric.IOU, mode=MetricMode.AVG) - Returns: - float: Segmentation Quality (SQ). - """ - is_edge_case, result = self.edge_case_handler.handle_zero_tp( - metric=ListMetric.IOU, - tp=self.tp, - num_pred_instances=self.num_pred_instances, - num_ref_instances=self.num_ref_instances, - ) - if is_edge_case: - return result - if ListMetric.IOU.name not in self.metric_dict: - print("Requested SQ but no IOU metric evaluated") - return None - return np.sum(self.metric_dict[ListMetric.IOU.name]) / self.tp - - @property - def sq_sd(self) -> float: - """ - Calculate the standard deviation of Segmentation Quality (SQ) based on IoU values. - Returns: - float: Standard deviation of Segmentation Quality (SQ). - """ - if ListMetric.IOU.name not in self.metric_dict: - print("Requested SQ_SD but no IOU metric evaluated") - return None - return ( - np.std(self.metric_dict[ListMetric.IOU.name]) - if len(self.metric_dict[ListMetric.IOU.name]) > 0 - else self.edge_case_handler.handle_empty_list_std() - ) +def sq_std(res: PanopticaResult): + return res.get_list_metric(Metric.IOU, mode=MetricMode.STD) - @property - def pq(self) -> float: - """ - Calculate the Panoptic Quality (PQ) based on SQ and RQ. - Returns: - float: Panoptic Quality (PQ). - """ - sq = self.sq - rq = self.rq - if sq is None or rq is None: - return None - else: - return sq * rq +def pq(res: PanopticaResult): + return res.sq * res.rq - @property - def sq_dsc(self) -> float: - """ - Calculate the average Dice coefficient for matched instances. Analogue to segmentation quality but based on DSC. - Returns: - float: Average Dice coefficient. - """ - is_edge_case, result = self.edge_case_handler.handle_zero_tp( - metric=ListMetric.DSC, - tp=self.tp, - num_pred_instances=self.num_pred_instances, - num_ref_instances=self.num_ref_instances, - ) - if is_edge_case: - return result - if ListMetric.DSC.name not in self.metric_dict: - print("Requested DSC but no DSC metric evaluated") - return None - return np.sum(self.metric_dict[ListMetric.DSC.name]) / self.tp - - @property - def sq_dsc_sd(self) -> float: - """ - Calculate the standard deviation of average Dice coefficient for matched instances. Analogue to segmentation quality but based on DSC. +# endregion - Returns: - float: Standard deviation of Average Dice coefficient. - """ - if ListMetric.DSC.name not in self.metric_dict: - print("Requested DSC_SD but no DSC metric evaluated") - return None - return ( - np.std(self.metric_dict[ListMetric.DSC.name]) - if len(self.metric_dict[ListMetric.DSC.name]) > 0 - else self.edge_case_handler.handle_empty_list_std() - ) - @property - def pq_dsc(self) -> float: - """ - Calculate the Panoptic Quality (PQ) based on DSC-based SQ and RQ. +# region DSC +def sq_dsc(res: PanopticaResult): + return res.get_list_metric(Metric.DSC, mode=MetricMode.AVG) - Returns: - float: Panoptic Quality (PQ). - """ - sq = self.sq_dsc - rq = self.rq - if sq is None or rq is None: - return None - else: - return sq * rq - @property - def sq_assd(self) -> float: - """ - Calculate the average average symmetric surface distance (ASSD) for matched instances. Analogue to segmentation quality but based on ASSD. +def sq_dsc_std(res: PanopticaResult): + return res.get_list_metric(Metric.DSC, mode=MetricMode.STD) - Returns: - float: average symmetric surface distance. (ASSD) - """ - is_edge_case, result = self.edge_case_handler.handle_zero_tp( - metric=ListMetric.ASSD, - tp=self.tp, - num_pred_instances=self.num_pred_instances, - num_ref_instances=self.num_ref_instances, - ) - if is_edge_case: - return result - if ListMetric.ASSD.name not in self.metric_dict: - print("Requested ASSD but no ASSD metric evaluated") - return None - return np.sum(self.metric_dict[ListMetric.ASSD.name]) / self.tp - - @property - def sq_assd_sd(self) -> float: - """ - Calculate the standard deviation of average symmetric surface distance (ASSD) for matched instances. Analogue to segmentation quality but based on ASSD. - Returns: - float: Standard deviation of average symmetric surface distance (ASSD). - """ - if ListMetric.ASSD.name not in self.metric_dict: - print("Requested ASSD_SD but no ASSD metric evaluated") - return None - return ( - np.std(self.metric_dict[ListMetric.ASSD.name]) - if len(self.metric_dict[ListMetric.ASSD.name]) > 0 - else self.edge_case_handler.handle_empty_list_std() - ) - @property - def pq_assd(self) -> float: - """ - Calculate the Panoptic Quality (PQ) based on ASSD-based SQ and RQ. +def pq_dsc(res: PanopticaResult): + return res.sq_dsc * res.rq - Returns: - float: Panoptic Quality (PQ). - """ - return self.sq_assd * self.rq +# endregion -# TODO make general getter that takes metric enum and std or not -# splits up into lists or not -# use below structure -def getter(value: int): - return value +# region clDSC +def sq_cldsc(res: PanopticaResult): + return res.get_list_metric(Metric.clDSC, mode=MetricMode.AVG) -class Test(object): - def __init__(self) -> None: - self.x: int - self.y: int - # x = property(fget=getter(value=45)) +def sq_cldsc_std(res: PanopticaResult): + return res.get_list_metric(Metric.clDSC, mode=MetricMode.STD) - def __getattribute__(self, __name: str) -> Any: - attr = None - try: - attr = object.__getattribute__(self, __name) - except AttributeError as e: - pass - if attr is None: - value = getter(5) - setattr(self, __name, value) - return value - else: - return attr - # def __getattribute__(self, name): - # if some_predicate(name): - # # ... - # else: - # # Default behaviour - # return object.__getattribute__(self, name) +def pq_cldsc(res: PanopticaResult): + return res.sq_cldsc * res.rq -if __name__ == "__main__": - c = Test() +# endregion + + +# region ASSD +def sq_assd(res: PanopticaResult): + return res.get_list_metric(Metric.ASSD, mode=MetricMode.AVG) + + +def sq_assd_std(res: PanopticaResult): + return res.get_list_metric(Metric.ASSD, mode=MetricMode.STD) - print(c.x) - c.x = 4 +# endregion - print(c.x) + +# region Global +def global_bin_dsc(res: PanopticaResult): + if res.tp == 0: + return 0.0 + pred_binary = res._prediction_arr.copy() + ref_binary = res._reference_arr.copy() + pred_binary[pred_binary != 0] = 1 + ref_binary[ref_binary != 0] = 1 + return _compute_dice_coefficient(ref_binary, pred_binary) + + +def global_bin_cldsc(res: PanopticaResult): + if res.tp == 0: + return 0.0 + pred_binary = res._prediction_arr.copy() + ref_binary = res._reference_arr.copy() + pred_binary[pred_binary != 0] = 1 + ref_binary[ref_binary != 0] = 1 + return _compute_centerline_dice_coefficient(ref_binary, pred_binary) + + +# endregion + + +if __name__ == "__main__": + c = PanopticaResult( + reference_arr=np.zeros([5, 5, 5]), + prediction_arr=np.zeros([5, 5, 5]), + num_ref_instances=2, + num_pred_instances=5, + tp=0, + list_metrics={Metric.IOU: []}, + edge_case_handler=EdgeCaseHandler(), + ) + + print(c) + + c.calculate_all(print_errors=True) + print(c) + + # print(c.sq) diff --git a/panoptica/utils/edge_case_handling.py b/panoptica/utils/edge_case_handling.py index c2f881a..33b4c29 100644 --- a/panoptica/utils/edge_case_handling.py +++ b/panoptica/utils/edge_case_handling.py @@ -2,7 +2,7 @@ import numpy as np -from panoptica.metrics import ListMetric, Metrics +from panoptica.metrics import Metric, Metric from panoptica.utils.constants import _Enum_Compare, auto @@ -81,17 +81,17 @@ def __str__(self) -> str: class EdgeCaseHandler: def __init__( self, - listmetric_zeroTP_handling: dict[ListMetric, MetricZeroTPEdgeCaseHandling] = { - ListMetric.DSC: MetricZeroTPEdgeCaseHandling( + listmetric_zeroTP_handling: dict[Metric, MetricZeroTPEdgeCaseHandling] = { + Metric.DSC: MetricZeroTPEdgeCaseHandling( no_instances_result=EdgeCaseResult.NAN, default_result=EdgeCaseResult.ZERO, ), - ListMetric.IOU: MetricZeroTPEdgeCaseHandling( + Metric.IOU: MetricZeroTPEdgeCaseHandling( no_instances_result=EdgeCaseResult.NAN, empty_prediction_result=EdgeCaseResult.ZERO, default_result=EdgeCaseResult.ZERO, ), - ListMetric.ASSD: MetricZeroTPEdgeCaseHandling( + Metric.ASSD: MetricZeroTPEdgeCaseHandling( no_instances_result=EdgeCaseResult.NAN, default_result=EdgeCaseResult.INF, ), @@ -99,17 +99,19 @@ def __init__( empty_list_std: EdgeCaseResult = EdgeCaseResult.NAN, ) -> None: self.__listmetric_zeroTP_handling: dict[ - ListMetric, MetricZeroTPEdgeCaseHandling + Metric, MetricZeroTPEdgeCaseHandling ] = listmetric_zeroTP_handling - self.__empty_list_std = empty_list_std + self.__empty_list_std: EdgeCaseResult = empty_list_std def handle_zero_tp( self, - metric: ListMetric, + metric: Metric, tp: int, num_pred_instances: int, num_ref_instances: int, ) -> tuple[bool, float | None]: + if tp != 0: + return False, EdgeCaseResult.NONE.value if metric not in self.__listmetric_zeroTP_handling: raise NotImplementedError( f"Metric {metric} encountered zero TP, but no edge handling available" @@ -121,10 +123,10 @@ def handle_zero_tp( num_ref_instances=num_ref_instances, ) - def get_metric_zero_tp_handle(self, metric: ListMetric): + def get_metric_zero_tp_handle(self, metric: Metric): return self.__listmetric_zeroTP_handling[metric] - def handle_empty_list_std(self): + def handle_empty_list_std(self) -> float | None: return self.__empty_list_std.value def __str__(self) -> str: @@ -140,7 +142,7 @@ def __str__(self) -> str: print() # print(handler.get_metric_zero_tp_handle(ListMetric.IOU)) r = handler.handle_zero_tp( - ListMetric.IOU, tp=0, num_pred_instances=1, num_ref_instances=1 + Metric.IOU, tp=0, num_pred_instances=1, num_ref_instances=1 ) print(r) diff --git a/pyproject.toml b/pyproject.toml index d9fbb1e..2a30f1f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,7 @@ numpy = "^1.20.0" connected-components-3d = "^3.12.3" scipy = "^1.7.0" rich = "^13.6.0" +scikit-image = "^0.22.0" [tool.poetry.dev-dependencies] pytest = "^6.2.5" @@ -35,7 +36,7 @@ pytest-mock = "^3.6.0" optional = true [tool.poetry.group.docs.dependencies] -Sphinx = ">=7.0.0" +Sphinx = ">=7.0.0" sphinx-copybutton = ">=0.5.2" sphinx-rtd-theme = ">=1.3.0" -myst-parser = ">=2.0.0" \ No newline at end of file +myst-parser = ">=2.0.0" diff --git a/unit_tests/test_datatype.py b/unit_tests/test_datatype.py new file mode 100644 index 0000000..faffc56 --- /dev/null +++ b/unit_tests/test_datatype.py @@ -0,0 +1,45 @@ +# Call 'python -m unittest' on this folder +# coverage run -m unittest +# coverage report +# coverage html +import unittest +import os +import numpy as np + +from panoptica.panoptic_evaluator import Panoptic_Evaluator +from panoptica.instance_approximator import ConnectedComponentsInstanceApproximator +from panoptica.instance_matcher import NaiveThresholdMatching, MaximizeMergeMatching +from panoptica.panoptic_result import PanopticaResult, MetricCouldNotBeComputedException +from panoptica.metrics import _Metric, Metric, Metric, MetricMode +from panoptica.utils.edge_case_handling import EdgeCaseHandler, EdgeCaseResult +from panoptica.utils.processing_pair import SemanticPair + + +class Test_Panoptic_Evaluator(unittest.TestCase): + def setUp(self) -> None: + os.environ["PANOPTICA_CITATION_REMINDER"] = "False" + return super().setUp() + + def test_metrics_enum(self): + print(Metric.DSC) + # print(MatchingMetric.DSC.name) + + self.assertEqual(Metric.DSC, Metric.DSC) + self.assertEqual(Metric.DSC, "DSC") + self.assertEqual(Metric.DSC.name, "DSC") + # + self.assertNotEqual(Metric.DSC, Metric.IOU) + self.assertNotEqual(Metric.DSC, "IOU") + + def test_matching_metric(self): + dsc_metric = Metric.DSC + + self.assertTrue(dsc_metric.score_beats_threshold(0.55, 0.5)) + self.assertFalse(dsc_metric.score_beats_threshold(0.5, 0.55)) + + assd_metric = Metric.ASSD + + self.assertFalse(assd_metric.score_beats_threshold(0.55, 0.5)) + self.assertTrue(assd_metric.score_beats_threshold(0.5, 0.55)) + + # TODO listmetric + Mode (STD and so on) diff --git a/unit_tests/test_panoptic_evaluator.py b/unit_tests/test_panoptic_evaluator.py index fa5c2ec..f732701 100644 --- a/unit_tests/test_panoptic_evaluator.py +++ b/unit_tests/test_panoptic_evaluator.py @@ -9,8 +9,9 @@ from panoptica.panoptic_evaluator import Panoptic_Evaluator from panoptica.instance_approximator import ConnectedComponentsInstanceApproximator from panoptica.instance_matcher import NaiveThresholdMatching, MaximizeMergeMatching -from panoptica.metrics import _MatchingMetric, Metrics +from panoptica.metrics import _Metric, Metric from panoptica.utils.processing_pair import SemanticPair +from panoptica.panoptic_result import PanopticaResult, MetricCouldNotBeComputedException class Test_Panoptic_Evaluator(unittest.TestCase): @@ -71,18 +72,19 @@ def test_simple_evaluation_DSC_partial(self): evaluator = Panoptic_Evaluator( expected_input=SemanticPair, instance_approximator=ConnectedComponentsInstanceApproximator(), - instance_matcher=NaiveThresholdMatching(matching_metric=Metrics.DSC), - eval_metrics=[Metrics.DSC], + instance_matcher=NaiveThresholdMatching(matching_metric=Metric.DSC), + eval_metrics=[Metric.DSC], ) result, debug_data = evaluator.evaluate(sample) print(result) self.assertEqual(result.tp, 1) self.assertEqual(result.fp, 0) - self.assertEqual( - result.sq, None - ) # must be none because no IOU has been calculated - self.assertEqual(result.pq, None) + with self.assertRaises(MetricCouldNotBeComputedException): + result.sq + # must be none because no IOU has been calculated + with self.assertRaises(MetricCouldNotBeComputedException): + result.pq self.assertEqual(result.rq, 1.0) def test_simple_evaluation_ASSD(self): @@ -97,7 +99,7 @@ def test_simple_evaluation_ASSD(self): expected_input=SemanticPair, instance_approximator=ConnectedComponentsInstanceApproximator(), instance_matcher=NaiveThresholdMatching( - matching_metric=Metrics.ASSD, + matching_metric=Metric.ASSD, matching_threshold=1.0, ), ) @@ -121,7 +123,7 @@ def test_simple_evaluation_ASSD_negative(self): expected_input=SemanticPair, instance_approximator=ConnectedComponentsInstanceApproximator(), instance_matcher=NaiveThresholdMatching( - matching_metric=Metrics.ASSD, + matching_metric=Metric.ASSD, matching_threshold=0.5, ), ) diff --git a/unit_tests/test_panoptic_result.py b/unit_tests/test_panoptic_result.py new file mode 100644 index 0000000..4a87b1d --- /dev/null +++ b/unit_tests/test_panoptic_result.py @@ -0,0 +1,136 @@ +# Call 'python -m unittest' on this folder +# coverage run -m unittest +# coverage report +# coverage html +import unittest +import os +import numpy as np + +from panoptica.panoptic_evaluator import Panoptic_Evaluator +from panoptica.instance_approximator import ConnectedComponentsInstanceApproximator +from panoptica.instance_matcher import NaiveThresholdMatching, MaximizeMergeMatching +from panoptica.panoptic_result import PanopticaResult, MetricCouldNotBeComputedException +from panoptica.metrics import _Metric, Metric, Metric, MetricMode +from panoptica.utils.edge_case_handling import EdgeCaseHandler, EdgeCaseResult +from panoptica.utils.processing_pair import SemanticPair + + +class Test_Panoptic_Evaluator(unittest.TestCase): + def setUp(self) -> None: + os.environ["PANOPTICA_CITATION_REMINDER"] = "False" + return super().setUp() + + def test_simple_evaluation(self): + c = PanopticaResult( + prediction_arr=None, + reference_arr=None, + num_ref_instances=2, + num_pred_instances=5, + tp=0, + list_metrics={Metric.IOU: []}, + edge_case_handler=EdgeCaseHandler(), + ) + c.calculate_all(print_errors=True) + print(c) + + self.assertEqual(c.tp, 0) + self.assertEqual(c.fp, 5) + self.assertEqual(c.fn, 2) + self.assertEqual(c.rq, 0.0) + self.assertEqual(c.pq, 0.0) + + def test_simple_tp_fp(self): + for n_ref in range(1, 10): + for n_pred in range(1, 10): + for tp in range(1, n_ref): + c = PanopticaResult( + prediction_arr=None, + reference_arr=None, + num_ref_instances=n_ref, + num_pred_instances=n_pred, + tp=tp, + list_metrics={Metric.IOU: []}, + edge_case_handler=EdgeCaseHandler(), + ) + c.calculate_all(print_errors=False) + print(c) + + self.assertEqual(c.tp, tp) + self.assertEqual(c.fp, n_pred - tp) + self.assertEqual(c.fn, n_ref - tp) + + def test_std_edge_case(self): + for ecr in EdgeCaseResult: + c = PanopticaResult( + prediction_arr=None, + reference_arr=None, + num_ref_instances=2, + num_pred_instances=5, + tp=0, + list_metrics={Metric.IOU: []}, + edge_case_handler=EdgeCaseHandler(empty_list_std=ecr), + ) + c.calculate_all(print_errors=True) + print(c) + + if c.sq_std is None: + self.assertTrue(ecr.value is None) + elif np.isnan(c.sq_std): + self.assertTrue(np.isnan(ecr.value)) + else: + self.assertEqual(c.sq_std, ecr.value) + + def test_existing_metrics(self): + from itertools import chain, combinations + + def powerset(iterable): + s = list(iterable) + return list( + chain.from_iterable(combinations(s, r) for r in range(len(s) + 1)) + ) + + power_set = powerset([Metric.DSC, Metric.IOU, Metric.ASSD]) + for m in power_set[1:]: + list_metrics: dict = {} + for me in m: + list_metrics[me] = [1.0] + print(list(list_metrics.keys())) + + c = PanopticaResult( + prediction_arr=None, + reference_arr=None, + num_ref_instances=2, + num_pred_instances=5, + tp=1, + list_metrics=list_metrics, + edge_case_handler=EdgeCaseHandler(), + ) + c.calculate_all(print_errors=True) + print(c) + + if Metric.DSC in list_metrics: + self.assertEqual(c.sq_dsc, 1.0) + self.assertEqual(c.sq_dsc_std, 0.0) + else: + with self.assertRaises(MetricCouldNotBeComputedException): + c.sq_dsc + with self.assertRaises(MetricCouldNotBeComputedException): + c.sq_dsc_std + # + if Metric.IOU in list_metrics: + self.assertEqual(c.sq, 1.0) + self.assertEqual(c.sq_std, 0.0) + else: + with self.assertRaises(MetricCouldNotBeComputedException): + c.sq + with self.assertRaises(MetricCouldNotBeComputedException): + c.sq_std + # + if Metric.ASSD in list_metrics: + self.assertEqual(c.sq_assd, 1.0) + self.assertEqual(c.sq_assd_std, 0.0) + else: + with self.assertRaises(MetricCouldNotBeComputedException): + c.sq_assd + with self.assertRaises(MetricCouldNotBeComputedException): + c.sq_assd_std