From 560b3d05d0eb37153ff953e006402356fb82a746 Mon Sep 17 00:00:00 2001 From: iback Date: Tue, 23 Apr 2024 11:31:11 +0000 Subject: [PATCH 1/4] added init of segmentation classes --- panoptica/utils/segmentation_class.py | 30 +++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) create mode 100644 panoptica/utils/segmentation_class.py diff --git a/panoptica/utils/segmentation_class.py b/panoptica/utils/segmentation_class.py new file mode 100644 index 0000000..6392520 --- /dev/null +++ b/panoptica/utils/segmentation_class.py @@ -0,0 +1,30 @@ +import numpy as np + + +class ClassGroup: + def __init__( + self, + value_labels: list[int] | int, + single_instance: bool = False, + ) -> None: + """Defines a group of labels that semantically belong to each other + + Args: + value_labels (list[int]): Actually labels in the prediction and reference mask in this group. Defines the labels that can be matched to each other + single_instance (bool, optional): If true, will not use the matching_threshold as there is only one instance (large organ, ...). Defaults to False. + """ + if isinstance(value_labels, int): + value_labels = [value_labels] + self._value_labels = value_labels + assert np.all([v > 0 for v in self._value_labels]), f"Given value labels are not >0, got {value_labels}" + self._single_instance = single_instance + if self._single_instance: + assert len(value_labels) == 1, f"single_instance set to True, but got more than one label for this group, got {value_labels}" + + @property + def value_labels(self): + return self._value_labels + + @property + def single_instance(self): + return self._single_instance From a5786872a98182c40c3dd9a7a7d52627494be5ba Mon Sep 17 00:00:00 2001 From: iback Date: Fri, 26 Apr 2024 14:59:16 +0000 Subject: [PATCH 2/4] first working concept for example_spine_instance, still needs refinement. Add the examples as unittests, make definition of groups easier and also result handling, integrate into semantic example as well --- examples/example_spine_instance.py | 19 +++- examples/example_spine_semantic.py | 1 - panoptica/panoptic_evaluator.py | 94 ++++++++++------- panoptica/panoptic_result.py | 22 +--- panoptica/utils/__init__.py | 4 + panoptica/utils/segmentation_class.py | 141 ++++++++++++++++++++++++-- 6 files changed, 215 insertions(+), 66 deletions(-) diff --git a/examples/example_spine_instance.py b/examples/example_spine_instance.py index 241407e..da8978a 100644 --- a/examples/example_spine_instance.py +++ b/examples/example_spine_instance.py @@ -5,6 +5,7 @@ from panoptica import MatchedInstancePair, Panoptic_Evaluator from panoptica.metrics import Metric +from panoptica.utils.segmentation_class import LabelGroup, SegmentationClassGroups directory = turbopath(__file__).parent @@ -14,10 +15,21 @@ sample = MatchedInstancePair(prediction_arr=pred_masks, reference_arr=ref_masks) +import numpy as np + +print(np.unique(pred_masks)) evaluator = Panoptic_Evaluator( expected_input=MatchedInstancePair, eval_metrics=[Metric.DSC, Metric.IOU], + segmentation_class_groups=SegmentationClassGroups( + { + "vertebra": LabelGroup([i for i in range(1, 10)]), + "ivd": LabelGroup([i for i in range(101, 109)]), + "sacrum": (26, True), + "endplate": LabelGroup([i for i in range(201, 209)]), + } + ), decision_metric=Metric.DSC, decision_threshold=0.5, ) @@ -25,7 +37,10 @@ with cProfile.Profile() as pr: if __name__ == "__main__": - result, debug_data = evaluator.evaluate(sample, verbose=True) - print(result) + results = evaluator.evaluate(sample, verbose=False) + for groupname, (result, debug) in results.items(): + print() + print("### Group", groupname) + print(result) pr.dump_stats(directory + "/instance_example.log") diff --git a/examples/example_spine_semantic.py b/examples/example_spine_semantic.py index 57e6a04..c96500c 100644 --- a/examples/example_spine_semantic.py +++ b/examples/example_spine_semantic.py @@ -17,7 +17,6 @@ sample = SemanticPair(pred_masks, ref_masks) - evaluator = Panoptic_Evaluator( expected_input=SemanticPair, instance_approximator=ConnectedComponentsInstanceApproximator(), diff --git a/panoptica/panoptic_evaluator.py b/panoptica/panoptic_evaluator.py index 13d2e5d..339a29e 100644 --- a/panoptica/panoptic_evaluator.py +++ b/panoptica/panoptic_evaluator.py @@ -15,18 +15,19 @@ UnmatchedInstancePair, _ProcessingPair, ) +from panoptica.utils.segmentation_class import SegmentationClassGroups, LabelGroup class Panoptic_Evaluator: def __init__( self, - expected_input: ( - Type[SemanticPair] | Type[UnmatchedInstancePair] | Type[MatchedInstancePair] - ) = MatchedInstancePair, + # TODO let users give prediction and reference arr instead of the processing pair, so let this create the processing pair itself + expected_input: Type[SemanticPair] | Type[UnmatchedInstancePair] | Type[MatchedInstancePair] = MatchedInstancePair, instance_approximator: InstanceApproximator | None = None, instance_matcher: InstanceMatchingAlgorithm | None = None, edge_case_handler: EdgeCaseHandler | None = None, + segmentation_class_groups: SegmentationClassGroups | None = None, eval_metrics: list[Metric] = [Metric.DSC, Metric.IOU, Metric.ASSD, Metric.RVD], decision_metric: Metric | None = None, decision_threshold: float | None = None, @@ -49,13 +50,11 @@ def __init__( self.__decision_metric = decision_metric self.__decision_threshold = decision_threshold - self.__edge_case_handler = ( - edge_case_handler if edge_case_handler is not None else EdgeCaseHandler() - ) + self.__segmentation_class_groups = segmentation_class_groups + + self.__edge_case_handler = edge_case_handler if edge_case_handler is not None else EdgeCaseHandler() if self.__decision_metric is not None: - assert ( - self.__decision_threshold is not None - ), "decision metric set but no decision threshold for it" + assert self.__decision_threshold is not None, "decision metric set but no decision threshold for it" # self.__log_times = log_times self.__verbose = verbose @@ -64,34 +63,59 @@ def __init__( @measure_time def evaluate( self, - processing_pair: ( - SemanticPair | UnmatchedInstancePair | MatchedInstancePair | PanopticaResult - ), + processing_pair: SemanticPair | UnmatchedInstancePair | MatchedInstancePair | PanopticaResult, result_all: bool = True, verbose: bool | None = None, ) -> tuple[PanopticaResult, dict[str, _ProcessingPair]]: - assert ( - type(processing_pair) == self.__expected_input - ), f"input not of expected type {self.__expected_input}" - return panoptic_evaluate( - processing_pair=processing_pair, - edge_case_handler=self.__edge_case_handler, - instance_approximator=self.__instance_approximator, - instance_matcher=self.__instance_matcher, - eval_metrics=self.__eval_metrics, - decision_metric=self.__decision_metric, - decision_threshold=self.__decision_threshold, - result_all=result_all, - log_times=self.__log_times, - verbose=True if verbose is None else verbose, - verbose_calc=self.__verbose if verbose is None else verbose, - ) + assert type(processing_pair) == self.__expected_input, f"input not of expected type {self.__expected_input}" + + if self.__segmentation_class_groups is None: + return { + "ungrouped": panoptic_evaluate( + processing_pair=processing_pair, + edge_case_handler=self.__edge_case_handler, + instance_approximator=self.__instance_approximator, + instance_matcher=self.__instance_matcher, + eval_metrics=self.__eval_metrics, + decision_metric=self.__decision_metric, + decision_threshold=self.__decision_threshold, + result_all=result_all, + log_times=self.__log_times, + verbose=True if verbose is None else verbose, + verbose_calc=self.__verbose if verbose is None else verbose, + ) + } + + self.__segmentation_class_groups.has_defined_labels_for(processing_pair.prediction_arr, raise_error=True) + self.__segmentation_class_groups.has_defined_labels_for(processing_pair.reference_arr, raise_error=True) + + result_grouped = {} + for group_name in self.__segmentation_class_groups: + label_group = self.__segmentation_class_groups[group_name] + assert isinstance(label_group, LabelGroup) + + prediction_arr_grouped = label_group(processing_pair.prediction_arr) + reference_arr_grouped = label_group(processing_pair.reference_arr) + + processing_pair_grouped = processing_pair.__class__(prediction_arr=prediction_arr_grouped, reference_arr=reference_arr_grouped) + result_grouped[group_name] = panoptic_evaluate( + processing_pair=processing_pair_grouped, + edge_case_handler=self.__edge_case_handler, + instance_approximator=self.__instance_approximator, + instance_matcher=self.__instance_matcher, + eval_metrics=self.__eval_metrics, + decision_metric=self.__decision_metric, + decision_threshold=self.__decision_threshold, + result_all=result_all, + log_times=self.__log_times, + verbose=True if verbose is None else verbose, + verbose_calc=self.__verbose if verbose is None else verbose, + ) + return result_grouped def panoptic_evaluate( - processing_pair: ( - SemanticPair | UnmatchedInstancePair | MatchedInstancePair | PanopticaResult - ), + processing_pair: SemanticPair | UnmatchedInstancePair | MatchedInstancePair | PanopticaResult, instance_approximator: InstanceApproximator | None = None, instance_matcher: InstanceMatchingAlgorithm | None = None, eval_metrics: list[Metric] = [Metric.DSC, Metric.IOU, Metric.ASSD], @@ -147,9 +171,7 @@ def panoptic_evaluate( processing_pair.crop_data() if isinstance(processing_pair, SemanticPair): - assert ( - instance_approximator is not None - ), "Got SemanticPair but not InstanceApproximator" + assert instance_approximator is not None, "Got SemanticPair but not InstanceApproximator" if verbose: print("-- Got SemanticPair, will approximate instances") start = perf_counter() @@ -169,9 +191,7 @@ def panoptic_evaluate( if isinstance(processing_pair, UnmatchedInstancePair): if verbose: print("-- Got UnmatchedInstancePair, will match instances") - assert ( - instance_matcher is not None - ), "Got UnmatchedInstancePair but not InstanceMatchingAlgorithm" + assert instance_matcher is not None, "Got UnmatchedInstancePair but not InstanceMatchingAlgorithm" start = perf_counter() processing_pair = instance_matcher.match_instances( processing_pair, diff --git a/panoptica/panoptic_result.py b/panoptica/panoptic_result.py index c037e85..5dfc617 100644 --- a/panoptica/panoptic_result.py +++ b/panoptica/panoptic_result.py @@ -270,9 +270,7 @@ def __init__( num_pred_instances=self.num_pred_instances, num_ref_instances=self.num_ref_instances, ) - self._list_metrics[k] = Evaluation_List_Metric( - k, empty_list_std, v, is_edge_case, edge_case_result - ) + self._list_metrics[k] = Evaluation_List_Metric(k, empty_list_std, v, is_edge_case, edge_case_result) def _add_metric( self, @@ -341,19 +339,13 @@ def __str__(self) -> str: return text def to_dict(self) -> dict: - return { - k: getattr(self, v.id) - for k, v in self._evaluation_metrics.items() - if (v._error == False and v._was_calculated) - } + return {k: getattr(self, v.id) for k, v in self._evaluation_metrics.items() if (v._error == False and v._was_calculated)} def get_list_metric(self, metric: Metric, mode: MetricMode): if metric in self._list_metrics: return self._list_metrics[metric][mode] else: - raise MetricCouldNotBeComputedException( - f"{metric} could not be found, have you set it in eval_metrics during evaluation?" - ) + raise MetricCouldNotBeComputedException(f"{metric} could not be found, have you set it in eval_metrics during evaluation?") def _calc_metric(self, metric_name: str, supress_error: bool = False): if metric_name in self._evaluation_metrics: @@ -369,9 +361,7 @@ def _calc_metric(self, metric_name: str, supress_error: bool = False): self._evaluation_metrics[metric_name]._was_calculated = True return value else: - raise MetricCouldNotBeComputedException( - f"could not find metric with name {metric_name}" - ) + raise MetricCouldNotBeComputedException(f"could not find metric with name {metric_name}") def __getattribute__(self, __name: str) -> Any: attr = None @@ -384,9 +374,7 @@ def __getattribute__(self, __name: str) -> Any: raise e if attr is None: if self._evaluation_metrics[__name]._error: - raise MetricCouldNotBeComputedException( - f"Requested metric {__name} that could not be computed" - ) + raise MetricCouldNotBeComputedException(f"Requested metric {__name} that could not be computed") elif not self._evaluation_metrics[__name]._was_calculated: value = self._calc_metric(__name) setattr(self, __name, value) diff --git a/panoptica/utils/__init__.py b/panoptica/utils/__init__.py index 4a4392f..d4dfe79 100644 --- a/panoptica/utils/__init__.py +++ b/panoptica/utils/__init__.py @@ -15,3 +15,7 @@ ) # from utils.constants import +from panoptica.utils.segmentation_class import ( + SegmentationClassGroups, + LabelGroup, +) diff --git a/panoptica/utils/segmentation_class.py b/panoptica/utils/segmentation_class.py index 6392520..bf25a5e 100644 --- a/panoptica/utils/segmentation_class.py +++ b/panoptica/utils/segmentation_class.py @@ -1,7 +1,13 @@ import numpy as np -class ClassGroup: +# TODO also support LabelMergedGroup which takes multi labels and convert them into one before the evaluation +# Useful for BraTs with hierarchical labels (then define one generic Group class and then two more specific subgroups, one for hierarchical, the other for the current one) + + +class LabelGroup: + """Defines a group of labels that semantically belong to each other. Only labels within a group will be matched with each other""" + def __init__( self, value_labels: list[int] | int, @@ -15,16 +21,133 @@ def __init__( """ if isinstance(value_labels, int): value_labels = [value_labels] - self._value_labels = value_labels - assert np.all([v > 0 for v in self._value_labels]), f"Given value labels are not >0, got {value_labels}" - self._single_instance = single_instance - if self._single_instance: + self.__value_labels = value_labels + assert np.all([v > 0 for v in self.__value_labels]), f"Given value labels are not >0, got {value_labels}" + self.__single_instance = single_instance + if self.__single_instance: assert len(value_labels) == 1, f"single_instance set to True, but got more than one label for this group, got {value_labels}" @property - def value_labels(self): - return self._value_labels + def value_labels(self) -> list[int]: + return self.__value_labels @property - def single_instance(self): - return self._single_instance + def single_instance(self) -> bool: + return self.__single_instance + + def __call__( + self, + array: np.ndarray, + set_to_binary: bool = False, + ) -> np.ndarray: + """Extracts the labels of this class + + Args: + array (np.ndarray): Array to extract the segmentation group labels from + set_to_binary (bool, optional): If true, will output a binary array. Defaults to False. + + Returns: + np.ndarray: Array containing only the labels of this segmentation group + """ + array = array.copy() + array[np.isin(array, self.value_labels, invert=True)] = 0 + if set_to_binary: + array[array != 0] = 1 + return array + + def __str__(self) -> str: + return f"LabelGroup {self.value_labels}, single_instance={self.single_instance}" + + def __repr__(self) -> str: + return str(self) + + +class SegmentationClassGroups: + def __init__( + self, + groups: list[LabelGroup] | dict[str, LabelGroup | tuple[list[int] | int, bool]], + ) -> None: + self.__group_dictionary: dict[str, LabelGroup] = {} + self.__labels: list[int] = [] + # maps name of group to the group itself + + if isinstance(groups, list): + self.__group_dictionary = {f"group_{idx}": g for idx, g in enumerate(groups)} + else: + # transform dict into list of LabelGroups + for i, g in groups.items(): + name_lower = str(i).lower() + if isinstance(g, LabelGroup): + self.__group_dictionary[name_lower] = LabelGroup(g.value_labels, g.single_instance) + else: + self.__group_dictionary[name_lower] = LabelGroup(g[0], g[1]) + + # needs to check that each label is accounted for exactly ONCE + labels = [value_label for lg in self.__group_dictionary.values() for value_label in lg.value_labels] + duplicates = list_duplicates(labels) + if len(duplicates) > 0: + raise AssertionError(f"The same label was assigned to two different labelgroups, got {str(self)}") + + self.__labels = labels + + def has_defined_labels_for(self, arr: np.ndarray | list[int], raise_error: bool = False): + if isinstance(arr, list): + arr_labels = arr + else: + arr_labels = [i for i in np.unique(arr) if i != 0] + for al in arr_labels: + if al not in self.__labels: + if raise_error: + raise AssertionError( + f"Input array has labels undefined in the SegmentationClassGroups, got label {al} the groups are defined as {str(self)}" + ) + return False + return True + + def __str__(self) -> str: + text = "SegmentationClassGroups = " + for i, lg in self.__group_dictionary.items(): + text += f"\n - {i} : {str(lg)}" + return text + + def __contains__(self, item): + return item in self.__group_dictionary + + def __getitem__(self, key): + return self.__group_dictionary[key] + + def __iter__(self): + yield from self.__group_dictionary + + +def list_duplicates(seq): + seen = set() + seen_add = seen.add + # adds all elements it doesn't know yet to seen and all other to seen_twice + seen_twice = set(x for x in seq if x in seen or seen_add(x)) + # turn the set into a list (as requested) + return list(seen_twice) + + +if __name__ == "__main__": + group1 = LabelGroup([1, 2, 3, 4, 5], single_instance=False) + + print(group1) + print(group1.value_labels) + + arr = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + group1_arr = group1(arr, True) + print(group1_arr) + + classgroups = SegmentationClassGroups( + groups={ + "vertebra": group1, + "ivds": LabelGroup([100, 101, 102]), + } + ) + print(classgroups) + + print(classgroups.has_defined_labels_for([1, 2, 3])) + + for i in classgroups: + print(i) From b900eb5feefeff9224fd0ee441a94e8502bd858a Mon Sep 17 00:00:00 2001 From: iback Date: Mon, 10 Jun 2024 15:07:17 +0000 Subject: [PATCH 3/4] added unittests for definition of segmentation labels. Tweaked some things. Added single_instance mode which disables matching and sets decision threshold to zero as it assumes there is only one instance of this class. renamed files consistently to panoptica. --- examples/example_spine_instance.py | 8 +- examples/example_spine_semantic.py | 6 +- panoptica/__init__.py | 4 +- panoptica/_functionals.py | 81 --------------- panoptica/instance_evaluator.py | 2 +- panoptica/metrics/metrics.py | 2 +- .../metrics/relative_volume_difference.py | 2 +- ...ic_evaluator.py => panoptica_evaluator.py} | 24 +++-- ...panoptic_result.py => panoptica_result.py} | 0 panoptica/utils/segmentation_class.py | 12 ++- panoptica/{ => utils}/timing.py | 0 unit_tests/test_labelgroup.py | 89 +++++++++++++++++ unit_tests/test_metrics.py | 2 +- unit_tests/test_panoptic_evaluator.py | 98 ++++++++++++++----- unit_tests/test_panoptic_result.py | 2 +- 15 files changed, 197 insertions(+), 135 deletions(-) rename panoptica/{panoptic_evaluator.py => panoptica_evaluator.py} (93%) rename panoptica/{panoptic_result.py => panoptica_result.py} (100%) rename panoptica/{ => utils}/timing.py (100%) create mode 100644 unit_tests/test_labelgroup.py diff --git a/examples/example_spine_instance.py b/examples/example_spine_instance.py index da8978a..f10cb34 100644 --- a/examples/example_spine_instance.py +++ b/examples/example_spine_instance.py @@ -3,7 +3,7 @@ from auxiliary.nifti.io import read_nifti from auxiliary.turbopath import turbopath -from panoptica import MatchedInstancePair, Panoptic_Evaluator +from panoptica import MatchedInstancePair, Panoptica_Evaluator from panoptica.metrics import Metric from panoptica.utils.segmentation_class import LabelGroup, SegmentationClassGroups @@ -15,11 +15,7 @@ sample = MatchedInstancePair(prediction_arr=pred_masks, reference_arr=ref_masks) -import numpy as np - -print(np.unique(pred_masks)) - -evaluator = Panoptic_Evaluator( +evaluator = Panoptica_Evaluator( expected_input=MatchedInstancePair, eval_metrics=[Metric.DSC, Metric.IOU], segmentation_class_groups=SegmentationClassGroups( diff --git a/examples/example_spine_semantic.py b/examples/example_spine_semantic.py index c96500c..7385701 100644 --- a/examples/example_spine_semantic.py +++ b/examples/example_spine_semantic.py @@ -6,7 +6,7 @@ from panoptica import ( ConnectedComponentsInstanceApproximator, NaiveThresholdMatching, - Panoptic_Evaluator, + Panoptica_Evaluator, SemanticPair, ) @@ -17,7 +17,7 @@ sample = SemanticPair(pred_masks, ref_masks) -evaluator = Panoptic_Evaluator( +evaluator = Panoptica_Evaluator( expected_input=SemanticPair, instance_approximator=ConnectedComponentsInstanceApproximator(), instance_matcher=NaiveThresholdMatching(), @@ -26,7 +26,7 @@ with cProfile.Profile() as pr: if __name__ == "__main__": - result, debug_data = evaluator.evaluate(sample) + result, debug_data = evaluator.evaluate(sample)["ungrouped"] print(result) pr.dump_stats(directory + "/semantic_example.log") diff --git a/panoptica/__init__.py b/panoptica/__init__.py index ada7cd0..9a4d080 100644 --- a/panoptica/__init__.py +++ b/panoptica/__init__.py @@ -3,8 +3,8 @@ CCABackend, ) from panoptica.instance_matcher import NaiveThresholdMatching -from panoptica.panoptic_evaluator import Panoptic_Evaluator -from panoptica.panoptic_result import PanopticaResult +from panoptica.panoptica_evaluator import Panoptica_Evaluator +from panoptica.panoptica_result import PanopticaResult from panoptica.utils.processing_pair import ( SemanticPair, UnmatchedInstancePair, diff --git a/panoptica/_functionals.py b/panoptica/_functionals.py index ebcf108..eee9a49 100644 --- a/panoptica/_functionals.py +++ b/panoptica/_functionals.py @@ -79,87 +79,6 @@ def _calc_matching_metric_of_overlapping_labels( return mm_pairs -def _calc_iou_of_overlapping_labels( - prediction_arr: np.ndarray, - reference_arr: np.ndarray, - ref_labels: tuple[int, ...], - **kwargs, -) -> list[tuple[float, tuple[int, int]]]: - """Calculates the IOU for all overlapping labels (fast!) - - Args: - prediction_arr (np.ndarray): Numpy array containing the prediction labels. - reference_arr (np.ndarray): Numpy array containing the reference labels. - ref_labels (list[int]): List of unique reference labels. - pred_labels (list[int]): List of unique prediction labels. - - Returns: - list[tuple[float, tuple[int, int]]]: List of pairs in style: (iou, (ref_label, pred_label)) - """ - instance_pairs = [ - (reference_arr, prediction_arr, i[0], i[1]) - for i in _calc_overlapping_labels( - prediction_arr=prediction_arr, - reference_arr=reference_arr, - ref_labels=ref_labels, - ) - ] - with Pool() as pool: - iou_values = pool.starmap(_compute_instance_iou, instance_pairs) - - iou_pairs = [ - (i, (instance_pairs[idx][2], instance_pairs[idx][3])) - for idx, i in enumerate(iou_values) - ] - iou_pairs = sorted(iou_pairs, key=lambda x: x[0], reverse=True) - - return iou_pairs - - -def _calc_iou_matrix( - prediction_arr: np.ndarray, - reference_arr: np.ndarray, - ref_labels: tuple[int, ...], - pred_labels: tuple[int, ...], -): - """ - Calculate the Intersection over Union (IoU) matrix between reference and prediction arrays. - - Args: - prediction_arr (np.ndarray): Numpy array containing the prediction labels. - reference_arr (np.ndarray): Numpy array containing the reference labels. - ref_labels (list[int]): List of unique reference labels. - pred_labels (list[int]): List of unique prediction labels. - - Returns: - np.ndarray: IoU matrix where each element represents the IoU between a reference and prediction instance. - - Example: - >>> _calc_iou_matrix(np.array([1, 2, 3]), np.array([4, 5, 6]), [1, 2, 3], [4, 5, 6]) - array([[0. , 0. , 0. ], - [0. , 0. , 0. ], - [0. , 0. , 0. ]]) - """ - num_ref_instances = len(ref_labels) - num_pred_instances = len(pred_labels) - - # Create a pool of worker processes to parallelize the computation - with Pool() as pool: - # # Generate all possible pairs of instance indices for IoU computation - instance_pairs = [ - (reference_arr, prediction_arr, ref_idx, pred_idx) - for ref_idx in ref_labels - for pred_idx in pred_labels - ] - - # Calculate IoU for all instance pairs in parallel using starmap - iou_values = pool.starmap(_compute_instance_iou, instance_pairs) - - # Reshape the resulting IoU values into a matrix - iou_matrix = np.array(iou_values).reshape((num_ref_instances, num_pred_instances)) - return iou_matrix - - def _map_labels( arr: np.ndarray, label_map: dict[np.integer, np.integer], diff --git a/panoptica/instance_evaluator.py b/panoptica/instance_evaluator.py index a379bf3..383df0e 100644 --- a/panoptica/instance_evaluator.py +++ b/panoptica/instance_evaluator.py @@ -3,7 +3,7 @@ import numpy as np from panoptica.metrics import Metric -from panoptica.panoptic_result import PanopticaResult +from panoptica.panoptica_result import PanopticaResult from panoptica.utils import EdgeCaseHandler from panoptica.utils.processing_pair import MatchedInstancePair diff --git a/panoptica/metrics/metrics.py b/panoptica/metrics/metrics.py index b17630c..2d6a4c0 100644 --- a/panoptica/metrics/metrics.py +++ b/panoptica/metrics/metrics.py @@ -15,7 +15,7 @@ from panoptica.utils.constants import _Enum_Compare, auto if TYPE_CHECKING: - from panoptic_result import PanopticaResult + from panoptica.panoptica_result import PanopticaResult @dataclass diff --git a/panoptica/metrics/relative_volume_difference.py b/panoptica/metrics/relative_volume_difference.py index bbb131b..4d952b2 100644 --- a/panoptica/metrics/relative_volume_difference.py +++ b/panoptica/metrics/relative_volume_difference.py @@ -66,5 +66,5 @@ def _compute_relative_volume_difference( return 0.0 # Calculate Dice coefficient - rvd = (prediction_mask - reference_mask) / reference_mask + rvd = float(prediction_mask - reference_mask) / reference_mask return rvd diff --git a/panoptica/panoptic_evaluator.py b/panoptica/panoptica_evaluator.py similarity index 93% rename from panoptica/panoptic_evaluator.py rename to panoptica/panoptica_evaluator.py index 339a29e..452e509 100644 --- a/panoptica/panoptic_evaluator.py +++ b/panoptica/panoptica_evaluator.py @@ -5,8 +5,8 @@ from panoptica.instance_evaluator import evaluate_matched_instance from panoptica.instance_matcher import InstanceMatchingAlgorithm from panoptica.metrics import Metric, _Metric -from panoptica.panoptic_result import PanopticaResult -from panoptica.timing import measure_time +from panoptica.panoptica_result import PanopticaResult +from panoptica.utils.timing import measure_time from panoptica.utils import EdgeCaseHandler from panoptica.utils.citation_reminder import citation_reminder from panoptica.utils.processing_pair import ( @@ -18,7 +18,7 @@ from panoptica.utils.segmentation_class import SegmentationClassGroups, LabelGroup -class Panoptic_Evaluator: +class Panoptica_Evaluator: def __init__( self, @@ -66,7 +66,7 @@ def evaluate( processing_pair: SemanticPair | UnmatchedInstancePair | MatchedInstancePair | PanopticaResult, result_all: bool = True, verbose: bool | None = None, - ) -> tuple[PanopticaResult, dict[str, _ProcessingPair]]: + ) -> dict[str, tuple[PanopticaResult, dict[str, _ProcessingPair]]]: assert type(processing_pair) == self.__expected_input, f"input not of expected type {self.__expected_input}" if self.__segmentation_class_groups is None: @@ -90,14 +90,22 @@ def evaluate( self.__segmentation_class_groups.has_defined_labels_for(processing_pair.reference_arr, raise_error=True) result_grouped = {} - for group_name in self.__segmentation_class_groups: - label_group = self.__segmentation_class_groups[group_name] + for group_name, label_group in self.__segmentation_class_groups.items(): assert isinstance(label_group, LabelGroup) prediction_arr_grouped = label_group(processing_pair.prediction_arr) reference_arr_grouped = label_group(processing_pair.reference_arr) - processing_pair_grouped = processing_pair.__class__(prediction_arr=prediction_arr_grouped, reference_arr=reference_arr_grouped) + single_instance_mode = label_group.single_instance + processing_pair_grouped = processing_pair.__class__(prediction_arr=prediction_arr_grouped, reference_arr=reference_arr_grouped) # type: ignore + decision_threshold = self.__decision_threshold + if single_instance_mode and not isinstance(processing_pair, MatchedInstancePair): + processing_pair_grouped = MatchedInstancePair( + prediction_arr=processing_pair_grouped.prediction_arr, + reference_arr=processing_pair_grouped.reference_arr, + ) + decision_threshold = 0.0 + result_grouped[group_name] = panoptic_evaluate( processing_pair=processing_pair_grouped, edge_case_handler=self.__edge_case_handler, @@ -105,7 +113,7 @@ def evaluate( instance_matcher=self.__instance_matcher, eval_metrics=self.__eval_metrics, decision_metric=self.__decision_metric, - decision_threshold=self.__decision_threshold, + decision_threshold=decision_threshold, result_all=result_all, log_times=self.__log_times, verbose=True if verbose is None else verbose, diff --git a/panoptica/panoptic_result.py b/panoptica/panoptica_result.py similarity index 100% rename from panoptica/panoptic_result.py rename to panoptica/panoptica_result.py diff --git a/panoptica/utils/segmentation_class.py b/panoptica/utils/segmentation_class.py index bf25a5e..a9d4f47 100644 --- a/panoptica/utils/segmentation_class.py +++ b/panoptica/utils/segmentation_class.py @@ -1,10 +1,6 @@ import numpy as np -# TODO also support LabelMergedGroup which takes multi labels and convert them into one before the evaluation -# Useful for BraTs with hierarchical labels (then define one generic Group class and then two more specific subgroups, one for hierarchical, the other for the current one) - - class LabelGroup: """Defines a group of labels that semantically belong to each other. Only labels within a group will be matched with each other""" @@ -21,6 +17,7 @@ def __init__( """ if isinstance(value_labels, int): value_labels = [value_labels] + assert len(value_labels) >= 1, f"You tried to define a LabelGroup without any specified labels, got {value_labels}" self.__value_labels = value_labels assert np.all([v > 0 for v in self.__value_labels]), f"Given value labels are not >0, got {value_labels}" self.__single_instance = single_instance @@ -119,6 +116,13 @@ def __getitem__(self, key): def __iter__(self): yield from self.__group_dictionary + def keys(self) -> list[str]: + return list(self.__group_dictionary.keys()) + + def items(self): + for k in self: + yield k, self[k] + def list_duplicates(seq): seen = set() diff --git a/panoptica/timing.py b/panoptica/utils/timing.py similarity index 100% rename from panoptica/timing.py rename to panoptica/utils/timing.py diff --git a/unit_tests/test_labelgroup.py b/unit_tests/test_labelgroup.py new file mode 100644 index 0000000..292629a --- /dev/null +++ b/unit_tests/test_labelgroup.py @@ -0,0 +1,89 @@ +# Call 'python -m unittest' on this folder +# coverage run -m unittest +# coverage report +# coverage html +import os +import unittest +import numpy as np + +from panoptica.utils.segmentation_class import LabelGroup, SegmentationClassGroups + + +class Test_DefinitionOfSegmentationLabels(unittest.TestCase): + def setUp(self) -> None: + os.environ["PANOPTICA_CITATION_REMINDER"] = "False" + return super().setUp() + + def test_labelgroup(self): + group1 = LabelGroup([1, 2, 3, 4, 5], single_instance=False) + + print(group1) + arr = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + group1_arr = group1(arr, True) + + print(group1_arr) + self.assertEqual(group1_arr.sum(), 5) + + group1_arr_ind = np.argwhere(group1_arr).flatten() + print(group1_arr_ind) + group1_labels = np.asarray(group1.value_labels) + print(group1_labels) + self.assertTrue(np.all(group1_arr_ind == group1_labels)) + + def test_labelgroup_notpresent(self): + group1 = LabelGroup([1, 2, 3, 4, 5], single_instance=False) + + print(group1) + arr = np.array([0, 6, 7, 8, 0, 15, 6, 7, 8, 9, 10]) + group1_arr = group1(arr, True) + + print(group1_arr) + self.assertEqual(group1_arr.sum(), 0) + + group1_arr_ind = np.argwhere(group1_arr).flatten() + self.assertEqual(len(group1_arr_ind), 0) + + def test_wrong_labelgroup_definitions(self): + + with self.assertRaises(AssertionError): + group1 = LabelGroup([1, 2, 3, 4, 5], single_instance=True) + + with self.assertRaises(AssertionError): + group1 = LabelGroup([], single_instance=False) + + with self.assertRaises(AssertionError): + group1 = LabelGroup([1, 0, -1, 5], single_instance=False) + + def test_segmentationclassgroup_easy(self): + group1 = LabelGroup([1, 2, 3, 4, 5], single_instance=False) + classgroups = SegmentationClassGroups( + groups={ + "vertebra": group1, + "ivds": LabelGroup([100, 101, 102]), + } + ) + + print(classgroups) + + self.assertTrue(classgroups.has_defined_labels_for([1, 2, 3])) + + self.assertTrue(classgroups.has_defined_labels_for([1, 100, 3])) + + self.assertFalse(classgroups.has_defined_labels_for([1, 99, 3])) + + self.assertTrue("ivds" in classgroups) + + for i in classgroups: + self.assertTrue(i in ["vertebra", "ivds"]) + + for i, lg in classgroups.items(): + print(i, lg) + self.assertTrue(isinstance(i, str)) + self.assertTrue(isinstance(lg, LabelGroup)) + + def test_segmentationclassgroup_decarations(self): + classgroups = SegmentationClassGroups(groups=[LabelGroup(i) for i in range(1, 5)]) + + keys = classgroups.keys() + for i in range(1, 5): + self.assertTrue(f"group_{i-1}" in keys, f"not {i} in {keys}") diff --git a/unit_tests/test_metrics.py b/unit_tests/test_metrics.py index 729b321..621407b 100644 --- a/unit_tests/test_metrics.py +++ b/unit_tests/test_metrics.py @@ -8,7 +8,7 @@ import numpy as np from panoptica.metrics import Metric -from panoptica.panoptic_result import MetricCouldNotBeComputedException, PanopticaResult +from panoptica.panoptica_result import MetricCouldNotBeComputedException, PanopticaResult from panoptica.utils.edge_case_handling import EdgeCaseHandler, EdgeCaseResult diff --git a/unit_tests/test_panoptic_evaluator.py b/unit_tests/test_panoptic_evaluator.py index 00245c0..a2cf755 100644 --- a/unit_tests/test_panoptic_evaluator.py +++ b/unit_tests/test_panoptic_evaluator.py @@ -10,9 +10,10 @@ from panoptica.instance_approximator import ConnectedComponentsInstanceApproximator from panoptica.instance_matcher import MaximizeMergeMatching, NaiveThresholdMatching from panoptica.metrics import Metric -from panoptica.panoptic_evaluator import Panoptic_Evaluator -from panoptica.panoptic_result import MetricCouldNotBeComputedException +from panoptica.panoptica_evaluator import Panoptica_Evaluator +from panoptica.panoptica_result import MetricCouldNotBeComputedException from panoptica.utils.processing_pair import SemanticPair +from panoptica.utils.segmentation_class import SegmentationClassGroups, LabelGroup class Test_Panoptic_Evaluator(unittest.TestCase): @@ -28,13 +29,13 @@ def test_simple_evaluation(self): sample = SemanticPair(b, a) - evaluator = Panoptic_Evaluator( + evaluator = Panoptica_Evaluator( expected_input=SemanticPair, instance_approximator=ConnectedComponentsInstanceApproximator(), instance_matcher=NaiveThresholdMatching(), ) - result, debug_data = evaluator.evaluate(sample) + result, debug_data = evaluator.evaluate(sample)["ungrouped"] print(result) self.assertEqual(result.tp, 1) self.assertEqual(result.fp, 0) @@ -49,13 +50,13 @@ def test_simple_evaluation_DSC(self): sample = SemanticPair(b, a) - evaluator = Panoptic_Evaluator( + evaluator = Panoptica_Evaluator( expected_input=SemanticPair, instance_approximator=ConnectedComponentsInstanceApproximator(), instance_matcher=NaiveThresholdMatching(), ) - result, debug_data = evaluator.evaluate(sample) + result, debug_data = evaluator.evaluate(sample)["ungrouped"] print(result) self.assertEqual(result.tp, 1) self.assertEqual(result.fp, 0) @@ -70,14 +71,14 @@ def test_simple_evaluation_DSC_partial(self): sample = SemanticPair(b, a) - evaluator = Panoptic_Evaluator( + evaluator = Panoptica_Evaluator( expected_input=SemanticPair, instance_approximator=ConnectedComponentsInstanceApproximator(), instance_matcher=NaiveThresholdMatching(matching_metric=Metric.DSC), eval_metrics=[Metric.DSC], ) - result, debug_data = evaluator.evaluate(sample) + result, debug_data = evaluator.evaluate(sample)["ungrouped"] print(result) self.assertEqual(result.tp, 1) self.assertEqual(result.fp, 0) @@ -96,7 +97,7 @@ def test_simple_evaluation_ASSD(self): sample = SemanticPair(b, a) - evaluator = Panoptic_Evaluator( + evaluator = Panoptica_Evaluator( expected_input=SemanticPair, instance_approximator=ConnectedComponentsInstanceApproximator(), instance_matcher=NaiveThresholdMatching( @@ -105,7 +106,7 @@ def test_simple_evaluation_ASSD(self): ), ) - result, debug_data = evaluator.evaluate(sample) + result, debug_data = evaluator.evaluate(sample)["ungrouped"] print(result) self.assertEqual(result.tp, 1) self.assertEqual(result.fp, 0) @@ -120,7 +121,7 @@ def test_simple_evaluation_ASSD_negative(self): sample = SemanticPair(b, a) - evaluator = Panoptic_Evaluator( + evaluator = Panoptica_Evaluator( expected_input=SemanticPair, instance_approximator=ConnectedComponentsInstanceApproximator(), instance_matcher=NaiveThresholdMatching( @@ -129,7 +130,7 @@ def test_simple_evaluation_ASSD_negative(self): ), ) - result, debug_data = evaluator.evaluate(sample) + result, debug_data = evaluator.evaluate(sample)["ungrouped"] print(result) self.assertEqual(result.tp, 0) self.assertEqual(result.fp, 1) @@ -145,13 +146,13 @@ def test_pred_empty(self): sample = SemanticPair(b, a) - evaluator = Panoptic_Evaluator( + evaluator = Panoptica_Evaluator( expected_input=SemanticPair, instance_approximator=ConnectedComponentsInstanceApproximator(), instance_matcher=NaiveThresholdMatching(), ) - result, debug_data = evaluator.evaluate(sample) + result, debug_data = evaluator.evaluate(sample)["ungrouped"] print(result) self.assertEqual(result.tp, 0) self.assertEqual(result.fp, 0) @@ -168,13 +169,13 @@ def test_ref_empty(self): sample = SemanticPair(b, a) - evaluator = Panoptic_Evaluator( + evaluator = Panoptica_Evaluator( expected_input=SemanticPair, instance_approximator=ConnectedComponentsInstanceApproximator(), instance_matcher=NaiveThresholdMatching(), ) - result, debug_data = evaluator.evaluate(sample) + result, debug_data = evaluator.evaluate(sample)["ungrouped"] print(result) self.assertEqual(result.tp, 0) self.assertEqual(result.fp, 1) @@ -191,13 +192,13 @@ def test_both_empty(self): sample = SemanticPair(b, a) - evaluator = Panoptic_Evaluator( + evaluator = Panoptica_Evaluator( expected_input=SemanticPair, instance_approximator=ConnectedComponentsInstanceApproximator(), instance_matcher=NaiveThresholdMatching(), ) - result, debug_data = evaluator.evaluate(sample) + result, debug_data = evaluator.evaluate(sample)["ungrouped"] print(result) self.assertEqual(result.tp, 0) self.assertEqual(result.fp, 0) @@ -229,13 +230,13 @@ def test_dtype_evaluation(self): else: sample = SemanticPair(b, a) - evaluator = Panoptic_Evaluator( + evaluator = Panoptica_Evaluator( expected_input=SemanticPair, instance_approximator=ConnectedComponentsInstanceApproximator(), instance_matcher=NaiveThresholdMatching(), ) - result, debug_data = evaluator.evaluate(sample) + result, debug_data = evaluator.evaluate(sample)["ungrouped"] print(result) self.assertEqual(result.tp, 1) self.assertEqual(result.fp, 0) @@ -250,13 +251,13 @@ def test_simple_evaluation_maximize_matcher(self): sample = SemanticPair(b, a) - evaluator = Panoptic_Evaluator( + evaluator = Panoptica_Evaluator( expected_input=SemanticPair, instance_approximator=ConnectedComponentsInstanceApproximator(), instance_matcher=MaximizeMergeMatching(), ) - result, debug_data = evaluator.evaluate(sample) + result, debug_data = evaluator.evaluate(sample)["ungrouped"] print(result) self.assertEqual(result.tp, 1) self.assertEqual(result.fp, 0) @@ -272,13 +273,13 @@ def test_simple_evaluation_maximize_matcher_overlaptwo(self): sample = SemanticPair(b, a) - evaluator = Panoptic_Evaluator( + evaluator = Panoptica_Evaluator( expected_input=SemanticPair, instance_approximator=ConnectedComponentsInstanceApproximator(), instance_matcher=MaximizeMergeMatching(), ) - result, debug_data = evaluator.evaluate(sample) + result, debug_data = evaluator.evaluate(sample)["ungrouped"] print(result) self.assertEqual(result.tp, 1) self.assertEqual(result.fp, 0) @@ -296,13 +297,13 @@ def test_simple_evaluation_maximize_matcher_overlap(self): sample = SemanticPair(b, a) - evaluator = Panoptic_Evaluator( + evaluator = Panoptica_Evaluator( expected_input=SemanticPair, instance_approximator=ConnectedComponentsInstanceApproximator(), instance_matcher=MaximizeMergeMatching(), ) - result, debug_data = evaluator.evaluate(sample) + result, debug_data = evaluator.evaluate(sample)["ungrouped"] print(result) self.assertEqual(result.tp, 1) self.assertEqual(result.fp, 1) @@ -310,3 +311,48 @@ def test_simple_evaluation_maximize_matcher_overlap(self): self.assertAlmostEqual(result.pq, 0.56666666) self.assertAlmostEqual(result.rq, 0.66666666) self.assertAlmostEqual(result.sq_dsc, 0.9189189189189) + + def test_single_instance_mode(self): + a = np.zeros([50, 50], dtype=np.uint16) + b = a.copy().astype(a.dtype) + a[20:40, 10:20] = 5 + b[20:35, 10:20] = 5 + + sample = SemanticPair(b, a) + + evaluator = Panoptica_Evaluator( + expected_input=SemanticPair, + instance_approximator=ConnectedComponentsInstanceApproximator(), + instance_matcher=NaiveThresholdMatching(), + segmentation_class_groups=SegmentationClassGroups({"organ": (5, True)}), + ) + + result, debug_data = evaluator.evaluate(sample)["organ"] + print(result) + self.assertEqual(result.tp, 1) + self.assertEqual(result.fp, 0) + self.assertEqual(result.sq, 0.75) + self.assertEqual(result.pq, 0.75) + + def test_single_instance_mode_nooverlap(self): + a = np.zeros([50, 50], dtype=np.uint16) + b = a.copy().astype(a.dtype) + a[20:40, 10:20] = 5 + b[5:15, 30:50] = 5 + + sample = SemanticPair(b, a) + + evaluator = Panoptica_Evaluator( + expected_input=SemanticPair, + instance_approximator=ConnectedComponentsInstanceApproximator(), + instance_matcher=NaiveThresholdMatching(), + segmentation_class_groups=SegmentationClassGroups({"organ": (5, True)}), + ) + + result, debug_data = evaluator.evaluate(sample)["organ"] + print(result) + self.assertEqual(result.tp, 1) + self.assertEqual(result.fp, 0) + self.assertEqual(result.sq, 0.0) + self.assertEqual(result.pq, 0.0) + self.assertEqual(result.global_bin_dsc, 0.0) diff --git a/unit_tests/test_panoptic_result.py b/unit_tests/test_panoptic_result.py index 96c5c64..28d3554 100644 --- a/unit_tests/test_panoptic_result.py +++ b/unit_tests/test_panoptic_result.py @@ -8,7 +8,7 @@ import numpy as np from panoptica.metrics import Metric -from panoptica.panoptic_result import MetricCouldNotBeComputedException, PanopticaResult +from panoptica.panoptica_result import MetricCouldNotBeComputedException, PanopticaResult from panoptica.utils.edge_case_handling import EdgeCaseHandler, EdgeCaseResult From ea3fd6ff775b4b5a03de6d7b2a0d55650c295329 Mon Sep 17 00:00:00 2001 From: "brainless-bot[bot]" <153751247+brainless-bot[bot]@users.noreply.github.com> Date: Mon, 10 Jun 2024 15:08:49 +0000 Subject: [PATCH 4/4] Autoformat with black --- panoptica/panoptica_evaluator.py | 44 ++++++++++++++++++++------- panoptica/panoptica_result.py | 22 +++++++++++--- panoptica/utils/segmentation_class.py | 34 ++++++++++++++++----- unit_tests/test_labelgroup.py | 4 ++- unit_tests/test_metrics.py | 5 ++- unit_tests/test_panoptic_result.py | 5 ++- 6 files changed, 87 insertions(+), 27 deletions(-) diff --git a/panoptica/panoptica_evaluator.py b/panoptica/panoptica_evaluator.py index 452e509..ca59b98 100644 --- a/panoptica/panoptica_evaluator.py +++ b/panoptica/panoptica_evaluator.py @@ -23,7 +23,9 @@ class Panoptica_Evaluator: def __init__( self, # TODO let users give prediction and reference arr instead of the processing pair, so let this create the processing pair itself - expected_input: Type[SemanticPair] | Type[UnmatchedInstancePair] | Type[MatchedInstancePair] = MatchedInstancePair, + expected_input: ( + Type[SemanticPair] | Type[UnmatchedInstancePair] | Type[MatchedInstancePair] + ) = MatchedInstancePair, instance_approximator: InstanceApproximator | None = None, instance_matcher: InstanceMatchingAlgorithm | None = None, edge_case_handler: EdgeCaseHandler | None = None, @@ -52,9 +54,13 @@ def __init__( self.__segmentation_class_groups = segmentation_class_groups - self.__edge_case_handler = edge_case_handler if edge_case_handler is not None else EdgeCaseHandler() + self.__edge_case_handler = ( + edge_case_handler if edge_case_handler is not None else EdgeCaseHandler() + ) if self.__decision_metric is not None: - assert self.__decision_threshold is not None, "decision metric set but no decision threshold for it" + assert ( + self.__decision_threshold is not None + ), "decision metric set but no decision threshold for it" # self.__log_times = log_times self.__verbose = verbose @@ -63,11 +69,15 @@ def __init__( @measure_time def evaluate( self, - processing_pair: SemanticPair | UnmatchedInstancePair | MatchedInstancePair | PanopticaResult, + processing_pair: ( + SemanticPair | UnmatchedInstancePair | MatchedInstancePair | PanopticaResult + ), result_all: bool = True, verbose: bool | None = None, ) -> dict[str, tuple[PanopticaResult, dict[str, _ProcessingPair]]]: - assert type(processing_pair) == self.__expected_input, f"input not of expected type {self.__expected_input}" + assert ( + type(processing_pair) == self.__expected_input + ), f"input not of expected type {self.__expected_input}" if self.__segmentation_class_groups is None: return { @@ -86,8 +96,12 @@ def evaluate( ) } - self.__segmentation_class_groups.has_defined_labels_for(processing_pair.prediction_arr, raise_error=True) - self.__segmentation_class_groups.has_defined_labels_for(processing_pair.reference_arr, raise_error=True) + self.__segmentation_class_groups.has_defined_labels_for( + processing_pair.prediction_arr, raise_error=True + ) + self.__segmentation_class_groups.has_defined_labels_for( + processing_pair.reference_arr, raise_error=True + ) result_grouped = {} for group_name, label_group in self.__segmentation_class_groups.items(): @@ -99,7 +113,9 @@ def evaluate( single_instance_mode = label_group.single_instance processing_pair_grouped = processing_pair.__class__(prediction_arr=prediction_arr_grouped, reference_arr=reference_arr_grouped) # type: ignore decision_threshold = self.__decision_threshold - if single_instance_mode and not isinstance(processing_pair, MatchedInstancePair): + if single_instance_mode and not isinstance( + processing_pair, MatchedInstancePair + ): processing_pair_grouped = MatchedInstancePair( prediction_arr=processing_pair_grouped.prediction_arr, reference_arr=processing_pair_grouped.reference_arr, @@ -123,7 +139,9 @@ def evaluate( def panoptic_evaluate( - processing_pair: SemanticPair | UnmatchedInstancePair | MatchedInstancePair | PanopticaResult, + processing_pair: ( + SemanticPair | UnmatchedInstancePair | MatchedInstancePair | PanopticaResult + ), instance_approximator: InstanceApproximator | None = None, instance_matcher: InstanceMatchingAlgorithm | None = None, eval_metrics: list[Metric] = [Metric.DSC, Metric.IOU, Metric.ASSD], @@ -179,7 +197,9 @@ def panoptic_evaluate( processing_pair.crop_data() if isinstance(processing_pair, SemanticPair): - assert instance_approximator is not None, "Got SemanticPair but not InstanceApproximator" + assert ( + instance_approximator is not None + ), "Got SemanticPair but not InstanceApproximator" if verbose: print("-- Got SemanticPair, will approximate instances") start = perf_counter() @@ -199,7 +219,9 @@ def panoptic_evaluate( if isinstance(processing_pair, UnmatchedInstancePair): if verbose: print("-- Got UnmatchedInstancePair, will match instances") - assert instance_matcher is not None, "Got UnmatchedInstancePair but not InstanceMatchingAlgorithm" + assert ( + instance_matcher is not None + ), "Got UnmatchedInstancePair but not InstanceMatchingAlgorithm" start = perf_counter() processing_pair = instance_matcher.match_instances( processing_pair, diff --git a/panoptica/panoptica_result.py b/panoptica/panoptica_result.py index 5dfc617..c037e85 100644 --- a/panoptica/panoptica_result.py +++ b/panoptica/panoptica_result.py @@ -270,7 +270,9 @@ def __init__( num_pred_instances=self.num_pred_instances, num_ref_instances=self.num_ref_instances, ) - self._list_metrics[k] = Evaluation_List_Metric(k, empty_list_std, v, is_edge_case, edge_case_result) + self._list_metrics[k] = Evaluation_List_Metric( + k, empty_list_std, v, is_edge_case, edge_case_result + ) def _add_metric( self, @@ -339,13 +341,19 @@ def __str__(self) -> str: return text def to_dict(self) -> dict: - return {k: getattr(self, v.id) for k, v in self._evaluation_metrics.items() if (v._error == False and v._was_calculated)} + return { + k: getattr(self, v.id) + for k, v in self._evaluation_metrics.items() + if (v._error == False and v._was_calculated) + } def get_list_metric(self, metric: Metric, mode: MetricMode): if metric in self._list_metrics: return self._list_metrics[metric][mode] else: - raise MetricCouldNotBeComputedException(f"{metric} could not be found, have you set it in eval_metrics during evaluation?") + raise MetricCouldNotBeComputedException( + f"{metric} could not be found, have you set it in eval_metrics during evaluation?" + ) def _calc_metric(self, metric_name: str, supress_error: bool = False): if metric_name in self._evaluation_metrics: @@ -361,7 +369,9 @@ def _calc_metric(self, metric_name: str, supress_error: bool = False): self._evaluation_metrics[metric_name]._was_calculated = True return value else: - raise MetricCouldNotBeComputedException(f"could not find metric with name {metric_name}") + raise MetricCouldNotBeComputedException( + f"could not find metric with name {metric_name}" + ) def __getattribute__(self, __name: str) -> Any: attr = None @@ -374,7 +384,9 @@ def __getattribute__(self, __name: str) -> Any: raise e if attr is None: if self._evaluation_metrics[__name]._error: - raise MetricCouldNotBeComputedException(f"Requested metric {__name} that could not be computed") + raise MetricCouldNotBeComputedException( + f"Requested metric {__name} that could not be computed" + ) elif not self._evaluation_metrics[__name]._was_calculated: value = self._calc_metric(__name) setattr(self, __name, value) diff --git a/panoptica/utils/segmentation_class.py b/panoptica/utils/segmentation_class.py index a9d4f47..4535ad3 100644 --- a/panoptica/utils/segmentation_class.py +++ b/panoptica/utils/segmentation_class.py @@ -17,12 +17,18 @@ def __init__( """ if isinstance(value_labels, int): value_labels = [value_labels] - assert len(value_labels) >= 1, f"You tried to define a LabelGroup without any specified labels, got {value_labels}" + assert ( + len(value_labels) >= 1 + ), f"You tried to define a LabelGroup without any specified labels, got {value_labels}" self.__value_labels = value_labels - assert np.all([v > 0 for v in self.__value_labels]), f"Given value labels are not >0, got {value_labels}" + assert np.all( + [v > 0 for v in self.__value_labels] + ), f"Given value labels are not >0, got {value_labels}" self.__single_instance = single_instance if self.__single_instance: - assert len(value_labels) == 1, f"single_instance set to True, but got more than one label for this group, got {value_labels}" + assert ( + len(value_labels) == 1 + ), f"single_instance set to True, but got more than one label for this group, got {value_labels}" @property def value_labels(self) -> list[int]: @@ -69,25 +75,37 @@ def __init__( # maps name of group to the group itself if isinstance(groups, list): - self.__group_dictionary = {f"group_{idx}": g for idx, g in enumerate(groups)} + self.__group_dictionary = { + f"group_{idx}": g for idx, g in enumerate(groups) + } else: # transform dict into list of LabelGroups for i, g in groups.items(): name_lower = str(i).lower() if isinstance(g, LabelGroup): - self.__group_dictionary[name_lower] = LabelGroup(g.value_labels, g.single_instance) + self.__group_dictionary[name_lower] = LabelGroup( + g.value_labels, g.single_instance + ) else: self.__group_dictionary[name_lower] = LabelGroup(g[0], g[1]) # needs to check that each label is accounted for exactly ONCE - labels = [value_label for lg in self.__group_dictionary.values() for value_label in lg.value_labels] + labels = [ + value_label + for lg in self.__group_dictionary.values() + for value_label in lg.value_labels + ] duplicates = list_duplicates(labels) if len(duplicates) > 0: - raise AssertionError(f"The same label was assigned to two different labelgroups, got {str(self)}") + raise AssertionError( + f"The same label was assigned to two different labelgroups, got {str(self)}" + ) self.__labels = labels - def has_defined_labels_for(self, arr: np.ndarray | list[int], raise_error: bool = False): + def has_defined_labels_for( + self, arr: np.ndarray | list[int], raise_error: bool = False + ): if isinstance(arr, list): arr_labels = arr else: diff --git a/unit_tests/test_labelgroup.py b/unit_tests/test_labelgroup.py index 292629a..4aee9f3 100644 --- a/unit_tests/test_labelgroup.py +++ b/unit_tests/test_labelgroup.py @@ -82,7 +82,9 @@ def test_segmentationclassgroup_easy(self): self.assertTrue(isinstance(lg, LabelGroup)) def test_segmentationclassgroup_decarations(self): - classgroups = SegmentationClassGroups(groups=[LabelGroup(i) for i in range(1, 5)]) + classgroups = SegmentationClassGroups( + groups=[LabelGroup(i) for i in range(1, 5)] + ) keys = classgroups.keys() for i in range(1, 5): diff --git a/unit_tests/test_metrics.py b/unit_tests/test_metrics.py index 621407b..2800187 100644 --- a/unit_tests/test_metrics.py +++ b/unit_tests/test_metrics.py @@ -8,7 +8,10 @@ import numpy as np from panoptica.metrics import Metric -from panoptica.panoptica_result import MetricCouldNotBeComputedException, PanopticaResult +from panoptica.panoptica_result import ( + MetricCouldNotBeComputedException, + PanopticaResult, +) from panoptica.utils.edge_case_handling import EdgeCaseHandler, EdgeCaseResult diff --git a/unit_tests/test_panoptic_result.py b/unit_tests/test_panoptic_result.py index 28d3554..e88b2b2 100644 --- a/unit_tests/test_panoptic_result.py +++ b/unit_tests/test_panoptic_result.py @@ -8,7 +8,10 @@ import numpy as np from panoptica.metrics import Metric -from panoptica.panoptica_result import MetricCouldNotBeComputedException, PanopticaResult +from panoptica.panoptica_result import ( + MetricCouldNotBeComputedException, + PanopticaResult, +) from panoptica.utils.edge_case_handling import EdgeCaseHandler, EdgeCaseResult