diff --git a/examples/example_spine_instance.py b/examples/example_spine_instance.py index 241407e..f10cb34 100644 --- a/examples/example_spine_instance.py +++ b/examples/example_spine_instance.py @@ -3,8 +3,9 @@ from auxiliary.nifti.io import read_nifti from auxiliary.turbopath import turbopath -from panoptica import MatchedInstancePair, Panoptic_Evaluator +from panoptica import MatchedInstancePair, Panoptica_Evaluator from panoptica.metrics import Metric +from panoptica.utils.segmentation_class import LabelGroup, SegmentationClassGroups directory = turbopath(__file__).parent @@ -14,10 +15,17 @@ sample = MatchedInstancePair(prediction_arr=pred_masks, reference_arr=ref_masks) - -evaluator = Panoptic_Evaluator( +evaluator = Panoptica_Evaluator( expected_input=MatchedInstancePair, eval_metrics=[Metric.DSC, Metric.IOU], + segmentation_class_groups=SegmentationClassGroups( + { + "vertebra": LabelGroup([i for i in range(1, 10)]), + "ivd": LabelGroup([i for i in range(101, 109)]), + "sacrum": (26, True), + "endplate": LabelGroup([i for i in range(201, 209)]), + } + ), decision_metric=Metric.DSC, decision_threshold=0.5, ) @@ -25,7 +33,10 @@ with cProfile.Profile() as pr: if __name__ == "__main__": - result, debug_data = evaluator.evaluate(sample, verbose=True) - print(result) + results = evaluator.evaluate(sample, verbose=False) + for groupname, (result, debug) in results.items(): + print() + print("### Group", groupname) + print(result) pr.dump_stats(directory + "/instance_example.log") diff --git a/examples/example_spine_semantic.py b/examples/example_spine_semantic.py index 57e6a04..7385701 100644 --- a/examples/example_spine_semantic.py +++ b/examples/example_spine_semantic.py @@ -6,7 +6,7 @@ from panoptica import ( ConnectedComponentsInstanceApproximator, NaiveThresholdMatching, - Panoptic_Evaluator, + Panoptica_Evaluator, SemanticPair, ) @@ -17,8 +17,7 @@ sample = SemanticPair(pred_masks, ref_masks) - -evaluator = Panoptic_Evaluator( +evaluator = Panoptica_Evaluator( expected_input=SemanticPair, instance_approximator=ConnectedComponentsInstanceApproximator(), instance_matcher=NaiveThresholdMatching(), @@ -27,7 +26,7 @@ with cProfile.Profile() as pr: if __name__ == "__main__": - result, debug_data = evaluator.evaluate(sample) + result, debug_data = evaluator.evaluate(sample)["ungrouped"] print(result) pr.dump_stats(directory + "/semantic_example.log") diff --git a/panoptica/__init__.py b/panoptica/__init__.py index ada7cd0..9a4d080 100644 --- a/panoptica/__init__.py +++ b/panoptica/__init__.py @@ -3,8 +3,8 @@ CCABackend, ) from panoptica.instance_matcher import NaiveThresholdMatching -from panoptica.panoptic_evaluator import Panoptic_Evaluator -from panoptica.panoptic_result import PanopticaResult +from panoptica.panoptica_evaluator import Panoptica_Evaluator +from panoptica.panoptica_result import PanopticaResult from panoptica.utils.processing_pair import ( SemanticPair, UnmatchedInstancePair, diff --git a/panoptica/_functionals.py b/panoptica/_functionals.py index ebcf108..eee9a49 100644 --- a/panoptica/_functionals.py +++ b/panoptica/_functionals.py @@ -79,87 +79,6 @@ def _calc_matching_metric_of_overlapping_labels( return mm_pairs -def _calc_iou_of_overlapping_labels( - prediction_arr: np.ndarray, - reference_arr: np.ndarray, - ref_labels: tuple[int, ...], - **kwargs, -) -> list[tuple[float, tuple[int, int]]]: - """Calculates the IOU for all overlapping labels (fast!) - - Args: - prediction_arr (np.ndarray): Numpy array containing the prediction labels. - reference_arr (np.ndarray): Numpy array containing the reference labels. - ref_labels (list[int]): List of unique reference labels. - pred_labels (list[int]): List of unique prediction labels. - - Returns: - list[tuple[float, tuple[int, int]]]: List of pairs in style: (iou, (ref_label, pred_label)) - """ - instance_pairs = [ - (reference_arr, prediction_arr, i[0], i[1]) - for i in _calc_overlapping_labels( - prediction_arr=prediction_arr, - reference_arr=reference_arr, - ref_labels=ref_labels, - ) - ] - with Pool() as pool: - iou_values = pool.starmap(_compute_instance_iou, instance_pairs) - - iou_pairs = [ - (i, (instance_pairs[idx][2], instance_pairs[idx][3])) - for idx, i in enumerate(iou_values) - ] - iou_pairs = sorted(iou_pairs, key=lambda x: x[0], reverse=True) - - return iou_pairs - - -def _calc_iou_matrix( - prediction_arr: np.ndarray, - reference_arr: np.ndarray, - ref_labels: tuple[int, ...], - pred_labels: tuple[int, ...], -): - """ - Calculate the Intersection over Union (IoU) matrix between reference and prediction arrays. - - Args: - prediction_arr (np.ndarray): Numpy array containing the prediction labels. - reference_arr (np.ndarray): Numpy array containing the reference labels. - ref_labels (list[int]): List of unique reference labels. - pred_labels (list[int]): List of unique prediction labels. - - Returns: - np.ndarray: IoU matrix where each element represents the IoU between a reference and prediction instance. - - Example: - >>> _calc_iou_matrix(np.array([1, 2, 3]), np.array([4, 5, 6]), [1, 2, 3], [4, 5, 6]) - array([[0. , 0. , 0. ], - [0. , 0. , 0. ], - [0. , 0. , 0. ]]) - """ - num_ref_instances = len(ref_labels) - num_pred_instances = len(pred_labels) - - # Create a pool of worker processes to parallelize the computation - with Pool() as pool: - # # Generate all possible pairs of instance indices for IoU computation - instance_pairs = [ - (reference_arr, prediction_arr, ref_idx, pred_idx) - for ref_idx in ref_labels - for pred_idx in pred_labels - ] - - # Calculate IoU for all instance pairs in parallel using starmap - iou_values = pool.starmap(_compute_instance_iou, instance_pairs) - - # Reshape the resulting IoU values into a matrix - iou_matrix = np.array(iou_values).reshape((num_ref_instances, num_pred_instances)) - return iou_matrix - - def _map_labels( arr: np.ndarray, label_map: dict[np.integer, np.integer], diff --git a/panoptica/instance_evaluator.py b/panoptica/instance_evaluator.py index a379bf3..383df0e 100644 --- a/panoptica/instance_evaluator.py +++ b/panoptica/instance_evaluator.py @@ -3,7 +3,7 @@ import numpy as np from panoptica.metrics import Metric -from panoptica.panoptic_result import PanopticaResult +from panoptica.panoptica_result import PanopticaResult from panoptica.utils import EdgeCaseHandler from panoptica.utils.processing_pair import MatchedInstancePair diff --git a/panoptica/metrics/metrics.py b/panoptica/metrics/metrics.py index b17630c..2d6a4c0 100644 --- a/panoptica/metrics/metrics.py +++ b/panoptica/metrics/metrics.py @@ -15,7 +15,7 @@ from panoptica.utils.constants import _Enum_Compare, auto if TYPE_CHECKING: - from panoptic_result import PanopticaResult + from panoptica.panoptica_result import PanopticaResult @dataclass diff --git a/panoptica/metrics/relative_volume_difference.py b/panoptica/metrics/relative_volume_difference.py index bbb131b..4d952b2 100644 --- a/panoptica/metrics/relative_volume_difference.py +++ b/panoptica/metrics/relative_volume_difference.py @@ -66,5 +66,5 @@ def _compute_relative_volume_difference( return 0.0 # Calculate Dice coefficient - rvd = (prediction_mask - reference_mask) / reference_mask + rvd = float(prediction_mask - reference_mask) / reference_mask return rvd diff --git a/panoptica/panoptic_evaluator.py b/panoptica/panoptica_evaluator.py similarity index 74% rename from panoptica/panoptic_evaluator.py rename to panoptica/panoptica_evaluator.py index 13d2e5d..ca59b98 100644 --- a/panoptica/panoptic_evaluator.py +++ b/panoptica/panoptica_evaluator.py @@ -5,8 +5,8 @@ from panoptica.instance_evaluator import evaluate_matched_instance from panoptica.instance_matcher import InstanceMatchingAlgorithm from panoptica.metrics import Metric, _Metric -from panoptica.panoptic_result import PanopticaResult -from panoptica.timing import measure_time +from panoptica.panoptica_result import PanopticaResult +from panoptica.utils.timing import measure_time from panoptica.utils import EdgeCaseHandler from panoptica.utils.citation_reminder import citation_reminder from panoptica.utils.processing_pair import ( @@ -15,18 +15,21 @@ UnmatchedInstancePair, _ProcessingPair, ) +from panoptica.utils.segmentation_class import SegmentationClassGroups, LabelGroup -class Panoptic_Evaluator: +class Panoptica_Evaluator: def __init__( self, + # TODO let users give prediction and reference arr instead of the processing pair, so let this create the processing pair itself expected_input: ( Type[SemanticPair] | Type[UnmatchedInstancePair] | Type[MatchedInstancePair] ) = MatchedInstancePair, instance_approximator: InstanceApproximator | None = None, instance_matcher: InstanceMatchingAlgorithm | None = None, edge_case_handler: EdgeCaseHandler | None = None, + segmentation_class_groups: SegmentationClassGroups | None = None, eval_metrics: list[Metric] = [Metric.DSC, Metric.IOU, Metric.ASSD, Metric.RVD], decision_metric: Metric | None = None, decision_threshold: float | None = None, @@ -49,6 +52,8 @@ def __init__( self.__decision_metric = decision_metric self.__decision_threshold = decision_threshold + self.__segmentation_class_groups = segmentation_class_groups + self.__edge_case_handler = ( edge_case_handler if edge_case_handler is not None else EdgeCaseHandler() ) @@ -69,24 +74,69 @@ def evaluate( ), result_all: bool = True, verbose: bool | None = None, - ) -> tuple[PanopticaResult, dict[str, _ProcessingPair]]: + ) -> dict[str, tuple[PanopticaResult, dict[str, _ProcessingPair]]]: assert ( type(processing_pair) == self.__expected_input ), f"input not of expected type {self.__expected_input}" - return panoptic_evaluate( - processing_pair=processing_pair, - edge_case_handler=self.__edge_case_handler, - instance_approximator=self.__instance_approximator, - instance_matcher=self.__instance_matcher, - eval_metrics=self.__eval_metrics, - decision_metric=self.__decision_metric, - decision_threshold=self.__decision_threshold, - result_all=result_all, - log_times=self.__log_times, - verbose=True if verbose is None else verbose, - verbose_calc=self.__verbose if verbose is None else verbose, + + if self.__segmentation_class_groups is None: + return { + "ungrouped": panoptic_evaluate( + processing_pair=processing_pair, + edge_case_handler=self.__edge_case_handler, + instance_approximator=self.__instance_approximator, + instance_matcher=self.__instance_matcher, + eval_metrics=self.__eval_metrics, + decision_metric=self.__decision_metric, + decision_threshold=self.__decision_threshold, + result_all=result_all, + log_times=self.__log_times, + verbose=True if verbose is None else verbose, + verbose_calc=self.__verbose if verbose is None else verbose, + ) + } + + self.__segmentation_class_groups.has_defined_labels_for( + processing_pair.prediction_arr, raise_error=True + ) + self.__segmentation_class_groups.has_defined_labels_for( + processing_pair.reference_arr, raise_error=True ) + result_grouped = {} + for group_name, label_group in self.__segmentation_class_groups.items(): + assert isinstance(label_group, LabelGroup) + + prediction_arr_grouped = label_group(processing_pair.prediction_arr) + reference_arr_grouped = label_group(processing_pair.reference_arr) + + single_instance_mode = label_group.single_instance + processing_pair_grouped = processing_pair.__class__(prediction_arr=prediction_arr_grouped, reference_arr=reference_arr_grouped) # type: ignore + decision_threshold = self.__decision_threshold + if single_instance_mode and not isinstance( + processing_pair, MatchedInstancePair + ): + processing_pair_grouped = MatchedInstancePair( + prediction_arr=processing_pair_grouped.prediction_arr, + reference_arr=processing_pair_grouped.reference_arr, + ) + decision_threshold = 0.0 + + result_grouped[group_name] = panoptic_evaluate( + processing_pair=processing_pair_grouped, + edge_case_handler=self.__edge_case_handler, + instance_approximator=self.__instance_approximator, + instance_matcher=self.__instance_matcher, + eval_metrics=self.__eval_metrics, + decision_metric=self.__decision_metric, + decision_threshold=decision_threshold, + result_all=result_all, + log_times=self.__log_times, + verbose=True if verbose is None else verbose, + verbose_calc=self.__verbose if verbose is None else verbose, + ) + return result_grouped + def panoptic_evaluate( processing_pair: ( diff --git a/panoptica/panoptic_result.py b/panoptica/panoptica_result.py similarity index 100% rename from panoptica/panoptic_result.py rename to panoptica/panoptica_result.py diff --git a/panoptica/utils/__init__.py b/panoptica/utils/__init__.py index 4a4392f..d4dfe79 100644 --- a/panoptica/utils/__init__.py +++ b/panoptica/utils/__init__.py @@ -15,3 +15,7 @@ ) # from utils.constants import +from panoptica.utils.segmentation_class import ( + SegmentationClassGroups, + LabelGroup, +) diff --git a/panoptica/utils/segmentation_class.py b/panoptica/utils/segmentation_class.py new file mode 100644 index 0000000..4535ad3 --- /dev/null +++ b/panoptica/utils/segmentation_class.py @@ -0,0 +1,175 @@ +import numpy as np + + +class LabelGroup: + """Defines a group of labels that semantically belong to each other. Only labels within a group will be matched with each other""" + + def __init__( + self, + value_labels: list[int] | int, + single_instance: bool = False, + ) -> None: + """Defines a group of labels that semantically belong to each other + + Args: + value_labels (list[int]): Actually labels in the prediction and reference mask in this group. Defines the labels that can be matched to each other + single_instance (bool, optional): If true, will not use the matching_threshold as there is only one instance (large organ, ...). Defaults to False. + """ + if isinstance(value_labels, int): + value_labels = [value_labels] + assert ( + len(value_labels) >= 1 + ), f"You tried to define a LabelGroup without any specified labels, got {value_labels}" + self.__value_labels = value_labels + assert np.all( + [v > 0 for v in self.__value_labels] + ), f"Given value labels are not >0, got {value_labels}" + self.__single_instance = single_instance + if self.__single_instance: + assert ( + len(value_labels) == 1 + ), f"single_instance set to True, but got more than one label for this group, got {value_labels}" + + @property + def value_labels(self) -> list[int]: + return self.__value_labels + + @property + def single_instance(self) -> bool: + return self.__single_instance + + def __call__( + self, + array: np.ndarray, + set_to_binary: bool = False, + ) -> np.ndarray: + """Extracts the labels of this class + + Args: + array (np.ndarray): Array to extract the segmentation group labels from + set_to_binary (bool, optional): If true, will output a binary array. Defaults to False. + + Returns: + np.ndarray: Array containing only the labels of this segmentation group + """ + array = array.copy() + array[np.isin(array, self.value_labels, invert=True)] = 0 + if set_to_binary: + array[array != 0] = 1 + return array + + def __str__(self) -> str: + return f"LabelGroup {self.value_labels}, single_instance={self.single_instance}" + + def __repr__(self) -> str: + return str(self) + + +class SegmentationClassGroups: + def __init__( + self, + groups: list[LabelGroup] | dict[str, LabelGroup | tuple[list[int] | int, bool]], + ) -> None: + self.__group_dictionary: dict[str, LabelGroup] = {} + self.__labels: list[int] = [] + # maps name of group to the group itself + + if isinstance(groups, list): + self.__group_dictionary = { + f"group_{idx}": g for idx, g in enumerate(groups) + } + else: + # transform dict into list of LabelGroups + for i, g in groups.items(): + name_lower = str(i).lower() + if isinstance(g, LabelGroup): + self.__group_dictionary[name_lower] = LabelGroup( + g.value_labels, g.single_instance + ) + else: + self.__group_dictionary[name_lower] = LabelGroup(g[0], g[1]) + + # needs to check that each label is accounted for exactly ONCE + labels = [ + value_label + for lg in self.__group_dictionary.values() + for value_label in lg.value_labels + ] + duplicates = list_duplicates(labels) + if len(duplicates) > 0: + raise AssertionError( + f"The same label was assigned to two different labelgroups, got {str(self)}" + ) + + self.__labels = labels + + def has_defined_labels_for( + self, arr: np.ndarray | list[int], raise_error: bool = False + ): + if isinstance(arr, list): + arr_labels = arr + else: + arr_labels = [i for i in np.unique(arr) if i != 0] + for al in arr_labels: + if al not in self.__labels: + if raise_error: + raise AssertionError( + f"Input array has labels undefined in the SegmentationClassGroups, got label {al} the groups are defined as {str(self)}" + ) + return False + return True + + def __str__(self) -> str: + text = "SegmentationClassGroups = " + for i, lg in self.__group_dictionary.items(): + text += f"\n - {i} : {str(lg)}" + return text + + def __contains__(self, item): + return item in self.__group_dictionary + + def __getitem__(self, key): + return self.__group_dictionary[key] + + def __iter__(self): + yield from self.__group_dictionary + + def keys(self) -> list[str]: + return list(self.__group_dictionary.keys()) + + def items(self): + for k in self: + yield k, self[k] + + +def list_duplicates(seq): + seen = set() + seen_add = seen.add + # adds all elements it doesn't know yet to seen and all other to seen_twice + seen_twice = set(x for x in seq if x in seen or seen_add(x)) + # turn the set into a list (as requested) + return list(seen_twice) + + +if __name__ == "__main__": + group1 = LabelGroup([1, 2, 3, 4, 5], single_instance=False) + + print(group1) + print(group1.value_labels) + + arr = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + group1_arr = group1(arr, True) + print(group1_arr) + + classgroups = SegmentationClassGroups( + groups={ + "vertebra": group1, + "ivds": LabelGroup([100, 101, 102]), + } + ) + print(classgroups) + + print(classgroups.has_defined_labels_for([1, 2, 3])) + + for i in classgroups: + print(i) diff --git a/panoptica/timing.py b/panoptica/utils/timing.py similarity index 100% rename from panoptica/timing.py rename to panoptica/utils/timing.py diff --git a/unit_tests/test_labelgroup.py b/unit_tests/test_labelgroup.py new file mode 100644 index 0000000..4aee9f3 --- /dev/null +++ b/unit_tests/test_labelgroup.py @@ -0,0 +1,91 @@ +# Call 'python -m unittest' on this folder +# coverage run -m unittest +# coverage report +# coverage html +import os +import unittest +import numpy as np + +from panoptica.utils.segmentation_class import LabelGroup, SegmentationClassGroups + + +class Test_DefinitionOfSegmentationLabels(unittest.TestCase): + def setUp(self) -> None: + os.environ["PANOPTICA_CITATION_REMINDER"] = "False" + return super().setUp() + + def test_labelgroup(self): + group1 = LabelGroup([1, 2, 3, 4, 5], single_instance=False) + + print(group1) + arr = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + group1_arr = group1(arr, True) + + print(group1_arr) + self.assertEqual(group1_arr.sum(), 5) + + group1_arr_ind = np.argwhere(group1_arr).flatten() + print(group1_arr_ind) + group1_labels = np.asarray(group1.value_labels) + print(group1_labels) + self.assertTrue(np.all(group1_arr_ind == group1_labels)) + + def test_labelgroup_notpresent(self): + group1 = LabelGroup([1, 2, 3, 4, 5], single_instance=False) + + print(group1) + arr = np.array([0, 6, 7, 8, 0, 15, 6, 7, 8, 9, 10]) + group1_arr = group1(arr, True) + + print(group1_arr) + self.assertEqual(group1_arr.sum(), 0) + + group1_arr_ind = np.argwhere(group1_arr).flatten() + self.assertEqual(len(group1_arr_ind), 0) + + def test_wrong_labelgroup_definitions(self): + + with self.assertRaises(AssertionError): + group1 = LabelGroup([1, 2, 3, 4, 5], single_instance=True) + + with self.assertRaises(AssertionError): + group1 = LabelGroup([], single_instance=False) + + with self.assertRaises(AssertionError): + group1 = LabelGroup([1, 0, -1, 5], single_instance=False) + + def test_segmentationclassgroup_easy(self): + group1 = LabelGroup([1, 2, 3, 4, 5], single_instance=False) + classgroups = SegmentationClassGroups( + groups={ + "vertebra": group1, + "ivds": LabelGroup([100, 101, 102]), + } + ) + + print(classgroups) + + self.assertTrue(classgroups.has_defined_labels_for([1, 2, 3])) + + self.assertTrue(classgroups.has_defined_labels_for([1, 100, 3])) + + self.assertFalse(classgroups.has_defined_labels_for([1, 99, 3])) + + self.assertTrue("ivds" in classgroups) + + for i in classgroups: + self.assertTrue(i in ["vertebra", "ivds"]) + + for i, lg in classgroups.items(): + print(i, lg) + self.assertTrue(isinstance(i, str)) + self.assertTrue(isinstance(lg, LabelGroup)) + + def test_segmentationclassgroup_decarations(self): + classgroups = SegmentationClassGroups( + groups=[LabelGroup(i) for i in range(1, 5)] + ) + + keys = classgroups.keys() + for i in range(1, 5): + self.assertTrue(f"group_{i-1}" in keys, f"not {i} in {keys}") diff --git a/unit_tests/test_metrics.py b/unit_tests/test_metrics.py index 729b321..2800187 100644 --- a/unit_tests/test_metrics.py +++ b/unit_tests/test_metrics.py @@ -8,7 +8,10 @@ import numpy as np from panoptica.metrics import Metric -from panoptica.panoptic_result import MetricCouldNotBeComputedException, PanopticaResult +from panoptica.panoptica_result import ( + MetricCouldNotBeComputedException, + PanopticaResult, +) from panoptica.utils.edge_case_handling import EdgeCaseHandler, EdgeCaseResult diff --git a/unit_tests/test_panoptic_evaluator.py b/unit_tests/test_panoptic_evaluator.py index 00245c0..a2cf755 100644 --- a/unit_tests/test_panoptic_evaluator.py +++ b/unit_tests/test_panoptic_evaluator.py @@ -10,9 +10,10 @@ from panoptica.instance_approximator import ConnectedComponentsInstanceApproximator from panoptica.instance_matcher import MaximizeMergeMatching, NaiveThresholdMatching from panoptica.metrics import Metric -from panoptica.panoptic_evaluator import Panoptic_Evaluator -from panoptica.panoptic_result import MetricCouldNotBeComputedException +from panoptica.panoptica_evaluator import Panoptica_Evaluator +from panoptica.panoptica_result import MetricCouldNotBeComputedException from panoptica.utils.processing_pair import SemanticPair +from panoptica.utils.segmentation_class import SegmentationClassGroups, LabelGroup class Test_Panoptic_Evaluator(unittest.TestCase): @@ -28,13 +29,13 @@ def test_simple_evaluation(self): sample = SemanticPair(b, a) - evaluator = Panoptic_Evaluator( + evaluator = Panoptica_Evaluator( expected_input=SemanticPair, instance_approximator=ConnectedComponentsInstanceApproximator(), instance_matcher=NaiveThresholdMatching(), ) - result, debug_data = evaluator.evaluate(sample) + result, debug_data = evaluator.evaluate(sample)["ungrouped"] print(result) self.assertEqual(result.tp, 1) self.assertEqual(result.fp, 0) @@ -49,13 +50,13 @@ def test_simple_evaluation_DSC(self): sample = SemanticPair(b, a) - evaluator = Panoptic_Evaluator( + evaluator = Panoptica_Evaluator( expected_input=SemanticPair, instance_approximator=ConnectedComponentsInstanceApproximator(), instance_matcher=NaiveThresholdMatching(), ) - result, debug_data = evaluator.evaluate(sample) + result, debug_data = evaluator.evaluate(sample)["ungrouped"] print(result) self.assertEqual(result.tp, 1) self.assertEqual(result.fp, 0) @@ -70,14 +71,14 @@ def test_simple_evaluation_DSC_partial(self): sample = SemanticPair(b, a) - evaluator = Panoptic_Evaluator( + evaluator = Panoptica_Evaluator( expected_input=SemanticPair, instance_approximator=ConnectedComponentsInstanceApproximator(), instance_matcher=NaiveThresholdMatching(matching_metric=Metric.DSC), eval_metrics=[Metric.DSC], ) - result, debug_data = evaluator.evaluate(sample) + result, debug_data = evaluator.evaluate(sample)["ungrouped"] print(result) self.assertEqual(result.tp, 1) self.assertEqual(result.fp, 0) @@ -96,7 +97,7 @@ def test_simple_evaluation_ASSD(self): sample = SemanticPair(b, a) - evaluator = Panoptic_Evaluator( + evaluator = Panoptica_Evaluator( expected_input=SemanticPair, instance_approximator=ConnectedComponentsInstanceApproximator(), instance_matcher=NaiveThresholdMatching( @@ -105,7 +106,7 @@ def test_simple_evaluation_ASSD(self): ), ) - result, debug_data = evaluator.evaluate(sample) + result, debug_data = evaluator.evaluate(sample)["ungrouped"] print(result) self.assertEqual(result.tp, 1) self.assertEqual(result.fp, 0) @@ -120,7 +121,7 @@ def test_simple_evaluation_ASSD_negative(self): sample = SemanticPair(b, a) - evaluator = Panoptic_Evaluator( + evaluator = Panoptica_Evaluator( expected_input=SemanticPair, instance_approximator=ConnectedComponentsInstanceApproximator(), instance_matcher=NaiveThresholdMatching( @@ -129,7 +130,7 @@ def test_simple_evaluation_ASSD_negative(self): ), ) - result, debug_data = evaluator.evaluate(sample) + result, debug_data = evaluator.evaluate(sample)["ungrouped"] print(result) self.assertEqual(result.tp, 0) self.assertEqual(result.fp, 1) @@ -145,13 +146,13 @@ def test_pred_empty(self): sample = SemanticPair(b, a) - evaluator = Panoptic_Evaluator( + evaluator = Panoptica_Evaluator( expected_input=SemanticPair, instance_approximator=ConnectedComponentsInstanceApproximator(), instance_matcher=NaiveThresholdMatching(), ) - result, debug_data = evaluator.evaluate(sample) + result, debug_data = evaluator.evaluate(sample)["ungrouped"] print(result) self.assertEqual(result.tp, 0) self.assertEqual(result.fp, 0) @@ -168,13 +169,13 @@ def test_ref_empty(self): sample = SemanticPair(b, a) - evaluator = Panoptic_Evaluator( + evaluator = Panoptica_Evaluator( expected_input=SemanticPair, instance_approximator=ConnectedComponentsInstanceApproximator(), instance_matcher=NaiveThresholdMatching(), ) - result, debug_data = evaluator.evaluate(sample) + result, debug_data = evaluator.evaluate(sample)["ungrouped"] print(result) self.assertEqual(result.tp, 0) self.assertEqual(result.fp, 1) @@ -191,13 +192,13 @@ def test_both_empty(self): sample = SemanticPair(b, a) - evaluator = Panoptic_Evaluator( + evaluator = Panoptica_Evaluator( expected_input=SemanticPair, instance_approximator=ConnectedComponentsInstanceApproximator(), instance_matcher=NaiveThresholdMatching(), ) - result, debug_data = evaluator.evaluate(sample) + result, debug_data = evaluator.evaluate(sample)["ungrouped"] print(result) self.assertEqual(result.tp, 0) self.assertEqual(result.fp, 0) @@ -229,13 +230,13 @@ def test_dtype_evaluation(self): else: sample = SemanticPair(b, a) - evaluator = Panoptic_Evaluator( + evaluator = Panoptica_Evaluator( expected_input=SemanticPair, instance_approximator=ConnectedComponentsInstanceApproximator(), instance_matcher=NaiveThresholdMatching(), ) - result, debug_data = evaluator.evaluate(sample) + result, debug_data = evaluator.evaluate(sample)["ungrouped"] print(result) self.assertEqual(result.tp, 1) self.assertEqual(result.fp, 0) @@ -250,13 +251,13 @@ def test_simple_evaluation_maximize_matcher(self): sample = SemanticPair(b, a) - evaluator = Panoptic_Evaluator( + evaluator = Panoptica_Evaluator( expected_input=SemanticPair, instance_approximator=ConnectedComponentsInstanceApproximator(), instance_matcher=MaximizeMergeMatching(), ) - result, debug_data = evaluator.evaluate(sample) + result, debug_data = evaluator.evaluate(sample)["ungrouped"] print(result) self.assertEqual(result.tp, 1) self.assertEqual(result.fp, 0) @@ -272,13 +273,13 @@ def test_simple_evaluation_maximize_matcher_overlaptwo(self): sample = SemanticPair(b, a) - evaluator = Panoptic_Evaluator( + evaluator = Panoptica_Evaluator( expected_input=SemanticPair, instance_approximator=ConnectedComponentsInstanceApproximator(), instance_matcher=MaximizeMergeMatching(), ) - result, debug_data = evaluator.evaluate(sample) + result, debug_data = evaluator.evaluate(sample)["ungrouped"] print(result) self.assertEqual(result.tp, 1) self.assertEqual(result.fp, 0) @@ -296,13 +297,13 @@ def test_simple_evaluation_maximize_matcher_overlap(self): sample = SemanticPair(b, a) - evaluator = Panoptic_Evaluator( + evaluator = Panoptica_Evaluator( expected_input=SemanticPair, instance_approximator=ConnectedComponentsInstanceApproximator(), instance_matcher=MaximizeMergeMatching(), ) - result, debug_data = evaluator.evaluate(sample) + result, debug_data = evaluator.evaluate(sample)["ungrouped"] print(result) self.assertEqual(result.tp, 1) self.assertEqual(result.fp, 1) @@ -310,3 +311,48 @@ def test_simple_evaluation_maximize_matcher_overlap(self): self.assertAlmostEqual(result.pq, 0.56666666) self.assertAlmostEqual(result.rq, 0.66666666) self.assertAlmostEqual(result.sq_dsc, 0.9189189189189) + + def test_single_instance_mode(self): + a = np.zeros([50, 50], dtype=np.uint16) + b = a.copy().astype(a.dtype) + a[20:40, 10:20] = 5 + b[20:35, 10:20] = 5 + + sample = SemanticPair(b, a) + + evaluator = Panoptica_Evaluator( + expected_input=SemanticPair, + instance_approximator=ConnectedComponentsInstanceApproximator(), + instance_matcher=NaiveThresholdMatching(), + segmentation_class_groups=SegmentationClassGroups({"organ": (5, True)}), + ) + + result, debug_data = evaluator.evaluate(sample)["organ"] + print(result) + self.assertEqual(result.tp, 1) + self.assertEqual(result.fp, 0) + self.assertEqual(result.sq, 0.75) + self.assertEqual(result.pq, 0.75) + + def test_single_instance_mode_nooverlap(self): + a = np.zeros([50, 50], dtype=np.uint16) + b = a.copy().astype(a.dtype) + a[20:40, 10:20] = 5 + b[5:15, 30:50] = 5 + + sample = SemanticPair(b, a) + + evaluator = Panoptica_Evaluator( + expected_input=SemanticPair, + instance_approximator=ConnectedComponentsInstanceApproximator(), + instance_matcher=NaiveThresholdMatching(), + segmentation_class_groups=SegmentationClassGroups({"organ": (5, True)}), + ) + + result, debug_data = evaluator.evaluate(sample)["organ"] + print(result) + self.assertEqual(result.tp, 1) + self.assertEqual(result.fp, 0) + self.assertEqual(result.sq, 0.0) + self.assertEqual(result.pq, 0.0) + self.assertEqual(result.global_bin_dsc, 0.0) diff --git a/unit_tests/test_panoptic_result.py b/unit_tests/test_panoptic_result.py index 96c5c64..e88b2b2 100644 --- a/unit_tests/test_panoptic_result.py +++ b/unit_tests/test_panoptic_result.py @@ -8,7 +8,10 @@ import numpy as np from panoptica.metrics import Metric -from panoptica.panoptic_result import MetricCouldNotBeComputedException, PanopticaResult +from panoptica.panoptica_result import ( + MetricCouldNotBeComputedException, + PanopticaResult, +) from panoptica.utils.edge_case_handling import EdgeCaseHandler, EdgeCaseResult