diff --git a/benchmark/modules_speedtest.py b/benchmark/modules_speedtest.py index cdfe395..6ff5555 100644 --- a/benchmark/modules_speedtest.py +++ b/benchmark/modules_speedtest.py @@ -9,6 +9,8 @@ ConnectedComponentsInstanceApproximator, NaiveThresholdMatching, SemanticPair, + UnmatchedInstancePair, + MatchedInstancePair, ) from panoptica.instance_evaluator import evaluate_matched_instance from time import perf_counter @@ -80,16 +82,21 @@ def test_input(processing_pair: SemanticPair): processing_pair.crop_data() # start1 = perf_counter() - processing_pair = instance_approximator.approximate_instances(processing_pair) + unmatched_instance_pair = instance_approximator.approximate_instances( + semantic_pair=processing_pair + ) time1 = perf_counter() - start1 # start2 = perf_counter() - processing_pair = instance_matcher.match_instances(processing_pair) + matched_instance_pair = instance_matcher.match_instances( + unmatched_instance_pair=unmatched_instance_pair + ) time2 = perf_counter() - start2 # start3 = perf_counter() - processing_pair = evaluate_matched_instance( - processing_pair, iou_threshold=iou_threshold + result = evaluate_matched_instance( + matched_instance_pair, + decision_threshold=iou_threshold, ) time3 = perf_counter() - start3 return time1, time2, time3 diff --git a/examples/example_cfos_3d.py b/examples/example_cfos_3d.py deleted file mode 100644 index d86aba2..0000000 --- a/examples/example_cfos_3d.py +++ /dev/null @@ -1,28 +0,0 @@ -from auxiliary.nifti.io import read_nifti - -from panoptica import ( - SemanticPair, - Panoptic_Evaluator, - ConnectedComponentsInstanceApproximator, - CCABackend, - NaiveThresholdMatching, -) - -pred_masks = read_nifti( - input_nifti_path="/home/florian/flow/cfos_analysis/data/ablation/2021-11-25_23-50-56_2021-10-25_19-38-31_tr_dice_bce_11/patchvolume_695_2.nii.gz" -) -ref_masks = read_nifti( - input_nifti_path="/home/florian/flow/cfos_analysis/data/reference/patchvolume_695_2/patchvolume_695_2_binary.nii.gz", -) - -sample = SemanticPair(pred_masks, ref_masks) - -evaluator = Panoptic_Evaluator( - expected_input=SemanticPair, - instance_approximator=ConnectedComponentsInstanceApproximator(), - instance_matcher=NaiveThresholdMatching(), - iou_threshold=0.5, -) - -result, debug_data = evaluator.evaluate(sample) -print(result) diff --git a/examples/example_spine_instance.py b/examples/example_spine_instance.py index 361702d..2717721 100644 --- a/examples/example_spine_instance.py +++ b/examples/example_spine_instance.py @@ -4,6 +4,7 @@ from auxiliary.turbopath import turbopath from panoptica import MatchedInstancePair, Panoptic_Evaluator +from panoptica.metrics import Metrics directory = turbopath(__file__).parent @@ -16,13 +17,16 @@ evaluator = Panoptic_Evaluator( expected_input=MatchedInstancePair, - instance_approximator=None, - instance_matcher=None, - iou_threshold=0.5, + eval_metrics=[Metrics.ASSD, Metrics.IOU], + decision_metric=Metrics.IOU, + decision_threshold=0.5, ) + + with cProfile.Profile() as pr: if __name__ == "__main__": result, debug_data = evaluator.evaluate(sample) + print(result) pr.dump_stats(directory + "/instance_example.log") diff --git a/examples/example_spine_semantic.py b/examples/example_spine_semantic.py index 63b8d51..ac67b25 100644 --- a/examples/example_spine_semantic.py +++ b/examples/example_spine_semantic.py @@ -3,13 +3,13 @@ from auxiliary.nifti.io import read_nifti from auxiliary.turbopath import turbopath - from panoptica import ( ConnectedComponentsInstanceApproximator, NaiveThresholdMatching, Panoptic_Evaluator, SemanticPair, ) +from panoptica.metrics import Metrics directory = turbopath(__file__).parent @@ -18,12 +18,13 @@ sample = SemanticPair(pred_masks, ref_masks) + evaluator = Panoptic_Evaluator( expected_input=SemanticPair, instance_approximator=ConnectedComponentsInstanceApproximator(), instance_matcher=NaiveThresholdMatching(), - iou_threshold=0.5, ) + with cProfile.Profile() as pr: if __name__ == "__main__": result, debug_data = evaluator.evaluate(sample) diff --git a/panoptica/__init__.py b/panoptica/__init__.py index 98a8887..9e17409 100644 --- a/panoptica/__init__.py +++ b/panoptica/__init__.py @@ -3,9 +3,9 @@ CCABackend, ) from panoptica.instance_matcher import NaiveThresholdMatching -from panoptica.evaluator import Panoptic_Evaluator -from panoptica.result import PanopticaResult -from panoptica.utils.datatypes import ( +from panoptica.panoptic_evaluator import Panoptic_Evaluator +from panoptica.panoptic_result import PanopticaResult +from panoptica.utils.processing_pair import ( SemanticPair, UnmatchedInstancePair, MatchedInstancePair, diff --git a/panoptica/_functionals.py b/panoptica/_functionals.py index 50fdcd0..5b03723 100644 --- a/panoptica/_functionals.py +++ b/panoptica/_functionals.py @@ -1,8 +1,10 @@ +from multiprocessing import Pool + import numpy as np -from panoptica.metrics import _compute_instance_iou + +from panoptica.metrics import _compute_instance_iou, _MatchingMetric from panoptica.utils.constants import CCABackend from panoptica.utils.numpy_utils import _get_bbox_nd -from multiprocessing import Pool def _calc_overlapping_labels( @@ -35,6 +37,42 @@ def _calc_overlapping_labels( ] +def _calc_matching_metric_of_overlapping_labels( + prediction_arr: np.ndarray, + reference_arr: np.ndarray, + ref_labels: tuple[int, ...], + matching_metric: _MatchingMetric, +) -> list[tuple[float, tuple[int, int]]]: + """Calculates the MatchingMetric for all overlapping labels (fast!) + + Args: + prediction_arr (np.ndarray): Numpy array containing the prediction labels. + reference_arr (np.ndarray): Numpy array containing the reference labels. + ref_labels (list[int]): List of unique reference labels. + + Returns: + list[tuple[float, tuple[int, int]]]: List of pairs in style: (iou, (ref_label, pred_label)) + """ + instance_pairs = [ + (reference_arr == i[0], prediction_arr == i[1], i[0], i[1]) + for i in _calc_overlapping_labels( + prediction_arr=prediction_arr, + reference_arr=reference_arr, + ref_labels=ref_labels, + ) + ] + with Pool() as pool: + mm_values = pool.starmap(matching_metric._metric_function, instance_pairs) + + mm_pairs = [ + (i, (instance_pairs[idx][2], instance_pairs[idx][3])) + for idx, i in enumerate(mm_values) + ] + mm_pairs = sorted(mm_pairs, key=lambda x: x[0], reverse=matching_metric.decreasing) + + return mm_pairs + + def _calc_iou_of_overlapping_labels( prediction_arr: np.ndarray, reference_arr: np.ndarray, diff --git a/panoptica/instance_approximator.py b/panoptica/instance_approximator.py index c24c5cd..0a905b5 100644 --- a/panoptica/instance_approximator.py +++ b/panoptica/instance_approximator.py @@ -1,5 +1,5 @@ from abc import abstractmethod, ABC -from panoptica.utils.datatypes import ( +from panoptica.utils.processing_pair import ( SemanticPair, UnmatchedInstancePair, MatchedInstancePair, diff --git a/panoptica/instance_evaluator.py b/panoptica/instance_evaluator.py index 8f537fa..fa742f8 100644 --- a/panoptica/instance_evaluator.py +++ b/panoptica/instance_evaluator.py @@ -1,19 +1,26 @@ import concurrent.futures -from panoptica.utils.datatypes import MatchedInstancePair -from panoptica.result import PanopticaResult +import gc +from multiprocessing import Pool + +import numpy as np + from panoptica.metrics import ( - _compute_iou, - _compute_dice_coefficient, - _average_symmetric_surface_distance, + Metrics, + _MatchingMetric, ) +from panoptica.panoptic_result import PanopticaResult from panoptica.timing import measure_time -import numpy as np -import gc -from multiprocessing import Pool +from panoptica.utils import EdgeCaseHandler +from panoptica.utils.processing_pair import MatchedInstancePair def evaluate_matched_instance( - matched_instance_pair: MatchedInstancePair, iou_threshold: float, **kwargs + matched_instance_pair: MatchedInstancePair, + eval_metrics: list[_MatchingMetric] = [Metrics.DSC, Metrics.IOU, Metrics.ASSD], + decision_metric: _MatchingMetric | None = Metrics.IOU, + decision_threshold: float | None = None, + edge_case_handler: EdgeCaseHandler | None = None, + **kwargs, ) -> PanopticaResult: """ Map instance labels based on the provided labelmap and create a MatchedInstancePair. @@ -30,66 +37,49 @@ def evaluate_matched_instance( >>> labelmap = [([1, 2], [3, 4]), ([5], [6])] >>> result = map_instance_labels(unmatched_instance_pair, labelmap) """ + if edge_case_handler is None: + edge_case_handler = EdgeCaseHandler() + if decision_metric is not None: + assert decision_metric.name in [ + v.name for v in eval_metrics + ], "decision metric not contained in eval_metrics" + assert decision_threshold is not None, "decision metric set but no threshold" # Initialize variables for True Positives (tp) - tp, dice_list, iou_list, assd_list = 0, [], [], [] + tp = len(matched_instance_pair.matched_instances) + score_dict: dict[str | _MatchingMetric, list[float]] = { + m.name: [] for m in eval_metrics + } reference_arr, prediction_arr = ( matched_instance_pair._reference_arr, matched_instance_pair._prediction_arr, ) - ref_labels = matched_instance_pair._ref_labels - - # instance_pairs = _calc_overlapping_labels( - # prediction_arr=prediction_arr, - # reference_arr=reference_arr, - # ref_labels=ref_labels, - # ) - # instance_pairs = [(ra, pa, rl, iou_threshold) for (ra, pa, rl, pl) in instance_pairs] + ref_matched_labels = matched_instance_pair.matched_instances instance_pairs = [ - (reference_arr, prediction_arr, ref_idx, iou_threshold) - for ref_idx in ref_labels + (reference_arr, prediction_arr, ref_idx, eval_metrics) + for ref_idx in ref_matched_labels ] with Pool() as pool: - metric_values = pool.starmap(_evaluate_instance, instance_pairs) - - for tp_i, dice_i, iou_i, assd_i in metric_values: - tp += tp_i - if dice_i is not None and iou_i is not None and assd_i is not None: - dice_list.append(dice_i) - iou_list.append(iou_i) - assd_list.append(assd_i) - - # Use concurrent.futures.ThreadPoolExecutor for parallelization - # with concurrent.futures.ThreadPoolExecutor() as executor: - # futures = [ - # executor.submit( - # _evaluate_instance, - # reference_arr, - # prediction_arr, - # ref_idx, - # iou_threshold, - # ) - # for ref_idx in ref_labels - # ] - # - # for future in concurrent.futures.as_completed(futures): - # tp_i, dice_i, iou_i, assd_i = future.result() - # tp += tp_i - # if dice_i is not None and iou_i is not None and assd_i is not None: - # dice_list.append(dice_i) - # iou_list.append(iou_i) - # assd_list.append(assd_i) - # del future - # gc.collect() + metric_dicts = pool.starmap(_evaluate_instance, instance_pairs) + + for metric_dict in metric_dicts: + if decision_metric is None or ( + decision_threshold is not None + and decision_metric.score_beats_threshold( + metric_dict[decision_metric.name], decision_threshold + ) + ): + for k, v in metric_dict.items(): + score_dict[k].append(v) + # Create and return the PanopticaResult object with computed metrics return PanopticaResult( num_ref_instances=matched_instance_pair.n_reference_instance, num_pred_instances=matched_instance_pair.n_prediction_instance, tp=tp, - dice_list=dice_list, - iou_list=iou_list, - assd_list=assd_list, + list_metrics=score_dict, + edge_case_handler=edge_case_handler, ) @@ -97,8 +87,8 @@ def _evaluate_instance( reference_arr: np.ndarray, prediction_arr: np.ndarray, ref_idx: int, - iou_threshold: float, -) -> tuple[int, float | None, float | None, float | None]: + eval_metrics: list[_MatchingMetric], +) -> dict[str, float]: """ Evaluate a single instance. @@ -113,27 +103,12 @@ def _evaluate_instance( """ ref_arr = reference_arr == ref_idx pred_arr = prediction_arr == ref_idx + result: dict[str, float] = {} if ref_arr.sum() == 0 or pred_arr.sum() == 0: - tp = 0 - dice = None - iou = None - assd = None + return result else: - iou: float | None = _compute_iou( - reference=ref_arr, - prediction=pred_arr, - ) - if iou > iou_threshold: - tp = 1 - dice = _compute_dice_coefficient( - reference=ref_arr, - prediction=pred_arr, - ) - assd = _average_symmetric_surface_distance(pred_arr, ref_arr) - else: - tp = 0 - dice = None - iou = None - assd = None - - return tp, dice, iou, assd + for metric in eval_metrics: + value = metric._metric_function(ref_arr, pred_arr) + result[metric.name] = value + + return result diff --git a/panoptica/instance_matcher.py b/panoptica/instance_matcher.py index e0114ac..339609f 100644 --- a/panoptica/instance_matcher.py +++ b/panoptica/instance_matcher.py @@ -3,17 +3,15 @@ import numpy as np from panoptica._functionals import ( - _calc_iou_matrix, + _calc_matching_metric_of_overlapping_labels, _map_labels, - _calc_iou_of_overlapping_labels, ) -from panoptica.utils.datatypes import ( +from panoptica.metrics import Metrics, _MatchingMetric +from panoptica.utils.processing_pair import ( InstanceLabelMap, MatchedInstancePair, UnmatchedInstancePair, ) -from panoptica.timing import measure_time -from scipy.optimize import linear_sum_assignment class InstanceMatchingAlgorithm(ABC): @@ -43,7 +41,9 @@ class InstanceMatchingAlgorithm(ABC): @abstractmethod def _match_instances( - self, unmatched_instance_pair: UnmatchedInstancePair, **kwargs + self, + unmatched_instance_pair: UnmatchedInstancePair, + **kwargs, ) -> InstanceLabelMap: """ Abstract method to be implemented by subclasses for instance matching. @@ -58,7 +58,9 @@ def _match_instances( pass def match_instances( - self, unmatched_instance_pair: UnmatchedInstancePair, **kwargs + self, + unmatched_instance_pair: UnmatchedInstancePair, + **kwargs, ) -> MatchedInstancePair: """ Perform instance matching on the given UnmatchedInstancePair. @@ -70,7 +72,10 @@ def match_instances( Returns: MatchedInstancePair: The result of the instance matching. """ - instance_labelmap = self._match_instances(unmatched_instance_pair, **kwargs) + instance_labelmap = self._match_instances( + unmatched_instance_pair, + **kwargs, + ) # print("instance_labelmap:", instance_labelmap) return map_instance_labels(unmatched_instance_pair.copy(), instance_labelmap) @@ -126,7 +131,7 @@ def map_instance_labels( missed_prediction_labels=missed_pred_labels, n_prediction_instance=processing_pair.n_prediction_instance, n_reference_instance=processing_pair.n_reference_instance, - n_matched_instances=n_matched_instances, + matched_instances=ref_matched_labels, ) return matched_instance_pair @@ -154,7 +159,10 @@ class NaiveThresholdMatching(InstanceMatchingAlgorithm): """ def __init__( - self, iou_threshold: float = 0.5, allow_many_to_one: bool = False + self, + matching_metric: _MatchingMetric = Metrics.IOU, + matching_threshold: float = 0.5, + allow_many_to_one: bool = False, ) -> None: """ Initialize the NaiveOneToOneMatching instance. @@ -165,11 +173,14 @@ def __init__( Raises: AssertionError: If the specified IoU threshold is not within the valid range. """ - self.iou_threshold = iou_threshold self.allow_many_to_one = allow_many_to_one + self.matching_metric = matching_metric + self.matching_threshold = matching_threshold def _match_instances( - self, unmatched_instance_pair: UnmatchedInstancePair, **kwargs + self, + unmatched_instance_pair: UnmatchedInstancePair, + **kwargs, ) -> InstanceLabelMap: """ Perform one-to-one instance matching based on IoU values. @@ -182,7 +193,6 @@ def _match_instances( Instance_Label_Map: The result of the instance matching. """ ref_labels = unmatched_instance_pair._ref_labels - pred_labels = unmatched_instance_pair._pred_labels # Initialize variables for True Positives (tp) and False Positives (fp) labelmap = InstanceLabelMap() @@ -191,18 +201,20 @@ def _match_instances( unmatched_instance_pair._prediction_arr, unmatched_instance_pair._reference_arr, ) - iou_pairs = _calc_iou_of_overlapping_labels( - pred_arr, ref_arr, ref_labels, pred_labels + mm_pairs = _calc_matching_metric_of_overlapping_labels( + pred_arr, ref_arr, ref_labels, matching_metric=self.matching_metric ) # Loop through matched instances to compute PQ components - for iou, (ref_label, pred_label) in iou_pairs: + for matching_score, (ref_label, pred_label) in mm_pairs: if ( labelmap.contains_or(pred_label, ref_label) and not self.allow_many_to_one ): continue # -> doesnt make speed difference - if iou >= 0.5: + if self.matching_metric.score_beats_threshold( + matching_score, self.matching_threshold + ): # Match found, increment true positive count and collect IoU and Dice values labelmap.add_labelmap_entry(pred_label, ref_label) # map label ref_idx to pred_idx @@ -211,74 +223,27 @@ def _match_instances( class MaximizeMergeMatching(InstanceMatchingAlgorithm): """ - Instance matching algorithm that performs many-to-one matching based on IoU values. Will merge if combined instance IOU is greater than individual one + Instance matching algorithm that performs many-to-one matching based on metric. Will merge if combined instance metric is greater than individual one - Attributes: - iou_threshold (float): The IoU threshold for matching instances. Methods: - __init__(self, iou_threshold: float = 0.5) -> None: - Initialize the NaiveOneToOneMatching instance. _match_instances(self, unmatched_instance_pair: UnmatchedInstancePair, **kwargs) -> Instance_Label_Map: Perform one-to-one instance matching based on IoU values. Raises: AssertionError: If the specified IoU threshold is not within the valid range. - - Example: - >>> matcher = NaiveOneToOneMatching(iou_threshold=0.6) - >>> unmatched_instance_pair = UnmatchedInstancePair(...) - >>> result = matcher.match_instances(unmatched_instance_pair) """ - def __init__( - self, iou_threshold: float = 0.5, allow_many_to_one: bool = False - ) -> None: - """ - Initialize the NaiveOneToOneMatching instance. + pass - Args: - iou_threshold (float, optional): The IoU threshold for matching instances. Defaults to 0.5. - Raises: - AssertionError: If the specified IoU threshold is not within the valid range. - """ - self.iou_threshold = iou_threshold - self.allow_many_to_one = allow_many_to_one +class MatchUntilConvergenceMatching(InstanceMatchingAlgorithm): + # Match like the naive matcher (so each to their best reference) and then again and again until no overlapping labels are left + pass - def _match_instances( - self, unmatched_instance_pair: UnmatchedInstancePair, **kwargs - ) -> InstanceLabelMap: - """ - Perform one-to-one instance matching based on IoU values. - Args: - unmatched_instance_pair (UnmatchedInstancePair): The unmatched instance pair to be matched. - **kwargs: Additional keyword arguments. - - Returns: - Instance_Label_Map: The result of the instance matching. - """ - ref_labels = unmatched_instance_pair._ref_labels - pred_labels = unmatched_instance_pair._pred_labels - - # Initialize variables for True Positives (tp) and False Positives (fp) - labelmap = InstanceLabelMap() - - pred_arr, ref_arr = ( - unmatched_instance_pair._prediction_arr, - unmatched_instance_pair._reference_arr, - ) - iou_pairs = _calc_iou_of_overlapping_labels( - pred_arr, ref_arr, ref_labels, pred_labels - ) - - # Loop through matched instances to compute PQ components - for iou, (ref_label, pred_label) in iou_pairs: - if labelmap.contains_or(None, ref_label): - continue # -> doesnt make speed difference - if iou >= 0.5: - # Match found, increment true positive count and collect IoU and Dice values - labelmap.add_labelmap_entry(pred_label, ref_label) - # map label ref_idx to pred_idx - return labelmap +class DesperateMarriageMatching(InstanceMatchingAlgorithm): + # Match as many predictions to references as possible, doesn't need threshold + # Option for many-to-one or one-to-one + # https://github.com/koseii2122/The-Stable-Matching-Algorithm + pass diff --git a/panoptica/metrics/__init__.py b/panoptica/metrics/__init__.py index 1f8ef14..ef93513 100644 --- a/panoptica/metrics/__init__.py +++ b/panoptica/metrics/__init__.py @@ -1,9 +1,16 @@ from panoptica.metrics.assd import ( - _average_symmetric_surface_distance, _average_surface_distance, + _average_symmetric_surface_distance, ) from panoptica.metrics.dice import ( _compute_dice_coefficient, _compute_instance_volumetric_dice, ) from panoptica.metrics.iou import _compute_instance_iou, _compute_iou +from panoptica.metrics.metrics import ( + EvalMetric, + ListMetric, + MetricDict, + Metrics, + _MatchingMetric, +) diff --git a/panoptica/metrics/assd.py b/panoptica/metrics/assd.py index b4f7925..a08bba2 100644 --- a/panoptica/metrics/assd.py +++ b/panoptica/metrics/assd.py @@ -5,41 +5,46 @@ def _average_symmetric_surface_distance( - result, reference, + prediction, voxelspacing=None, connectivity=1, + *args, ) -> float: assd = np.mean( ( - _average_surface_distance(result, reference, voxelspacing, connectivity), - _average_surface_distance(reference, result, voxelspacing, connectivity), + _average_surface_distance( + prediction, reference, voxelspacing, connectivity + ), + _average_surface_distance( + reference, prediction, voxelspacing, connectivity + ), ) ) return float(assd) -def _average_surface_distance(result, reference, voxelspacing=None, connectivity=1): - sds = __surface_distances(result, reference, voxelspacing, connectivity) +def _average_surface_distance(reference, prediction, voxelspacing=None, connectivity=1): + sds = __surface_distances(reference, prediction, voxelspacing, connectivity) asd = sds.mean() return asd -def __surface_distances(result, reference, voxelspacing=None, connectivity=1): +def __surface_distances(reference, prediction, voxelspacing=None, connectivity=1): """ The distances between the surface voxel of binary objects in result and their nearest partner surface voxel of a binary object in reference. """ - result = np.atleast_1d(result.astype(bool)) + prediction = np.atleast_1d(prediction.astype(bool)) reference = np.atleast_1d(reference.astype(bool)) if voxelspacing is not None: - voxelspacing = _ni_support._normalize_sequence(voxelspacing, result.ndim) + voxelspacing = _ni_support._normalize_sequence(voxelspacing, prediction.ndim) voxelspacing = np.asarray(voxelspacing, dtype=np.float64) if not voxelspacing.flags.contiguous: voxelspacing = voxelspacing.copy() # binary structure - footprint = generate_binary_structure(result.ndim, connectivity) + footprint = generate_binary_structure(prediction.ndim, connectivity) # test for emptiness # if 0 == np.count_nonzero(result): @@ -48,11 +53,11 @@ def __surface_distances(result, reference, voxelspacing=None, connectivity=1): # raise RuntimeError("The second supplied array does not contain any binary object.") # extract only 1-pixel border line of objects - result_border = result ^ binary_erosion(result, structure=footprint, iterations=1) + result_border = prediction ^ binary_erosion( + prediction, structure=footprint, iterations=1 + ) reference_border = reference ^ binary_erosion( - reference, - structure=footprint, - iterations=1, + reference, structure=footprint, iterations=1 ) # compute average surface distance diff --git a/panoptica/metrics/dice.py b/panoptica/metrics/dice.py index 20c6c6e..55e3b3a 100644 --- a/panoptica/metrics/dice.py +++ b/panoptica/metrics/dice.py @@ -36,6 +36,7 @@ def _compute_instance_volumetric_dice( def _compute_dice_coefficient( reference: np.ndarray, prediction: np.ndarray, + *args, ) -> float: """ Compute the Dice coefficient between two binary masks. diff --git a/panoptica/metrics/iou.py b/panoptica/metrics/iou.py index 7fcd76d..7cfbd81 100644 --- a/panoptica/metrics/iou.py +++ b/panoptica/metrics/iou.py @@ -21,19 +21,14 @@ def _compute_instance_iou( """ ref_instance_mask = reference_arr == ref_instance_idx pred_instance_mask = prediction_arr == pred_instance_idx - intersection = np.logical_and(ref_instance_mask, pred_instance_mask) - union = np.logical_or(ref_instance_mask, pred_instance_mask) + return _compute_iou(ref_instance_mask, pred_instance_mask) - union_sum = np.sum(union) - # Handle division by zero - if union_sum == 0: - return 0.0 - iou = np.sum(intersection) / union_sum - return iou - - -def _compute_iou(reference: np.ndarray, prediction: np.ndarray) -> float: +def _compute_iou( + reference_arr: np.ndarray, + prediction_arr: np.ndarray, + *args, +) -> float: """ Compute Intersection over Union (IoU) between two masks. @@ -45,8 +40,8 @@ def _compute_iou(reference: np.ndarray, prediction: np.ndarray) -> float: float: IoU between the two masks. A value between 0 and 1, where higher values indicate better overlap and similarity between masks. """ - intersection = np.logical_and(reference, prediction) - union = np.logical_or(reference, prediction) + intersection = np.logical_and(reference_arr, prediction_arr) + union = np.logical_or(reference_arr, prediction_arr) union_sum = np.sum(union) diff --git a/panoptica/metrics/metrics.py b/panoptica/metrics/metrics.py new file mode 100644 index 0000000..f282907 --- /dev/null +++ b/panoptica/metrics/metrics.py @@ -0,0 +1,122 @@ +from dataclasses import dataclass +from enum import EnumMeta +from typing import Callable + +import numpy as np + +from panoptica.metrics import ( + _average_symmetric_surface_distance, + _compute_dice_coefficient, + _compute_iou, +) +from panoptica.utils.constants import Enum, _Enum_Compare, auto + + +@dataclass +class _MatchingMetric: + name: str + decreasing: bool + _metric_function: Callable + + def __call__( + self, + reference_arr: np.ndarray, + prediction_arr: np.ndarray, + ref_instance_idx: int | None = None, + pred_instance_idx: int | None = None, + *args, + **kwargs, + ): + if ref_instance_idx is not None and pred_instance_idx is not None: + reference_arr = reference_arr.copy() == ref_instance_idx + prediction_arr = prediction_arr.copy() == pred_instance_idx + return self._metric_function(reference_arr, prediction_arr, *args, **kwargs) + + def __eq__(self, __value: object) -> bool: + if isinstance(__value, _MatchingMetric): + return self.name == __value.name + elif isinstance(__value, str): + return self.name == __value + else: + return False + + def __str__(self) -> str: + return f"{type(self).__name__}.{self.name}" + + def __repr__(self) -> str: + return str(self) + + @property + def increasing(self): + return not self.decreasing + + def score_beats_threshold( + self, matching_score: float, matching_threshold: float + ) -> bool: + return (self.increasing and matching_score >= matching_threshold) or ( + self.decreasing and matching_score <= matching_threshold + ) + + +# class _EnumMeta(EnumMeta): +# def __getattribute__(cls, name) -> MatchingMetric: +# value = super().__getattribute__(name) +# if isinstance(value, cls): +# value = value.value +# return value + + +# Important metrics that must be calculated in the evaluator, can be set for thresholding in matching and evaluation +# TODO make abstract class for metric, make enum with references to these classes for referenciation and user exposure +class Metrics: + # TODO make this with meta above, and then it can function without the double name, right? + DSC = _MatchingMetric("DSC", False, _compute_dice_coefficient) + IOU = _MatchingMetric("IOU", False, _compute_iou) + ASSD = _MatchingMetric("ASSD", True, _average_symmetric_surface_distance) + # These are all lists of values + + +class ListMetric(_Enum_Compare): + DSC = Metrics.DSC.name + IOU = Metrics.IOU.name + ASSD = Metrics.ASSD.name + + def __hash__(self) -> int: + return abs(hash(self.value)) % (10**8) + + +# Metrics that are derived from list metrics and can be calculated later +# TODO map result properties to this enum +class EvalMetric(_Enum_Compare): + TP = auto() + FP = auto() + FN = auto() + RQ = auto() + DQ_DSC = auto() + PQ_DSC = auto() + ASSD = auto() + PQ_ASSD = auto() + + +MetricDict = dict[ListMetric | EvalMetric | str, float | list[float]] + + +list_of_applicable_std_metrics: list[EvalMetric] = [ + EvalMetric.RQ, + EvalMetric.DQ_DSC, + EvalMetric.PQ_ASSD, + EvalMetric.ASSD, + EvalMetric.PQ_ASSD, +] + + +if __name__ == "__main__": + print(Metrics.DSC) + # print(MatchingMetric.DSC.name) + + print(Metrics.DSC == Metrics.DSC) + print(Metrics.DSC == "DSC") + print(Metrics.DSC.name == "DSC") + # + print(Metrics.DSC == Metrics.IOU) + print(Metrics.DSC == "IOU") diff --git a/panoptica/evaluator.py b/panoptica/panoptic_evaluator.py similarity index 71% rename from panoptica/evaluator.py rename to panoptica/panoptic_evaluator.py index 0325898..ae904d4 100644 --- a/panoptica/evaluator.py +++ b/panoptica/panoptic_evaluator.py @@ -1,11 +1,15 @@ +from abc import ABC, abstractmethod +from time import perf_counter from typing import Type from panoptica.instance_approximator import InstanceApproximator from panoptica.instance_evaluator import evaluate_matched_instance from panoptica.instance_matcher import InstanceMatchingAlgorithm -from panoptica.result import PanopticaResult +from panoptica.metrics import Metrics, _MatchingMetric +from panoptica.panoptic_result import PanopticaResult from panoptica.timing import measure_time -from panoptica.utils.datatypes import ( +from panoptica.utils import EdgeCaseHandler +from panoptica.utils.processing_pair import ( MatchedInstancePair, SemanticPair, UnmatchedInstancePair, @@ -22,9 +26,12 @@ def __init__( | Type[MatchedInstancePair] = MatchedInstancePair, instance_approximator: InstanceApproximator | None = None, instance_matcher: InstanceMatchingAlgorithm | None = None, + edge_case_handler: EdgeCaseHandler | None = None, + eval_metrics: list[_MatchingMetric] = [Metrics.DSC, Metrics.IOU, Metrics.ASSD], + decision_metric: _MatchingMetric | None = None, + decision_threshold: float | None = None, log_times: bool = False, verbose: bool = False, - iou_threshold: float = 0.5, ) -> None: """Creates a Panoptic_Evaluator, that saves some parameters to be used for all subsequent evaluations @@ -35,9 +42,21 @@ def __init__( iou_threshold (float, optional): Iou Threshold for evaluation. Defaults to 0.5. """ self.__expected_input = expected_input + # self.__instance_approximator = instance_approximator self.__instance_matcher = instance_matcher - self.__iou_threshold = iou_threshold + self.__eval_metrics = eval_metrics + self.__decision_metric = decision_metric + self.__decision_threshold = decision_threshold + + self.__edge_case_handler = ( + edge_case_handler if edge_case_handler is not None else EdgeCaseHandler() + ) + if self.__decision_metric is not None: + assert ( + self.__decision_threshold is not None + ), "decision metric set but no decision threshold for it" + # self.__log_times = log_times self.__verbose = verbose @@ -55,9 +74,12 @@ def evaluate( ), f"input not of expected type {self.__expected_input}" return panoptic_evaluate( processing_pair=processing_pair, + edge_case_handler=self.__edge_case_handler, instance_approximator=self.__instance_approximator, instance_matcher=self.__instance_matcher, - iou_threshold=self.__iou_threshold, + eval_metrics=self.__eval_metrics, + decision_metric=self.__decision_metric, + decision_threshold=self.__decision_threshold, log_times=self.__log_times, verbose=self.__verbose, ) @@ -70,9 +92,12 @@ def panoptic_evaluate( | PanopticaResult, instance_approximator: InstanceApproximator | None = None, instance_matcher: InstanceMatchingAlgorithm | None = None, + eval_metrics: list[_MatchingMetric] = [Metrics.DSC, Metrics.IOU, Metrics.ASSD], + decision_metric: _MatchingMetric | None = None, + decision_threshold: float | None = None, + edge_case_handler: EdgeCaseHandler | None = None, log_times: bool = False, verbose: bool = False, - iou_threshold: float = 0.5, **kwargs, ) -> tuple[PanopticaResult, dict[str, _ProcessingPair]]: """ @@ -103,6 +128,9 @@ def panoptic_evaluate( (PanopticaResult(...), {'UnmatchedInstanceMap': _ProcessingPair(...), 'MatchedInstanceMap': _ProcessingPair(...)}) """ print("Panoptic: Start Evaluation") + if edge_case_handler is None: + # use default edgecase handler + edge_case_handler = EdgeCaseHandler() debug_data: dict[str, _ProcessingPair] = {} # First Phase: Instance Approximation if isinstance(processing_pair, PanopticaResult): @@ -118,30 +146,49 @@ def panoptic_evaluate( ), "Got SemanticPair but not InstanceApproximator" print("-- Got SemanticPair, will approximate instances") processing_pair = instance_approximator.approximate_instances(processing_pair) + start = perf_counter() + processing_pair = instance_approximator.approximate_instances(processing_pair) + if log_times: + print(f"-- Approximation took {perf_counter() - start} seconds") debug_data["UnmatchedInstanceMap"] = processing_pair.copy() # Second Phase: Instance Matching if isinstance(processing_pair, UnmatchedInstancePair): - processing_pair = _handle_zero_instances_cases(processing_pair) + processing_pair = _handle_zero_instances_cases( + processing_pair, edge_case_handler=edge_case_handler + ) if isinstance(processing_pair, UnmatchedInstancePair): print("-- Got UnmatchedInstancePair, will match instances") assert ( instance_matcher is not None ), "Got UnmatchedInstancePair but not InstanceMatchingAlgorithm" - processing_pair = instance_matcher.match_instances(processing_pair) + start = perf_counter() + processing_pair = instance_matcher.match_instances( + processing_pair, + ) + if log_times: + print(f"-- Matching took {perf_counter() - start} seconds") debug_data["MatchedInstanceMap"] = processing_pair.copy() # Third Phase: Instance Evaluation if isinstance(processing_pair, MatchedInstancePair): - processing_pair = _handle_zero_instances_cases(processing_pair) + processing_pair = _handle_zero_instances_cases( + processing_pair, edge_case_handler=edge_case_handler + ) if isinstance(processing_pair, MatchedInstancePair): print("-- Got MatchedInstancePair, will evaluate instances") processing_pair = evaluate_matched_instance( - processing_pair, iou_threshold=iou_threshold + processing_pair, + eval_metrics=eval_metrics, + decision_metric=decision_metric, + decision_threshold=decision_threshold, + edge_case_handler=edge_case_handler, ) + if log_times: + print(f"-- Instance Evaluation took {perf_counter() - start} seconds") if isinstance(processing_pair, PanopticaResult): return processing_pair, debug_data @@ -151,6 +198,7 @@ def panoptic_evaluate( def _handle_zero_instances_cases( processing_pair: UnmatchedInstancePair | MatchedInstancePair, + edge_case_handler: EdgeCaseHandler, ) -> UnmatchedInstancePair | MatchedInstancePair | PanopticaResult: """ Handle edge cases when comparing reference and prediction masks. @@ -164,6 +212,7 @@ def _handle_zero_instances_cases( """ n_reference_instance = processing_pair.n_reference_instance n_prediction_instance = processing_pair.n_prediction_instance + # Handle cases where either the reference or the prediction is empty if n_prediction_instance == 0 and n_reference_instance == 0: # Both references and predictions are empty, perfect match @@ -171,9 +220,8 @@ def _handle_zero_instances_cases( num_ref_instances=0, num_pred_instances=0, tp=0, - dice_list=[], - iou_list=[], - assd_list=[], + list_metrics={}, + edge_case_handler=edge_case_handler, ) if n_reference_instance == 0: # All references are missing, only false positives @@ -181,9 +229,8 @@ def _handle_zero_instances_cases( num_ref_instances=0, num_pred_instances=n_prediction_instance, tp=0, - dice_list=[], - iou_list=[], - assd_list=[], + list_metrics={}, + edge_case_handler=edge_case_handler, ) if n_prediction_instance == 0: # All predictions are missing, only false negatives @@ -191,8 +238,7 @@ def _handle_zero_instances_cases( num_ref_instances=n_reference_instance, num_pred_instances=0, tp=0, - dice_list=[], - iou_list=[], - assd_list=[], + list_metrics={}, + edge_case_handler=edge_case_handler, ) return processing_pair diff --git a/panoptica/panoptic_result.py b/panoptica/panoptic_result.py new file mode 100644 index 0000000..f2b38cc --- /dev/null +++ b/panoptica/panoptic_result.py @@ -0,0 +1,367 @@ +from __future__ import annotations + +from typing import Any, List + +import numpy as np + +from panoptica.metrics import EvalMetric, ListMetric, MetricDict, _MatchingMetric +from panoptica.utils import EdgeCaseHandler + + +class PanopticaResult: + """ + Represents the result of the Panoptic Quality (PQ) computation. + + Attributes: + num_ref_instances (int): Number of reference instances. + num_pred_instances (int): Number of predicted instances. + tp (int): Number of correctly matched instances (True Positives). + fp (int): Number of extra predicted instances (False Positives). + """ + + def __init__( + self, + num_ref_instances: int, + num_pred_instances: int, + tp: int, + list_metrics: dict[_MatchingMetric | str, list[float]], + edge_case_handler: EdgeCaseHandler, + ): + """ + Initialize a PanopticaResult object. + + Args: + num_ref_instances (int): Number of reference instances. + num_pred_instances (int): Number of predicted instances. + tp (int): Number of correctly matched instances (True Positives). + list_metrics: dict[MatchingMetric | str, list[float]]: TBD + edge_case_handler: EdgeCaseHandler: TBD + """ + self._tp = tp + self.edge_case_handler = edge_case_handler + self.metric_dict: MetricDict = {} + for k, v in list_metrics.items(): + if isinstance(k, _MatchingMetric): + k = k.name + self.metric_dict[k] = v + + # for k in ListMetric: + # if k.name not in self.metric_dict: + # self.metric_dict[k.name] = [] + self._num_ref_instances = num_ref_instances + self._num_pred_instances = num_pred_instances + + # TODO instead of all the properties, make a generic function inputting metric and std or not, + # and returns it if contained in dictionary, + # otherwise calls function to calculates, saves it and return + + def __str__(self): + text = ( + f"Number of instances in prediction: {self.num_pred_instances}\n" + f"Number of instances in reference: {self.num_ref_instances}\n" + f"True Positives (tp): {self.tp}\n" + f"False Positives (fp): {self.fp}\n" + f"False Negatives (fn): {self.fn}\n" + f"Recognition Quality / F1 Score (RQ): {self.rq}\n" + ) + + if ListMetric.IOU.name in self.metric_dict: + text += f"Segmentation Quality (SQ): {self.sq} ± {self.sq_sd}\n" + text += f"Panoptic Quality (PQ): {self.pq}\n" + + if ListMetric.DSC.name in self.metric_dict: + text += f"DSC-based Segmentation Quality (DQ_DSC): {self.sq_dsc} ± {self.sq_dsc_sd}\n" + text += f"DSC-based Panoptic Quality (PQ_DSC): {self.pq_dsc}\n" + + if ListMetric.ASSD.name in self.metric_dict: + text += f"Average symmetric surface distance (ASSD): {self.sq_assd} ± {self.sq_assd_sd}\n" + text += f"ASSD-based Panoptic Quality (PQ_ASSD): {self.pq_assd}" + return text + + def to_dict(self): + eval_dict = { + "num_pred_instances": self.num_pred_instances, + "num_ref_instances": self.num_ref_instances, + "tp": self.tp, + "fp": self.fp, + "fn": self.fn, + "rq": self.rq, + } + + if ListMetric.IOU.name in self.metric_dict: + eval_dict["sq"] = self.sq + eval_dict["sq_sd"] = self.sq_sd + eval_dict["pq"] = self.pq + + if ListMetric.DSC.name in self.metric_dict: + eval_dict["sq_dsc"] = self.sq_dsc + eval_dict["sq_dsc_sd"] = self.sq_dsc_sd + eval_dict["pq_dsc"] = self.pq_dsc + + if ListMetric.ASSD.name in self.metric_dict: + eval_dict["sq_assd"] = self.sq_assd + eval_dict["sq_assd_sd"] = self.sq_assd_sd + eval_dict["pq_assd"] = self.pq_assd + return eval_dict + + @property + def num_ref_instances(self) -> int: + """ + Get the number of reference instances. + + Returns: + int: Number of reference instances. + """ + return self._num_ref_instances + + @property + def num_pred_instances(self) -> int: + """ + Get the number of predicted instances. + + Returns: + int: Number of predicted instances. + """ + return self._num_pred_instances + + @property + def tp(self) -> int: + """ + Calculate the number of True Positives (TP). + + Returns: + int: Number of True Positives. + """ + return self._tp + + @property + def fp(self) -> int: + """ + Calculate the number of False Positives (FP). + + Returns: + int: Number of False Positives. + """ + return self.num_pred_instances - self.tp + + @property + def fn(self) -> int: + """ + Calculate the number of False Negatives (FN). + + Returns: + int: Number of False Negatives. + """ + return self.num_ref_instances - self.tp + + @property + def rq(self) -> float: + """ + Calculate the Recognition Quality (RQ) based on TP, FP, and FN. + + Returns: + float: Recognition Quality (RQ). + """ + if self.tp == 0: + return ( + 0.0 if self.num_pred_instances + self.num_ref_instances > 0 else np.nan + ) + return self.tp / (self.tp + 0.5 * self.fp + 0.5 * self.fn) + + @property + def sq(self) -> float: + """ + Calculate the Segmentation Quality (SQ) based on IoU values. + + Returns: + float: Segmentation Quality (SQ). + """ + is_edge_case, result = self.edge_case_handler.handle_zero_tp( + metric=ListMetric.IOU, + tp=self.tp, + num_pred_instances=self.num_pred_instances, + num_ref_instances=self.num_ref_instances, + ) + if is_edge_case: + return result + if ListMetric.IOU.name not in self.metric_dict: + print("Requested SQ but no IOU metric evaluated") + return None + return np.sum(self.metric_dict[ListMetric.IOU.name]) / self.tp + + @property + def sq_sd(self) -> float: + """ + Calculate the standard deviation of Segmentation Quality (SQ) based on IoU values. + + Returns: + float: Standard deviation of Segmentation Quality (SQ). + """ + if ListMetric.IOU.name not in self.metric_dict: + print("Requested SQ_SD but no IOU metric evaluated") + return None + return ( + np.std(self.metric_dict[ListMetric.IOU.name]) + if len(self.metric_dict[ListMetric.IOU.name]) > 0 + else self.edge_case_handler.handle_empty_list_std() + ) + + @property + def pq(self) -> float: + """ + Calculate the Panoptic Quality (PQ) based on SQ and RQ. + + Returns: + float: Panoptic Quality (PQ). + """ + sq = self.sq + rq = self.rq + if sq is None or rq is None: + return None + else: + return sq * rq + + @property + def sq_dsc(self) -> float: + """ + Calculate the average Dice coefficient for matched instances. Analogue to segmentation quality but based on DSC. + + Returns: + float: Average Dice coefficient. + """ + is_edge_case, result = self.edge_case_handler.handle_zero_tp( + metric=ListMetric.DSC, + tp=self.tp, + num_pred_instances=self.num_pred_instances, + num_ref_instances=self.num_ref_instances, + ) + if is_edge_case: + return result + if ListMetric.DSC.name not in self.metric_dict: + print("Requested DSC but no DSC metric evaluated") + return None + return np.sum(self.metric_dict[ListMetric.DSC.name]) / self.tp + + @property + def sq_dsc_sd(self) -> float: + """ + Calculate the standard deviation of average Dice coefficient for matched instances. Analogue to segmentation quality but based on DSC. + + Returns: + float: Standard deviation of Average Dice coefficient. + """ + if ListMetric.DSC.name not in self.metric_dict: + print("Requested DSC_SD but no DSC metric evaluated") + return None + return ( + np.std(self.metric_dict[ListMetric.DSC.name]) + if len(self.metric_dict[ListMetric.DSC.name]) > 0 + else self.edge_case_handler.handle_empty_list_std() + ) + + @property + def pq_dsc(self) -> float: + """ + Calculate the Panoptic Quality (PQ) based on DSC-based SQ and RQ. + + Returns: + float: Panoptic Quality (PQ). + """ + sq = self.sq_dsc + rq = self.rq + if sq is None or rq is None: + return None + else: + return sq * rq + + @property + def sq_assd(self) -> float: + """ + Calculate the average average symmetric surface distance (ASSD) for matched instances. Analogue to segmentation quality but based on ASSD. + + Returns: + float: average symmetric surface distance. (ASSD) + """ + is_edge_case, result = self.edge_case_handler.handle_zero_tp( + metric=ListMetric.ASSD, + tp=self.tp, + num_pred_instances=self.num_pred_instances, + num_ref_instances=self.num_ref_instances, + ) + if is_edge_case: + return result + if ListMetric.ASSD.name not in self.metric_dict: + print("Requested ASSD but no ASSD metric evaluated") + return None + return np.sum(self.metric_dict[ListMetric.ASSD.name]) / self.tp + + @property + def sq_assd_sd(self) -> float: + """ + Calculate the standard deviation of average symmetric surface distance (ASSD) for matched instances. Analogue to segmentation quality but based on ASSD. + Returns: + float: Standard deviation of average symmetric surface distance (ASSD). + """ + if ListMetric.ASSD.name not in self.metric_dict: + print("Requested ASSD_SD but no ASSD metric evaluated") + return None + return ( + np.std(self.metric_dict[ListMetric.ASSD.name]) + if len(self.metric_dict[ListMetric.ASSD.name]) > 0 + else self.edge_case_handler.handle_empty_list_std() + ) + + @property + def pq_assd(self) -> float: + """ + Calculate the Panoptic Quality (PQ) based on ASSD-based SQ and RQ. + + Returns: + float: Panoptic Quality (PQ). + """ + return self.sq_assd * self.rq + + +# TODO make general getter that takes metric enum and std or not +# splits up into lists or not +# use below structure +def getter(value: int): + return value + + +class Test(object): + def __init__(self) -> None: + self.x: int + self.y: int + + # x = property(fget=getter(value=45)) + + def __getattribute__(self, __name: str) -> Any: + attr = None + try: + attr = object.__getattribute__(self, __name) + except AttributeError as e: + pass + if attr is None: + value = getter(5) + setattr(self, __name, value) + return value + else: + return attr + + # def __getattribute__(self, name): + # if some_predicate(name): + # # ... + # else: + # # Default behaviour + # return object.__getattribute__(self, name) + + +if __name__ == "__main__": + c = Test() + + print(c.x) + + c.x = 4 + + print(c.x) diff --git a/panoptica/result.py b/panoptica/result.py deleted file mode 100644 index 8e7164c..0000000 --- a/panoptica/result.py +++ /dev/null @@ -1,245 +0,0 @@ -from __future__ import annotations - -from typing import List - -import numpy as np - - -class PanopticaResult: - """ - Represents the result of the Panoptic Quality (PQ) computation. - - Attributes: - num_ref_instances (int): Number of reference instances. - num_pred_instances (int): Number of predicted instances. - tp (int): Number of correctly matched instances (True Positives). - fp (int): Number of extra predicted instances (False Positives). - """ - - def __init__( - self, - num_ref_instances: int, - num_pred_instances: int, - tp: int, - dice_list: List[float], - iou_list: List[float], - assd_list: List[float], - ): - """ - Initialize a PanopticaResult object. - - Args: - num_ref_instances (int): Number of reference instances. - num_pred_instances (int): Number of predicted instances. - tp (int): Number of correctly matched instances (True Positives). - dice_list (List[float]): List of Dice coefficients for matched instances. - iou_list (List[float]): List of IoU values for matched instances. - """ - self._tp = tp - self._dice_list = dice_list - self._iou_list = iou_list - self._num_ref_instances = num_ref_instances - self._num_pred_instances = num_pred_instances - self._assd_list = assd_list - - def __str__(self): - return ( - f"Number of instances in prediction: {self.num_pred_instances}\n" - f"Number of instances in reference: {self.num_ref_instances}\n" - f"True Positives (tp): {self.tp}\n" - f"False Positives (fp): {self.fp}\n" - f"False Negatives (fn): {self.fn}\n" - f"Recognition Quality / F1 Score (RQ): {self.rq}\n" - f"Segmentation Quality (SQ): {self.sq} ± {self.sq_sd}\n" - f"Panoptic Quality (PQ): {self.pq}\n" - f"DSC-based Segmentation Quality (DQ_DSC): {self.sq_dsc} ± {self.sq_dsc_sd}\n" - f"DSC-based Panoptic Quality (PQ_DSC): {self.pq_dsc}\n" - f"Average symmetric surface distance (ASSD): {self.sq_assd} ± {self.sq_assd_sd}\n" - f"ASSD-based Panoptic Quality (PQ_ASSD): {self.pq_assd}" - ) - - def to_dict(self): - return { - "num_pred_instances": self.num_pred_instances, - "num_ref_instances": self.num_ref_instances, - "tp": self.tp, - "fp": self.fp, - "fn": self.fn, - "rq": self.rq, - "sq": self.sq, - "sq_sd": self.sq_sd, - "pq": self.pq, - "sq_dsc": self.sq_dsc, - "sq_dsc_sd": self.sq_dsc_sd, - "pq_dsc": self.pq_dsc, - "sq_assd": self.sq_assd, - "sq_assd_sd": self.sq_assd_sd, - "pq_assd": self.pq_assd, - } - - @property - def num_ref_instances(self) -> int: - """ - Get the number of reference instances. - - Returns: - int: Number of reference instances. - """ - return self._num_ref_instances - - @property - def num_pred_instances(self) -> int: - """ - Get the number of predicted instances. - - Returns: - int: Number of predicted instances. - """ - return self._num_pred_instances - - @property - def tp(self) -> int: - """ - Calculate the number of True Positives (TP). - - Returns: - int: Number of True Positives. - """ - return self._tp - - @property - def fp(self) -> int: - """ - Calculate the number of False Positives (FP). - - Returns: - int: Number of False Positives. - """ - return self.num_pred_instances - self.tp - - @property - def fn(self) -> int: - """ - Calculate the number of False Negatives (FN). - - Returns: - int: Number of False Negatives. - """ - return self.num_ref_instances - self.tp - - @property - def rq(self) -> float: - """ - Calculate the Recognition Quality (RQ) based on TP, FP, and FN. - - Returns: - float: Recognition Quality (RQ). - """ - if self.tp == 0: - return ( - 0.0 if self.num_pred_instances + self.num_ref_instances > 0 else np.nan - ) - return self.tp / (self.tp + 0.5 * self.fp + 0.5 * self.fn) - - @property - def sq(self) -> float: - """ - Calculate the Segmentation Quality (SQ) based on IoU values. - - Returns: - float: Segmentation Quality (SQ). - """ - if self.tp == 0: - return ( - 0.0 if self.num_pred_instances + self.num_ref_instances > 0 else np.nan - ) - return np.sum(self._iou_list) / self.tp - - @property - def sq_sd(self) -> float: - """ - Calculate the standard deviation of Segmentation Quality (SQ) based on IoU values. - - Returns: - float: Standard deviation of Segmentation Quality (SQ). - """ - return np.std(self._iou_list) if len(self._iou_list) > 0 else np.nan - - @property - def pq(self) -> float: - """ - Calculate the Panoptic Quality (PQ) based on SQ and RQ. - - Returns: - float: Panoptic Quality (PQ). - """ - return self.sq * self.rq - - @property - def sq_dsc(self) -> float: - """ - Calculate the average Dice coefficient for matched instances. Analogue to segmentation quality but based on DSC. - - Returns: - float: Average Dice coefficient. - """ - if self.tp == 0: - return ( - 0.0 if self.num_pred_instances + self.num_ref_instances > 0 else np.nan - ) - return np.sum(self._dice_list) / self.tp - - @property - def sq_dsc_sd(self) -> float: - """ - Calculate the standard deviation of average Dice coefficient for matched instances. Analogue to segmentation quality but based on DSC. - - Returns: - float: Standard deviation of Average Dice coefficient. - """ - return np.std(self._dice_list) if len(self._dice_list) > 0 else np.nan - - @property - def pq_dsc(self) -> float: - """ - Calculate the Panoptic Quality (PQ) based on DSC-based SQ and RQ. - - Returns: - float: Panoptic Quality (PQ). - """ - return self.sq_dsc * self.rq - - @property - def sq_assd(self) -> float: - """ - Calculate the average average symmetric surface distance (ASSD) for matched instances. Analogue to segmentation quality but based on ASSD. - - Returns: - float: average symmetric surface distance. (ASSD) - """ - if self.tp == 0: - return ( - np.nan - if self.num_pred_instances + self.num_ref_instances == 0 - else np.inf - ) - return np.sum(self._assd_list) / self.tp - - @property - def sq_assd_sd(self) -> float: - """ - Calculate the standard deviation of average symmetric surface distance (ASSD) for matched instances. Analogue to segmentation quality but based on ASSD. - Returns: - float: Standard deviation of average symmetric surface distance (ASSD). - """ - return np.std(self._assd_list) if len(self._assd_list) > 0 else np.nan - - @property - def pq_assd(self) -> float: - """ - Calculate the Panoptic Quality (PQ) based on ASSD-based SQ and RQ. - - Returns: - float: Panoptic Quality (PQ). - """ - return self.sq_assd * self.rq diff --git a/panoptica/utils/__init__.py b/panoptica/utils/__init__.py index ca5e196..b5b9927 100644 --- a/panoptica/utils/__init__.py +++ b/panoptica/utils/__init__.py @@ -2,11 +2,16 @@ _count_unique_without_zeros, _unique_without_zeros, ) -from panoptica.utils.datatypes import ( +from panoptica.utils.processing_pair import ( SemanticPair, UnmatchedInstancePair, MatchedInstancePair, InstanceLabelMap, ) +from panoptica.utils.edge_case_handling import ( + EdgeCaseHandler, + EdgeCaseResult, + EdgeCaseZeroTP, +) # from utils.constants import diff --git a/panoptica/utils/constants.py b/panoptica/utils/constants.py index e76fd8b..d4a1faa 100644 --- a/panoptica/utils/constants.py +++ b/panoptica/utils/constants.py @@ -16,12 +16,6 @@ def __str__(self) -> str: def __repr__(self) -> str: return str(self) - def __hash__(self) -> int: - return self.value - - -from enum import Enum, auto - class CCABackend(_Enum_Compare): """ @@ -41,4 +35,5 @@ class CCABackend(_Enum_Compare): if __name__ == "__main__": - print(CCABackend.cc3d) + print(CCABackend.cc3d == "cc3d") + print("cc3d" == CCABackend.cc3d) diff --git a/panoptica/utils/edge_case_handling.py b/panoptica/utils/edge_case_handling.py new file mode 100644 index 0000000..c2f881a --- /dev/null +++ b/panoptica/utils/edge_case_handling.py @@ -0,0 +1,158 @@ +from typing import Any + +import numpy as np + +from panoptica.metrics import ListMetric, Metrics +from panoptica.utils.constants import _Enum_Compare, auto + + +class EdgeCaseResult(_Enum_Compare): + INF = np.inf + NAN = np.nan + ZERO = 0.0 + ONE = 1.0 + NONE = None + + +class EdgeCaseZeroTP(_Enum_Compare): + NO_INSTANCES = auto() + EMPTY_PRED = auto() + EMPTY_REF = auto() + NORMAL = auto() + + def __hash__(self) -> int: + return self.value + + +class MetricZeroTPEdgeCaseHandling(object): + def __init__( + self, + default_result: EdgeCaseResult, + no_instances_result: EdgeCaseResult | None = None, + empty_prediction_result: EdgeCaseResult | None = None, + empty_reference_result: EdgeCaseResult | None = None, + normal: EdgeCaseResult | None = None, + ) -> None: + self.edgecase_dict: dict[EdgeCaseZeroTP, EdgeCaseResult] = {} + self.edgecase_dict[EdgeCaseZeroTP.EMPTY_PRED] = ( + empty_prediction_result + if empty_prediction_result is not None + else default_result + ) + self.edgecase_dict[EdgeCaseZeroTP.EMPTY_REF] = ( + empty_reference_result + if empty_reference_result is not None + else default_result + ) + self.edgecase_dict[EdgeCaseZeroTP.NO_INSTANCES] = ( + no_instances_result if no_instances_result is not None else default_result + ) + self.edgecase_dict[EdgeCaseZeroTP.NORMAL] = ( + normal if normal is not None else default_result + ) + + def __call__( + self, tp: int, num_pred_instances, num_ref_instances + ) -> tuple[bool, float | None]: + if tp != 0: + return False, EdgeCaseResult.NONE.value + # + elif num_pred_instances + num_ref_instances == 0: + return True, self.edgecase_dict[EdgeCaseZeroTP.NO_INSTANCES].value + elif num_ref_instances == 0: + return True, self.edgecase_dict[EdgeCaseZeroTP.EMPTY_REF].value + elif num_pred_instances == 0: + return True, self.edgecase_dict[EdgeCaseZeroTP.EMPTY_PRED].value + elif num_pred_instances > 0 and num_ref_instances > 0: + return True, self.edgecase_dict[EdgeCaseZeroTP.NORMAL].value + + raise NotImplementedError( + f"MetricZeroTPEdgeCaseHandling: couldn't handle case, got tp {tp}, n_pred_instances {num_pred_instances}, n_ref_instances {num_ref_instances}" + ) + + def __str__(self) -> str: + txt = "" + for k, v in self.edgecase_dict.items(): + if v is not None: + txt += str(k) + ": " + str(v) + "\n" + return txt + + +class EdgeCaseHandler: + def __init__( + self, + listmetric_zeroTP_handling: dict[ListMetric, MetricZeroTPEdgeCaseHandling] = { + ListMetric.DSC: MetricZeroTPEdgeCaseHandling( + no_instances_result=EdgeCaseResult.NAN, + default_result=EdgeCaseResult.ZERO, + ), + ListMetric.IOU: MetricZeroTPEdgeCaseHandling( + no_instances_result=EdgeCaseResult.NAN, + empty_prediction_result=EdgeCaseResult.ZERO, + default_result=EdgeCaseResult.ZERO, + ), + ListMetric.ASSD: MetricZeroTPEdgeCaseHandling( + no_instances_result=EdgeCaseResult.NAN, + default_result=EdgeCaseResult.INF, + ), + }, + empty_list_std: EdgeCaseResult = EdgeCaseResult.NAN, + ) -> None: + self.__listmetric_zeroTP_handling: dict[ + ListMetric, MetricZeroTPEdgeCaseHandling + ] = listmetric_zeroTP_handling + self.__empty_list_std = empty_list_std + + def handle_zero_tp( + self, + metric: ListMetric, + tp: int, + num_pred_instances: int, + num_ref_instances: int, + ) -> tuple[bool, float | None]: + if metric not in self.__listmetric_zeroTP_handling: + raise NotImplementedError( + f"Metric {metric} encountered zero TP, but no edge handling available" + ) + + return self.__listmetric_zeroTP_handling[metric]( + tp=tp, + num_pred_instances=num_pred_instances, + num_ref_instances=num_ref_instances, + ) + + def get_metric_zero_tp_handle(self, metric: ListMetric): + return self.__listmetric_zeroTP_handling[metric] + + def handle_empty_list_std(self): + return self.__empty_list_std.value + + def __str__(self) -> str: + txt = f"EdgeCaseHandler:\n - Standard Deviation of Empty = {self.__empty_list_std}" + for k, v in self.__listmetric_zeroTP_handling.items(): + txt += f"\n- {k}: {str(v)}" + return str(txt) + + +if __name__ == "__main__": + handler = EdgeCaseHandler() + + print() + # print(handler.get_metric_zero_tp_handle(ListMetric.IOU)) + r = handler.handle_zero_tp( + ListMetric.IOU, tp=0, num_pred_instances=1, num_ref_instances=1 + ) + print(r) + + iou_test = MetricZeroTPEdgeCaseHandling( + no_instances_result=EdgeCaseResult.NAN, + default_result=EdgeCaseResult.ZERO, + ) + # print(iou_test) + t = iou_test(tp=0, num_pred_instances=1, num_ref_instances=1) + print(t) + + # iou_test = default_iou + # print(iou_test) + # t = iou_test(tp=0, num_pred_instances=1, num_ref_instances=1) + # print(t) diff --git a/panoptica/utils/datatypes.py b/panoptica/utils/processing_pair.py similarity index 93% rename from panoptica/utils/datatypes.py rename to panoptica/utils/processing_pair.py index 9e39ec2..93d4886 100644 --- a/panoptica/utils/datatypes.py +++ b/panoptica/utils/processing_pair.py @@ -3,8 +3,8 @@ import numpy as np from numpy import dtype -from panoptica._functionals import _get_paired_crop from panoptica.utils import _count_unique_without_zeros, _unique_without_zeros +from panoptica._functionals import _get_paired_crop uint_type: type = np.unsignedinteger int_type: type = np.integer @@ -24,10 +24,7 @@ class _ProcessingPair(ABC): n_dim: int def __init__( - self, - prediction_arr: np.ndarray, - reference_arr: np.ndarray, - dtype: type | None, + self, prediction_arr: np.ndarray, reference_arr: np.ndarray, dtype: type | None ) -> None: """Initializes a general Processing Pair @@ -42,10 +39,10 @@ def __init__( self.dtype = dtype self.n_dim = reference_arr.ndim self._ref_labels: tuple[int, ...] = tuple( - _unique_without_zeros(reference_arr), + _unique_without_zeros(reference_arr) ) # type:ignore self._pred_labels: tuple[int, ...] = tuple( - _unique_without_zeros(prediction_arr), + _unique_without_zeros(prediction_arr) ) # type:ignore self.crop: tuple[slice, ...] = None self.is_cropped: bool = False @@ -113,7 +110,6 @@ def copy(self): """ Creates an exact copy of this object """ - # TODO see linter error return type(self)( prediction_arr=self._prediction_arr, reference_arr=self._reference_arr, @@ -170,9 +166,7 @@ def copy(self): def _check_array_integrity( - prediction_arr: np.ndarray, - reference_arr: np.ndarray, - dtype: type | None = None, + prediction_arr: np.ndarray, reference_arr: np.ndarray, dtype: type | None = None ): """ Check the integrity of two numpy arrays. @@ -246,7 +240,7 @@ class MatchedInstancePair(_ProcessingPairInstanced): missed_reference_labels: list[int] missed_prediction_labels: list[int] - n_matched_instances: int + matched_instances: list[int] def __init__( self, @@ -254,7 +248,7 @@ def __init__( reference_arr: np.ndarray, missed_reference_labels: list[int] | None = None, missed_prediction_labels: list[int] | None = None, - n_matched_instances: int | None = None, + matched_instances: list[int] | None = None, n_prediction_instance: int | None = None, n_reference_instance: int | None = None, ) -> None: @@ -265,7 +259,7 @@ def __init__( reference_arr (np.ndarray): Numpy array containing the reference matched instance labels missed_reference_labels (list[int] | None, optional): List of unmatched reference labels. Defaults to None. missed_prediction_labels (list[int] | None, optional): List of unmatched prediction labels. Defaults to None. - n_matched_instances (int | None, optional): Number of total matched instances, i.e. unique matched labels in both maps. Defaults to None. + matched_instances (int | None, optional): matched instances labels, i.e. unique matched labels in both maps. Defaults to None. n_prediction_instance (int | None, optional): Number of prediction instances. Defaults to None. n_reference_instance (int | None, optional): Number of reference instances. Defaults to None. @@ -278,11 +272,9 @@ def __init__( n_prediction_instance, n_reference_instance, ) # type:ignore - if n_matched_instances is None: - n_matched_instances = len( - [i for i in self._pred_labels if i in self._ref_labels] - ) - self.n_matched_instances = n_matched_instances + if matched_instances is None: + matched_instances = [i for i in self._pred_labels if i in self._ref_labels] + self.matched_instances = matched_instances if missed_reference_labels is None: missed_reference_labels = list( @@ -296,6 +288,10 @@ def __init__( ) self.missed_prediction_labels = missed_prediction_labels + @property + def n_matched_instances(self): + return len(self.matched_instances) + def copy(self): """ Creates an exact copy of this object @@ -307,7 +303,7 @@ def copy(self): n_reference_instance=self.n_reference_instance, missed_reference_labels=self.missed_reference_labels, missed_prediction_labels=self.missed_prediction_labels, - n_matched_instances=self.n_matched_instances, + matched_instances=self.matched_instances, ) diff --git a/unit_tests/test_panoptic_evaluator.py b/unit_tests/test_panoptic_evaluator.py index 61bb8f6..27a7eb9 100644 --- a/unit_tests/test_panoptic_evaluator.py +++ b/unit_tests/test_panoptic_evaluator.py @@ -6,10 +6,11 @@ import os import numpy as np -from panoptica.evaluator import Panoptic_Evaluator +from panoptica.panoptic_evaluator import Panoptic_Evaluator from panoptica.instance_approximator import ConnectedComponentsInstanceApproximator from panoptica.instance_matcher import NaiveThresholdMatching -from panoptica.utils.datatypes import SemanticPair +from panoptica.metrics import Metrics +from panoptica.utils.processing_pair import SemanticPair class Test_Panoptic_Evaluator(unittest.TestCase): @@ -38,6 +39,101 @@ def test_simple_evaluation(self): self.assertEqual(result.sq, 0.75) self.assertEqual(result.pq, 0.75) + def test_simple_evaluation_DSC(self): + a = np.zeros([50, 50], dtype=np.uint16) + b = a.copy().astype(a.dtype) + a[20:40, 10:20] = 1 + b[20:35, 10:20] = 2 + + sample = SemanticPair(b, a) + + evaluator = Panoptic_Evaluator( + expected_input=SemanticPair, + instance_approximator=ConnectedComponentsInstanceApproximator(), + instance_matcher=NaiveThresholdMatching(), + ) + + result, debug_data = evaluator.evaluate(sample) + print(result) + self.assertEqual(result.tp, 1) + self.assertEqual(result.fp, 0) + self.assertEqual(result.sq, 0.75) + self.assertEqual(result.pq, 0.75) + + def test_simple_evaluation_DSC_partial(self): + a = np.zeros([50, 50], dtype=np.uint16) + b = a.copy().astype(a.dtype) + a[20:40, 10:20] = 1 + b[20:35, 10:20] = 2 + + sample = SemanticPair(b, a) + + evaluator = Panoptic_Evaluator( + expected_input=SemanticPair, + instance_approximator=ConnectedComponentsInstanceApproximator(), + instance_matcher=NaiveThresholdMatching(matching_metric=Metrics.DSC), + eval_metrics=[Metrics.DSC], + ) + + result, debug_data = evaluator.evaluate(sample) + print(result) + self.assertEqual(result.tp, 1) + self.assertEqual(result.fp, 0) + self.assertEqual( + result.sq, None + ) # must be none because no IOU has been calculated + self.assertEqual(result.pq, None) + self.assertEqual(result.rq, 1.0) + + def test_simple_evaluation_ASSD(self): + a = np.zeros([50, 50], dtype=np.uint16) + b = a.copy().astype(a.dtype) + a[20:40, 10:20] = 1 + b[20:35, 10:20] = 2 + + sample = SemanticPair(b, a) + + evaluator = Panoptic_Evaluator( + expected_input=SemanticPair, + instance_approximator=ConnectedComponentsInstanceApproximator(), + instance_matcher=NaiveThresholdMatching( + matching_metric=Metrics.ASSD, + matching_threshold=1.0, + ), + ) + + result, debug_data = evaluator.evaluate(sample) + print(result) + self.assertEqual(result.tp, 1) + self.assertEqual(result.fp, 0) + self.assertEqual(result.sq, 0.75) + self.assertEqual(result.pq, 0.75) + + def test_simple_evaluation_ASSD_negative(self): + a = np.zeros([50, 50], dtype=np.uint16) + b = a.copy().astype(a.dtype) + a[20:40, 10:20] = 1 + b[20:35, 10:20] = 2 + + sample = SemanticPair(b, a) + + evaluator = Panoptic_Evaluator( + expected_input=SemanticPair, + instance_approximator=ConnectedComponentsInstanceApproximator(), + instance_matcher=NaiveThresholdMatching( + matching_metric=Metrics.ASSD, + matching_threshold=0.5, + ), + ) + + result, debug_data = evaluator.evaluate(sample) + print(result) + self.assertEqual(result.tp, 0) + self.assertEqual(result.fp, 1) + self.assertEqual(result.sq, 0.0) + self.assertEqual(result.pq, 0.0) + self.assertEqual(result.sq_assd, np.inf) + def test_pred_empty(self): a = np.zeros([50, 50], np.uint16) b = a.copy()