diff --git a/examples/example_spine_instance_config.py b/examples/example_spine_instance_config.py index 2855add..56c61e4 100644 --- a/examples/example_spine_instance_config.py +++ b/examples/example_spine_instance_config.py @@ -10,7 +10,9 @@ reference_mask = read_nifti(directory + "/spine_seg/matched_instance/ref.nii.gz") prediction_mask = read_nifti(directory + "/spine_seg/matched_instance/pred.nii.gz") -evaluator = Panoptica_Evaluator.load_from_config_name("panoptica_evaluator_unmatched_instance") +evaluator = Panoptica_Evaluator.load_from_config_name( + "panoptica_evaluator_unmatched_instance" +) with cProfile.Profile() as pr: diff --git a/examples/example_spine_semantic.py b/examples/example_spine_semantic.py index 2793592..a2e5a32 100644 --- a/examples/example_spine_semantic.py +++ b/examples/example_spine_semantic.py @@ -25,7 +25,9 @@ with cProfile.Profile() as pr: if __name__ == "__main__": - result, debug_data = evaluator.evaluate(prediction_mask, reference_mask)["ungrouped"] + result, debug_data = evaluator.evaluate(prediction_mask, reference_mask)[ + "ungrouped" + ] print(result) pr.dump_stats(directory + "/semantic_example.log") diff --git a/panoptica/instance_approximator.py b/panoptica/instance_approximator.py index 957b934..f8e061f 100644 --- a/panoptica/instance_approximator.py +++ b/panoptica/instance_approximator.py @@ -58,7 +58,9 @@ def _approximate_instances( pass def _yaml_repr(cls, node) -> dict: - raise NotImplementedError(f"Tried to get yaml representation of abstract class {cls.__name__}") + raise NotImplementedError( + f"Tried to get yaml representation of abstract class {cls.__name__}" + ) return {} def approximate_instances( @@ -146,7 +148,9 @@ def _approximate_instances( """ cca_backend = self.cca_backend if cca_backend is None: - cca_backend = CCABackend.cc3d if semantic_pair.n_dim >= 3 else CCABackend.scipy + cca_backend = ( + CCABackend.cc3d if semantic_pair.n_dim >= 3 else CCABackend.scipy + ) assert cca_backend is not None empty_prediction = len(semantic_pair._pred_labels) == 0 diff --git a/panoptica/instance_matcher.py b/panoptica/instance_matcher.py index 6c18249..5bedd9e 100644 --- a/panoptica/instance_matcher.py +++ b/panoptica/instance_matcher.py @@ -81,7 +81,9 @@ def match_instances( return map_instance_labels(unmatched_instance_pair.copy(), instance_labelmap) def _yaml_repr(cls, node) -> dict: - raise NotImplementedError(f"Tried to get yaml representation of abstract class {cls.__name__}") + raise NotImplementedError( + f"Tried to get yaml representation of abstract class {cls.__name__}" + ) return {} @@ -199,13 +201,20 @@ def _match_instances( unmatched_instance_pair.prediction_arr, unmatched_instance_pair.reference_arr, ) - mm_pairs = _calc_matching_metric_of_overlapping_labels(pred_arr, ref_arr, ref_labels, matching_metric=self._matching_metric) + mm_pairs = _calc_matching_metric_of_overlapping_labels( + pred_arr, ref_arr, ref_labels, matching_metric=self._matching_metric + ) # Loop through matched instances to compute PQ components for matching_score, (ref_label, pred_label) in mm_pairs: - if labelmap.contains_or(pred_label, ref_label) and not self._allow_many_to_one: + if ( + labelmap.contains_or(pred_label, ref_label) + and not self._allow_many_to_one + ): continue # -> doesnt make speed difference - if self._matching_metric.score_beats_threshold(matching_score, self._matching_threshold): + if self._matching_metric.score_beats_threshold( + matching_score, self._matching_threshold + ): # Match found, increment true positive count and collect IoU and Dice values labelmap.add_labelmap_entry(pred_label, ref_label) # map label ref_idx to pred_idx @@ -296,7 +305,9 @@ def _match_instances( if new_score > score_ref[ref_label]: labelmap.add_labelmap_entry(pred_label, ref_label) score_ref[ref_label] = new_score - elif self._matching_metric.score_beats_threshold(matching_score, self._matching_threshold): + elif self._matching_metric.score_beats_threshold( + matching_score, self._matching_threshold + ): # Match found, increment true positive count and collect IoU and Dice values labelmap.add_labelmap_entry(pred_label, ref_label) score_ref[ref_label] = matching_score diff --git a/panoptica/panoptica_evaluator.py b/panoptica/panoptica_evaluator.py index 7423592..fd5617c 100644 --- a/panoptica/panoptica_evaluator.py +++ b/panoptica/panoptica_evaluator.py @@ -54,9 +54,13 @@ def __init__( self.__segmentation_class_groups = segmentation_class_groups - self.__edge_case_handler = edge_case_handler if edge_case_handler is not None else EdgeCaseHandler() + self.__edge_case_handler = ( + edge_case_handler if edge_case_handler is not None else EdgeCaseHandler() + ) if self.__decision_metric is not None: - assert self.__decision_threshold is not None, "decision metric set but no decision threshold for it" + assert ( + self.__decision_threshold is not None + ), "decision metric set but no decision threshold for it" # self.__log_times = log_times self.__verbose = verbose @@ -86,7 +90,9 @@ def evaluate( verbose: bool | None = None, ) -> dict[str, tuple[PanopticaResult, dict[str, _ProcessingPair]]]: processing_pair = self.__expected_input(prediction_arr, reference_arr) - assert isinstance(processing_pair, self.__expected_input.value), f"input not of expected type {self.__expected_input}" + assert isinstance( + processing_pair, self.__expected_input.value + ), f"input not of expected type {self.__expected_input}" if self.__segmentation_class_groups is None: return { @@ -105,8 +111,12 @@ def evaluate( ) } - self.__segmentation_class_groups.has_defined_labels_for(processing_pair.prediction_arr, raise_error=True) - self.__segmentation_class_groups.has_defined_labels_for(processing_pair.reference_arr, raise_error=True) + self.__segmentation_class_groups.has_defined_labels_for( + processing_pair.prediction_arr, raise_error=True + ) + self.__segmentation_class_groups.has_defined_labels_for( + processing_pair.reference_arr, raise_error=True + ) result_grouped = {} for group_name, label_group in self.__segmentation_class_groups.items(): @@ -118,7 +128,9 @@ def evaluate( single_instance_mode = label_group.single_instance processing_pair_grouped = processing_pair.__class__(prediction_arr=prediction_arr_grouped, reference_arr=reference_arr_grouped) # type: ignore decision_threshold = self.__decision_threshold - if single_instance_mode and not isinstance(processing_pair, MatchedInstancePair): + if single_instance_mode and not isinstance( + processing_pair, MatchedInstancePair + ): processing_pair_grouped = MatchedInstancePair( prediction_arr=processing_pair_grouped.prediction_arr, reference_arr=processing_pair_grouped.reference_arr, @@ -142,7 +154,9 @@ def evaluate( def panoptic_evaluate( - processing_pair: SemanticPair | UnmatchedInstancePair | MatchedInstancePair | PanopticaResult, + processing_pair: ( + SemanticPair | UnmatchedInstancePair | MatchedInstancePair | PanopticaResult + ), instance_approximator: InstanceApproximator | None = None, instance_matcher: InstanceMatchingAlgorithm | None = None, eval_metrics: list[Metric] = [Metric.DSC, Metric.IOU, Metric.ASSD], @@ -198,7 +212,9 @@ def panoptic_evaluate( processing_pair.crop_data() if isinstance(processing_pair, SemanticPair): - assert instance_approximator is not None, "Got SemanticPair but not InstanceApproximator" + assert ( + instance_approximator is not None + ), "Got SemanticPair but not InstanceApproximator" if verbose: print("-- Got SemanticPair, will approximate instances") start = perf_counter() @@ -218,7 +234,9 @@ def panoptic_evaluate( if isinstance(processing_pair, UnmatchedInstancePair): if verbose: print("-- Got UnmatchedInstancePair, will match instances") - assert instance_matcher is not None, "Got UnmatchedInstancePair but not InstanceMatchingAlgorithm" + assert ( + instance_matcher is not None + ), "Got UnmatchedInstancePair but not InstanceMatchingAlgorithm" start = perf_counter() processing_pair = instance_matcher.match_instances( processing_pair, diff --git a/panoptica/utils/config.py b/panoptica/utils/config.py index 3afc165..9545cae 100644 --- a/panoptica/utils/config.py +++ b/panoptica/utils/config.py @@ -71,13 +71,17 @@ def save_from_object(cls, obj: object, file: str | Path): @classmethod def load(cls, file: str | Path, registered_class=None): data = _load_yaml(file, registered_class) - assert isinstance(data, dict), f"The config at {file} is registered to a class. Use load_as_object() instead" + assert isinstance( + data, dict + ), f"The config at {file} is registered to a class. Use load_as_object() instead" return Configuration(data, registered_class=registered_class) @classmethod def load_as_object(cls, file: str | Path, registered_class=None): data = _load_yaml(file, registered_class) - assert not isinstance(data, dict), f"The config at {file} is not registered to a class. Use load() instead" + assert not isinstance( + data, dict + ), f"The config at {file} is not registered to a class. Use load() instead" return data def save(self, out_file: str | Path): @@ -148,7 +152,9 @@ def _register_permanently(cls): @classmethod def load_from_config(cls, path: str | Path): obj = _load_from_config(cls, path) - assert isinstance(obj, cls), f"loaded object was not of the correct class, expected {cls.__name__} but got {type(obj)}" + assert isinstance( + obj, cls + ), f"loaded object was not of the correct class, expected {cls.__name__} but got {type(obj)}" return obj @classmethod @@ -163,7 +169,9 @@ def save_to_config(self, path: str | Path): @classmethod def to_yaml(cls, representer, node): # cls._register_permanently() - assert hasattr(cls, "_yaml_repr"), f"Class {cls.__name__} has no _yaml_repr(cls, node) defined" + assert hasattr( + cls, "_yaml_repr" + ), f"Class {cls.__name__} has no _yaml_repr(cls, node) defined" return representer.represent_mapping("!" + cls.__name__, cls._yaml_repr(node)) @classmethod diff --git a/panoptica/utils/constants.py b/panoptica/utils/constants.py index 2889ad6..ecb7f63 100644 --- a/panoptica/utils/constants.py +++ b/panoptica/utils/constants.py @@ -1,5 +1,10 @@ from enum import Enum, auto -from panoptica.utils.config import _register_class_to_yaml, _load_from_config, _load_from_config_name, _save_to_config +from panoptica.utils.config import ( + _register_class_to_yaml, + _load_from_config, + _load_from_config_name, + _save_to_config, +) from pathlib import Path import numpy as np diff --git a/panoptica/utils/edge_case_handling.py b/panoptica/utils/edge_case_handling.py index dfe2ca2..1a865f0 100644 --- a/panoptica/utils/edge_case_handling.py +++ b/panoptica/utils/edge_case_handling.py @@ -41,12 +41,26 @@ def __init__( self._default_result = default_result self._edgecase_dict: dict[EdgeCaseZeroTP, EdgeCaseResult] = {} - self._edgecase_dict[EdgeCaseZeroTP.EMPTY_PRED] = empty_prediction_result if empty_prediction_result is not None else default_result - self._edgecase_dict[EdgeCaseZeroTP.EMPTY_REF] = empty_reference_result if empty_reference_result is not None else default_result - self._edgecase_dict[EdgeCaseZeroTP.NO_INSTANCES] = no_instances_result if no_instances_result is not None else default_result - self._edgecase_dict[EdgeCaseZeroTP.NORMAL] = normal if normal is not None else default_result + self._edgecase_dict[EdgeCaseZeroTP.EMPTY_PRED] = ( + empty_prediction_result + if empty_prediction_result is not None + else default_result + ) + self._edgecase_dict[EdgeCaseZeroTP.EMPTY_REF] = ( + empty_reference_result + if empty_reference_result is not None + else default_result + ) + self._edgecase_dict[EdgeCaseZeroTP.NO_INSTANCES] = ( + no_instances_result if no_instances_result is not None else default_result + ) + self._edgecase_dict[EdgeCaseZeroTP.NORMAL] = ( + normal if normal is not None else default_result + ) - def __call__(self, tp: int, num_pred_instances, num_ref_instances) -> tuple[bool, float | None]: + def __call__( + self, tp: int, num_pred_instances, num_ref_instances + ) -> tuple[bool, float | None]: if tp != 0: return False, EdgeCaseResult.NONE.value # @@ -117,7 +131,9 @@ def __init__( }, empty_list_std: EdgeCaseResult = EdgeCaseResult.NAN, ) -> None: - self.__listmetric_zeroTP_handling: dict[Metric, MetricZeroTPEdgeCaseHandling] = listmetric_zeroTP_handling + self.__listmetric_zeroTP_handling: dict[ + Metric, MetricZeroTPEdgeCaseHandling + ] = listmetric_zeroTP_handling self.__empty_list_std: EdgeCaseResult = empty_list_std def handle_zero_tp( @@ -130,7 +146,9 @@ def handle_zero_tp( if tp != 0: return False, EdgeCaseResult.NONE.value if metric not in self.__listmetric_zeroTP_handling: - raise NotImplementedError(f"Metric {metric} encountered zero TP, but no edge handling available") + raise NotImplementedError( + f"Metric {metric} encountered zero TP, but no edge handling available" + ) return self.__listmetric_zeroTP_handling[metric]( tp=tp, @@ -167,7 +185,9 @@ def _yaml_repr(cls, node) -> dict: print() # print(handler.get_metric_zero_tp_handle(ListMetric.IOU)) - r = handler.handle_zero_tp(Metric.IOU, tp=0, num_pred_instances=1, num_ref_instances=1) + r = handler.handle_zero_tp( + Metric.IOU, tp=0, num_pred_instances=1, num_ref_instances=1 + ) print(r) iou_test = MetricZeroTPEdgeCaseHandling( diff --git a/panoptica/utils/filepath.py b/panoptica/utils/filepath.py index f668b75..a0d8c25 100644 --- a/panoptica/utils/filepath.py +++ b/panoptica/utils/filepath.py @@ -4,7 +4,9 @@ from pathlib import Path -def search_path(basepath: str | Path, query: str, verbose: bool = False, suppress: bool = False) -> list[Path]: +def search_path( + basepath: str | Path, query: str, verbose: bool = False, suppress: bool = False +) -> list[Path]: """Searches from basepath with query Args: basepath: ground path to look into @@ -16,7 +18,9 @@ def search_path(basepath: str | Path, query: str, verbose: bool = False, suppres All found paths """ basepath = str(basepath) - assert os.path.exists(basepath), f"basepath for search_path() doesnt exist, got {basepath}" + assert os.path.exists( + basepath + ), f"basepath for search_path() doesnt exist, got {basepath}" if not basepath.endswith("/"): basepath += "/" print(f"search_path: in {basepath}{query}") if verbose else None @@ -28,9 +32,16 @@ def search_path(basepath: str | Path, query: str, verbose: bool = False, suppres # Find config path def config_by_name(name: str) -> Path: - directory = Path(__file__.replace("////", "/").replace("\\\\", "/").replace("//", "/").replace("\\", "/")).parent.parent + directory = Path( + __file__.replace("////", "/") + .replace("\\\\", "/") + .replace("//", "/") + .replace("\\", "/") + ).parent.parent if not name.endswith(".yaml"): name += ".yaml" p = search_path(directory, query=f"**/{name}", suppress=True) - assert len(p) == 1, f"Did not find exactly one config yaml with name {name} in directory {directory}, got {p}" + assert ( + len(p) == 1 + ), f"Did not find exactly one config yaml with name {name} in directory {directory}, got {p}" return p[0] diff --git a/panoptica/utils/instancelabelmap.py b/panoptica/utils/instancelabelmap.py index 7480875..16fd33c 100644 --- a/panoptica/utils/instancelabelmap.py +++ b/panoptica/utils/instancelabelmap.py @@ -13,7 +13,9 @@ def add_labelmap_entry(self, pred_labels: list[int] | int, ref_label: int): if not isinstance(pred_labels, list): pred_labels = [pred_labels] assert isinstance(ref_label, int), "add_labelmap_entry: got no int as ref_label" - assert np.all([isinstance(r, int) for r in pred_labels]), "add_labelmap_entry: got no int as pred_label" + assert np.all( + [isinstance(r, int) for r in pred_labels] + ), "add_labelmap_entry: got no int as pred_label" for p in pred_labels: if p in self.labelmap and self.labelmap[p] != ref_label: raise Exception( @@ -30,12 +32,16 @@ def contains_pred(self, pred_label: int): def contains_ref(self, ref_label: int): return ref_label in self.labelmap.values() - def contains_and(self, pred_label: int | None = None, ref_label: int | None = None) -> bool: + def contains_and( + self, pred_label: int | None = None, ref_label: int | None = None + ) -> bool: pred_in = True if pred_label is None else pred_label in self.labelmap ref_in = True if ref_label is None else ref_label in self.labelmap.values() return pred_in and ref_in - def contains_or(self, pred_label: int | None = None, ref_label: int | None = None) -> bool: + def contains_or( + self, pred_label: int | None = None, ref_label: int | None = None + ) -> bool: pred_in = True if pred_label is None else pred_label in self.labelmap ref_in = True if ref_label is None else ref_label in self.labelmap.values() return pred_in or ref_in @@ -47,7 +53,9 @@ def __str__(self) -> str: return str( list( [ - str(tuple(k for k in self.labelmap.keys() if self.labelmap[k] == v)) + " -> " + str(v) + str(tuple(k for k in self.labelmap.keys() if self.labelmap[k] == v)) + + " -> " + + str(v) for v in set(self.labelmap.values()) ] ) diff --git a/panoptica/utils/label_group.py b/panoptica/utils/label_group.py index 6757500..430e2dc 100644 --- a/panoptica/utils/label_group.py +++ b/panoptica/utils/label_group.py @@ -20,12 +20,18 @@ def __init__( """ if isinstance(value_labels, int): value_labels = [value_labels] - assert len(value_labels) >= 1, f"You tried to define a LabelGroup without any specified labels, got {value_labels}" + assert ( + len(value_labels) >= 1 + ), f"You tried to define a LabelGroup without any specified labels, got {value_labels}" self.__value_labels = value_labels - assert np.all([v > 0 for v in self.__value_labels]), f"Given value labels are not >0, got {value_labels}" + assert np.all( + [v > 0 for v in self.__value_labels] + ), f"Given value labels are not >0, got {value_labels}" self.__single_instance = single_instance if self.__single_instance: - assert len(value_labels) == 1, f"single_instance set to True, but got more than one label for this group, got {value_labels}" + assert ( + len(value_labels) == 1 + ), f"single_instance set to True, but got more than one label for this group, got {value_labels}" LabelGroup._register_permanently() @@ -65,7 +71,10 @@ def __repr__(self) -> str: @classmethod def _yaml_repr(cls, node): - return {"value_labels": node.value_labels, "single_instance": node.single_instance} + return { + "value_labels": node.value_labels, + "single_instance": node.single_instance, + } # @classmethod # def to_yaml(cls, representer, node): diff --git a/panoptica/utils/processing_pair.py b/panoptica/utils/processing_pair.py index 938a8bc..c64b1e9 100644 --- a/panoptica/utils/processing_pair.py +++ b/panoptica/utils/processing_pair.py @@ -320,5 +320,7 @@ class InputType(_Enum_Compare): UNMATCHED_INSTANCE = UnmatchedInstancePair MATCHED_INSTANCE = MatchedInstancePair - def __call__(self, prediction_arr: np.ndarray, reference_arr: np.ndarray) -> _ProcessingPair: + def __call__( + self, prediction_arr: np.ndarray, reference_arr: np.ndarray + ) -> _ProcessingPair: return self.value(prediction_arr, reference_arr) diff --git a/panoptica/utils/segmentation_class.py b/panoptica/utils/segmentation_class.py index 24e7568..550b6aa 100644 --- a/panoptica/utils/segmentation_class.py +++ b/panoptica/utils/segmentation_class.py @@ -15,24 +15,36 @@ def __init__( # maps name of group to the group itself if isinstance(groups, list): - self.__group_dictionary = {f"group_{idx}": g for idx, g in enumerate(groups)} + self.__group_dictionary = { + f"group_{idx}": g for idx, g in enumerate(groups) + } elif isinstance(groups, dict): # transform dict into list of LabelGroups for i, g in groups.items(): name_lower = str(i).lower() if isinstance(g, LabelGroup): - self.__group_dictionary[name_lower] = LabelGroup(g.value_labels, g.single_instance) + self.__group_dictionary[name_lower] = LabelGroup( + g.value_labels, g.single_instance + ) else: self.__group_dictionary[name_lower] = LabelGroup(g[0], g[1]) # needs to check that each label is accounted for exactly ONCE - labels = [value_label for lg in self.__group_dictionary.values() for value_label in lg.value_labels] + labels = [ + value_label + for lg in self.__group_dictionary.values() + for value_label in lg.value_labels + ] duplicates = list_duplicates(labels) if len(duplicates) > 0: - raise AssertionError(f"The same label was assigned to two different labelgroups, got {str(self)}") + raise AssertionError( + f"The same label was assigned to two different labelgroups, got {str(self)}" + ) self.__labels = labels - def has_defined_labels_for(self, arr: np.ndarray | list[int], raise_error: bool = False): + def has_defined_labels_for( + self, arr: np.ndarray | list[int], raise_error: bool = False + ): if isinstance(arr, list): arr_labels = arr else: diff --git a/unit_tests/test_config.py b/unit_tests/test_config.py index 5d516db..59c0b46 100644 --- a/unit_tests/test_config.py +++ b/unit_tests/test_config.py @@ -14,7 +14,12 @@ ) from panoptica.utils.segmentation_class import SegmentationClassGroups, LabelGroup from panoptica.utils.constants import CCABackend -from panoptica.utils.edge_case_handling import EdgeCaseResult, EdgeCaseZeroTP, MetricZeroTPEdgeCaseHandling, EdgeCaseHandler +from panoptica.utils.edge_case_handling import ( + EdgeCaseResult, + EdgeCaseZeroTP, + MetricZeroTPEdgeCaseHandling, + EdgeCaseHandler, +) from panoptica import ConnectedComponentsInstanceApproximator, NaiveThresholdMatching from pathlib import Path import numpy as np @@ -90,7 +95,9 @@ def test_InstanceApproximator_config(self): print(t) print() t.save_to_config(test_file) - d: ConnectedComponentsInstanceApproximator = ConnectedComponentsInstanceApproximator.load_from_config(test_file) + d: ConnectedComponentsInstanceApproximator = ( + ConnectedComponentsInstanceApproximator.load_from_config(test_file) + ) os.remove(test_file) self.assertEqual(d.cca_backend, t.cca_backend) @@ -99,11 +106,17 @@ def test_NaiveThresholdMatching_config(self): for mm in [Metric.DSC, Metric.IOU, Metric.ASSD]: for mt in [0.1, 0.4, 0.5, 0.8, 1.0]: for amto in [False, True]: - t = NaiveThresholdMatching(matching_metric=mm, matching_threshold=mt, allow_many_to_one=amto) + t = NaiveThresholdMatching( + matching_metric=mm, + matching_threshold=mt, + allow_many_to_one=amto, + ) print(t) print() t.save_to_config(test_file) - d: NaiveThresholdMatching = NaiveThresholdMatching.load_from_config(test_file) + d: NaiveThresholdMatching = NaiveThresholdMatching.load_from_config( + test_file + ) os.remove(test_file) self.assertEqual(d._allow_many_to_one, t._allow_many_to_one) @@ -118,7 +131,9 @@ def test_MetricZeroTPEdgeCaseHandling_config(self): print(t) print() t.save_to_config(test_file) - d: MetricZeroTPEdgeCaseHandling = MetricZeroTPEdgeCaseHandling.load_from_config(test_file) + d: MetricZeroTPEdgeCaseHandling = ( + MetricZeroTPEdgeCaseHandling.load_from_config(test_file) + ) os.remove(test_file) for k, v in t._edgecase_dict.items(): diff --git a/unit_tests/test_labelgroup.py b/unit_tests/test_labelgroup.py index 292629a..4aee9f3 100644 --- a/unit_tests/test_labelgroup.py +++ b/unit_tests/test_labelgroup.py @@ -82,7 +82,9 @@ def test_segmentationclassgroup_easy(self): self.assertTrue(isinstance(lg, LabelGroup)) def test_segmentationclassgroup_decarations(self): - classgroups = SegmentationClassGroups(groups=[LabelGroup(i) for i in range(1, 5)]) + classgroups = SegmentationClassGroups( + groups=[LabelGroup(i) for i in range(1, 5)] + ) keys = classgroups.keys() for i in range(1, 5):