From bed8b2e4acad0638d6009e5aee7d7a7b827df18f Mon Sep 17 00:00:00 2001 From: iback Date: Thu, 1 Aug 2024 12:07:05 +0000 Subject: [PATCH 01/13] refined segmentation class definition, started to go into config files --- examples/example_spine_instance.py | 1 + panoptica/base_configs/arya.yaml | 1 + panoptica/base_configs/config_edgecase.yaml | 0 panoptica/base_configs/test.yaml | 4 + panoptica/base_configs/test_out.yaml | 4 + panoptica/panoptica_evaluator.py | 46 +++-------- panoptica/utils/config.py | 83 +++++++++++++++++++ panoptica/utils/label_group.py | 59 ++++++++++++++ panoptica/utils/segmentation_class.py | 90 ++------------------- pyproject.toml | 1 + unit_tests/test_panoptic_evaluator.py | 2 +- unit_tests/test_panoptic_result.py | 2 +- 12 files changed, 175 insertions(+), 118 deletions(-) create mode 100644 panoptica/base_configs/arya.yaml create mode 100644 panoptica/base_configs/config_edgecase.yaml create mode 100644 panoptica/base_configs/test.yaml create mode 100644 panoptica/base_configs/test_out.yaml create mode 100644 panoptica/utils/config.py create mode 100644 panoptica/utils/label_group.py diff --git a/examples/example_spine_instance.py b/examples/example_spine_instance.py index f10cb34..2e5aaca 100644 --- a/examples/example_spine_instance.py +++ b/examples/example_spine_instance.py @@ -28,6 +28,7 @@ ), decision_metric=Metric.DSC, decision_threshold=0.5, + log_times=True, ) diff --git a/panoptica/base_configs/arya.yaml b/panoptica/base_configs/arya.yaml new file mode 100644 index 0000000..bbd5e8c --- /dev/null +++ b/panoptica/base_configs/arya.yaml @@ -0,0 +1 @@ +!Person {age: 18, name: arya} diff --git a/panoptica/base_configs/config_edgecase.yaml b/panoptica/base_configs/config_edgecase.yaml new file mode 100644 index 0000000..e69de29 diff --git a/panoptica/base_configs/test.yaml b/panoptica/base_configs/test.yaml new file mode 100644 index 0000000..ea2ab77 --- /dev/null +++ b/panoptica/base_configs/test.yaml @@ -0,0 +1,4 @@ +YAML: + - T: 2 + - E: 3 + - 3: "TEST" \ No newline at end of file diff --git a/panoptica/base_configs/test_out.yaml b/panoptica/base_configs/test_out.yaml new file mode 100644 index 0000000..d503b15 --- /dev/null +++ b/panoptica/base_configs/test_out.yaml @@ -0,0 +1,4 @@ +YAML: +- {T: 2} +- {E: 3} +- {3: TEST} diff --git a/panoptica/panoptica_evaluator.py b/panoptica/panoptica_evaluator.py index ca59b98..810bd9b 100644 --- a/panoptica/panoptica_evaluator.py +++ b/panoptica/panoptica_evaluator.py @@ -23,9 +23,7 @@ class Panoptica_Evaluator: def __init__( self, # TODO let users give prediction and reference arr instead of the processing pair, so let this create the processing pair itself - expected_input: ( - Type[SemanticPair] | Type[UnmatchedInstancePair] | Type[MatchedInstancePair] - ) = MatchedInstancePair, + expected_input: Type[SemanticPair] | Type[UnmatchedInstancePair] | Type[MatchedInstancePair] = MatchedInstancePair, instance_approximator: InstanceApproximator | None = None, instance_matcher: InstanceMatchingAlgorithm | None = None, edge_case_handler: EdgeCaseHandler | None = None, @@ -36,7 +34,7 @@ def __init__( log_times: bool = False, verbose: bool = False, ) -> None: - """Creates a Panoptic_Evaluator, that saves some parameters to be used for all subsequent evaluations + """Creates a Panoptica_Evaluator, that saves some parameters to be used for all subsequent evaluations Args: expected_input (type, optional): Expected DataPair Input. Defaults to type(MatchedInstancePair). @@ -54,13 +52,9 @@ def __init__( self.__segmentation_class_groups = segmentation_class_groups - self.__edge_case_handler = ( - edge_case_handler if edge_case_handler is not None else EdgeCaseHandler() - ) + self.__edge_case_handler = edge_case_handler if edge_case_handler is not None else EdgeCaseHandler() if self.__decision_metric is not None: - assert ( - self.__decision_threshold is not None - ), "decision metric set but no decision threshold for it" + assert self.__decision_threshold is not None, "decision metric set but no decision threshold for it" # self.__log_times = log_times self.__verbose = verbose @@ -69,15 +63,11 @@ def __init__( @measure_time def evaluate( self, - processing_pair: ( - SemanticPair | UnmatchedInstancePair | MatchedInstancePair | PanopticaResult - ), + processing_pair: SemanticPair | UnmatchedInstancePair | MatchedInstancePair | PanopticaResult, result_all: bool = True, verbose: bool | None = None, ) -> dict[str, tuple[PanopticaResult, dict[str, _ProcessingPair]]]: - assert ( - type(processing_pair) == self.__expected_input - ), f"input not of expected type {self.__expected_input}" + assert type(processing_pair) == self.__expected_input, f"input not of expected type {self.__expected_input}" if self.__segmentation_class_groups is None: return { @@ -96,12 +86,8 @@ def evaluate( ) } - self.__segmentation_class_groups.has_defined_labels_for( - processing_pair.prediction_arr, raise_error=True - ) - self.__segmentation_class_groups.has_defined_labels_for( - processing_pair.reference_arr, raise_error=True - ) + self.__segmentation_class_groups.has_defined_labels_for(processing_pair.prediction_arr, raise_error=True) + self.__segmentation_class_groups.has_defined_labels_for(processing_pair.reference_arr, raise_error=True) result_grouped = {} for group_name, label_group in self.__segmentation_class_groups.items(): @@ -113,9 +99,7 @@ def evaluate( single_instance_mode = label_group.single_instance processing_pair_grouped = processing_pair.__class__(prediction_arr=prediction_arr_grouped, reference_arr=reference_arr_grouped) # type: ignore decision_threshold = self.__decision_threshold - if single_instance_mode and not isinstance( - processing_pair, MatchedInstancePair - ): + if single_instance_mode and not isinstance(processing_pair, MatchedInstancePair): processing_pair_grouped = MatchedInstancePair( prediction_arr=processing_pair_grouped.prediction_arr, reference_arr=processing_pair_grouped.reference_arr, @@ -139,9 +123,7 @@ def evaluate( def panoptic_evaluate( - processing_pair: ( - SemanticPair | UnmatchedInstancePair | MatchedInstancePair | PanopticaResult - ), + processing_pair: SemanticPair | UnmatchedInstancePair | MatchedInstancePair | PanopticaResult, instance_approximator: InstanceApproximator | None = None, instance_matcher: InstanceMatchingAlgorithm | None = None, eval_metrics: list[Metric] = [Metric.DSC, Metric.IOU, Metric.ASSD], @@ -197,9 +179,7 @@ def panoptic_evaluate( processing_pair.crop_data() if isinstance(processing_pair, SemanticPair): - assert ( - instance_approximator is not None - ), "Got SemanticPair but not InstanceApproximator" + assert instance_approximator is not None, "Got SemanticPair but not InstanceApproximator" if verbose: print("-- Got SemanticPair, will approximate instances") start = perf_counter() @@ -219,9 +199,7 @@ def panoptic_evaluate( if isinstance(processing_pair, UnmatchedInstancePair): if verbose: print("-- Got UnmatchedInstancePair, will match instances") - assert ( - instance_matcher is not None - ), "Got UnmatchedInstancePair but not InstanceMatchingAlgorithm" + assert instance_matcher is not None, "Got UnmatchedInstancePair but not InstanceMatchingAlgorithm" start = perf_counter() processing_pair = instance_matcher.match_instances( processing_pair, diff --git a/panoptica/utils/config.py b/panoptica/utils/config.py new file mode 100644 index 0000000..49f32ff --- /dev/null +++ b/panoptica/utils/config.py @@ -0,0 +1,83 @@ +from ruamel.yaml import YAML +from pathlib import Path + +#################### + + +def load_yaml(file: str | Path): + if isinstance(file, str): + file = Path(file) + yaml = YAML(typ="safe") + data = yaml.load(file) + assert isinstance(data, dict) or isinstance(data, object) + return data + + +def save_yaml(data_dict: dict | object, out_file: str | Path, registered_class=None): + if isinstance(out_file, str): + out_file = Path(out_file) + + yaml = YAML(typ="safe") + if registered_class is not None: + yaml.register_class(registered_class) + if isinstance(data_dict, object): + yaml.dump(data_dict, out_file) + else: + yaml.dump([registered_class(*data_dict)], out_file) + else: + yaml.dump(data_dict, out_file) + + +class Person: + name: str + age: int + + def __init__(self, name, age) -> None: + self.name = name + self.age = age + + +#################### + +# TODO split into general config and object configuration (latter saves as object yaml and loads directly as object?) + + +class Configuration: + _data_dict: dict + _registered_class = None + + def __init__(self, data_dict: dict, registered_class=None) -> None: + self._data_dict = data_dict + if registered_class is not None: + self.register_to_class(registered_class) + + def register_to_class(self, cls): + self._registered_class = cls + + @classmethod + def save_from_object(cls, obj: object, file: str | Path): + save_yaml(obj, file, registered_class=type(obj)) + return Configuration.load(file, registered_class=type(obj)) + + @classmethod + def load(cls, file: str | Path, registered_class=None): + data = load_yaml(file) + return Configuration(data, registered_class=registered_class) + + def save(self, out_file: str | Path): + save_yaml(self._data_dict, out_file) + + def cls_object_from_this(self): + assert self._registered_class is not None + self._registered_class(*self._data_dict) + + +if __name__ == "__main__": + c = Configuration.load("/DATA/NAS/ongoing_projects/hendrik/panoptica/repo/panoptica/base_configs/test.yaml") + + c.save("/DATA/NAS/ongoing_projects/hendrik/panoptica/repo/panoptica/base_configs/test_out.yaml") + + arya = Person("arya", 18) + + c = Configuration.save_from_object(arya, "/DATA/NAS/ongoing_projects/hendrik/panoptica/repo/panoptica/base_configs/arya.yaml") + print(c._data_dict.name) diff --git a/panoptica/utils/label_group.py b/panoptica/utils/label_group.py new file mode 100644 index 0000000..338a111 --- /dev/null +++ b/panoptica/utils/label_group.py @@ -0,0 +1,59 @@ +import numpy as np + + +class LabelGroup: + """Defines a group of labels that semantically belong to each other. Only labels within a group will be matched with each other""" + + def __init__( + self, + value_labels: list[int] | int, + single_instance: bool = False, + ) -> None: + """Defines a group of labels that semantically belong to each other + + Args: + value_labels (list[int]): Actually labels in the prediction and reference mask in this group. Defines the labels that can be matched to each other + single_instance (bool, optional): If true, will not use the matching_threshold as there is only one instance (large organ, ...). Defaults to False. + """ + if isinstance(value_labels, int): + value_labels = [value_labels] + assert len(value_labels) >= 1, f"You tried to define a LabelGroup without any specified labels, got {value_labels}" + self.__value_labels = value_labels + assert np.all([v > 0 for v in self.__value_labels]), f"Given value labels are not >0, got {value_labels}" + self.__single_instance = single_instance + if self.__single_instance: + assert len(value_labels) == 1, f"single_instance set to True, but got more than one label for this group, got {value_labels}" + + @property + def value_labels(self) -> list[int]: + return self.__value_labels + + @property + def single_instance(self) -> bool: + return self.__single_instance + + def __call__( + self, + array: np.ndarray, + set_to_binary: bool = False, + ) -> np.ndarray: + """Extracts the labels of this class + + Args: + array (np.ndarray): Array to extract the segmentation group labels from + set_to_binary (bool, optional): If true, will output a binary array. Defaults to False. + + Returns: + np.ndarray: Array containing only the labels of this segmentation group + """ + array = array.copy() + array[np.isin(array, self.value_labels, invert=True)] = 0 + if set_to_binary: + array[array != 0] = 1 + return array + + def __str__(self) -> str: + return f"LabelGroup {self.value_labels}, single_instance={self.single_instance}" + + def __repr__(self) -> str: + return str(self) diff --git a/panoptica/utils/segmentation_class.py b/panoptica/utils/segmentation_class.py index 4535ad3..70abde9 100644 --- a/panoptica/utils/segmentation_class.py +++ b/panoptica/utils/segmentation_class.py @@ -1,71 +1,9 @@ import numpy as np - - -class LabelGroup: - """Defines a group of labels that semantically belong to each other. Only labels within a group will be matched with each other""" - - def __init__( - self, - value_labels: list[int] | int, - single_instance: bool = False, - ) -> None: - """Defines a group of labels that semantically belong to each other - - Args: - value_labels (list[int]): Actually labels in the prediction and reference mask in this group. Defines the labels that can be matched to each other - single_instance (bool, optional): If true, will not use the matching_threshold as there is only one instance (large organ, ...). Defaults to False. - """ - if isinstance(value_labels, int): - value_labels = [value_labels] - assert ( - len(value_labels) >= 1 - ), f"You tried to define a LabelGroup without any specified labels, got {value_labels}" - self.__value_labels = value_labels - assert np.all( - [v > 0 for v in self.__value_labels] - ), f"Given value labels are not >0, got {value_labels}" - self.__single_instance = single_instance - if self.__single_instance: - assert ( - len(value_labels) == 1 - ), f"single_instance set to True, but got more than one label for this group, got {value_labels}" - - @property - def value_labels(self) -> list[int]: - return self.__value_labels - - @property - def single_instance(self) -> bool: - return self.__single_instance - - def __call__( - self, - array: np.ndarray, - set_to_binary: bool = False, - ) -> np.ndarray: - """Extracts the labels of this class - - Args: - array (np.ndarray): Array to extract the segmentation group labels from - set_to_binary (bool, optional): If true, will output a binary array. Defaults to False. - - Returns: - np.ndarray: Array containing only the labels of this segmentation group - """ - array = array.copy() - array[np.isin(array, self.value_labels, invert=True)] = 0 - if set_to_binary: - array[array != 0] = 1 - return array - - def __str__(self) -> str: - return f"LabelGroup {self.value_labels}, single_instance={self.single_instance}" - - def __repr__(self) -> str: - return str(self) +from panoptica.utils.label_group import LabelGroup class SegmentationClassGroups: + # def __init__( self, groups: list[LabelGroup] | dict[str, LabelGroup | tuple[list[int] | int, bool]], @@ -75,37 +13,25 @@ def __init__( # maps name of group to the group itself if isinstance(groups, list): - self.__group_dictionary = { - f"group_{idx}": g for idx, g in enumerate(groups) - } - else: + self.__group_dictionary = {f"group_{idx}": g for idx, g in enumerate(groups)} + elif isinstance(groups, dict): # transform dict into list of LabelGroups for i, g in groups.items(): name_lower = str(i).lower() if isinstance(g, LabelGroup): - self.__group_dictionary[name_lower] = LabelGroup( - g.value_labels, g.single_instance - ) + self.__group_dictionary[name_lower] = LabelGroup(g.value_labels, g.single_instance) else: self.__group_dictionary[name_lower] = LabelGroup(g[0], g[1]) # needs to check that each label is accounted for exactly ONCE - labels = [ - value_label - for lg in self.__group_dictionary.values() - for value_label in lg.value_labels - ] + labels = [value_label for lg in self.__group_dictionary.values() for value_label in lg.value_labels] duplicates = list_duplicates(labels) if len(duplicates) > 0: - raise AssertionError( - f"The same label was assigned to two different labelgroups, got {str(self)}" - ) + raise AssertionError(f"The same label was assigned to two different labelgroups, got {str(self)}") self.__labels = labels - def has_defined_labels_for( - self, arr: np.ndarray | list[int], raise_error: bool = False - ): + def has_defined_labels_for(self, arr: np.ndarray | list[int], raise_error: bool = False): if isinstance(arr, list): arr_labels = arr else: diff --git a/pyproject.toml b/pyproject.toml index 2a30f1f..80544da 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,7 @@ connected-components-3d = "^3.12.3" scipy = "^1.7.0" rich = "^13.6.0" scikit-image = "^0.22.0" +ruamel = "0.18.6" [tool.poetry.dev-dependencies] pytest = "^6.2.5" diff --git a/unit_tests/test_panoptic_evaluator.py b/unit_tests/test_panoptic_evaluator.py index a2cf755..276e195 100644 --- a/unit_tests/test_panoptic_evaluator.py +++ b/unit_tests/test_panoptic_evaluator.py @@ -16,7 +16,7 @@ from panoptica.utils.segmentation_class import SegmentationClassGroups, LabelGroup -class Test_Panoptic_Evaluator(unittest.TestCase): +class Test_Panoptica_Evaluator(unittest.TestCase): def setUp(self) -> None: os.environ["PANOPTICA_CITATION_REMINDER"] = "False" return super().setUp() diff --git a/unit_tests/test_panoptic_result.py b/unit_tests/test_panoptic_result.py index e88b2b2..f5087ea 100644 --- a/unit_tests/test_panoptic_result.py +++ b/unit_tests/test_panoptic_result.py @@ -15,7 +15,7 @@ from panoptica.utils.edge_case_handling import EdgeCaseHandler, EdgeCaseResult -class Test_Panoptic_Evaluator(unittest.TestCase): +class Test_Panoptica_Evaluator(unittest.TestCase): def setUp(self) -> None: os.environ["PANOPTICA_CITATION_REMINDER"] = "False" return super().setUp() From 6460f060a6c7cb30ff88b425e9838e3ea64c9ed5 Mon Sep 17 00:00:00 2001 From: iback Date: Fri, 2 Aug 2024 16:19:57 +0000 Subject: [PATCH 02/13] small metric fix --- panoptica/metrics/relative_volume_difference.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/panoptica/metrics/relative_volume_difference.py b/panoptica/metrics/relative_volume_difference.py index 4d952b2..1cde13e 100644 --- a/panoptica/metrics/relative_volume_difference.py +++ b/panoptica/metrics/relative_volume_difference.py @@ -55,16 +55,15 @@ def _compute_relative_volume_difference( prediction (np.ndarray): Prediction binary mask. Returns: - float: Relative volume Error between the two binary masks. A value between 0 and 1, where higher values - indicate better overlap and similarity between masks. + float: Relative volume Error between the two binary masks. A value of zero means perfect volume match, while >0 means oversegmentation and <0 undersegmentation. """ - reference_mask = np.sum(reference) - prediction_mask = np.sum(prediction) + reference_mask = float(np.sum(reference)) + prediction_mask = float(np.sum(prediction)) # Handle division by zero if reference_mask == 0 and prediction_mask == 0: return 0.0 # Calculate Dice coefficient - rvd = float(prediction_mask - reference_mask) / reference_mask + rvd = (prediction_mask - reference_mask) / reference_mask return rvd From cb915fafe79d4f42f01cf907e5db8e928adeca51 Mon Sep 17 00:00:00 2001 From: iback Date: Fri, 2 Aug 2024 16:20:13 +0000 Subject: [PATCH 03/13] first version of configs done --- examples/example_spine_instance_config.py | 38 +++++ panoptica/base_configs/arya.yaml | 1 - panoptica/base_configs/config_edgecase.yaml | 0 panoptica/base_configs/test.yaml | 4 - panoptica/base_configs/test_out.yaml | 4 - ...sGroups_example_unmatchedinstancepair.yaml | 14 ++ panoptica/utils/config.py | 158 +++++++++++++----- panoptica/utils/constants.py | 32 ++++ panoptica/utils/edge_case_handling.py | 48 ++---- panoptica/utils/filepath.py | 37 ++++ panoptica/utils/label_group.py | 20 ++- panoptica/utils/segmentation_class.py | 15 +- unit_tests/test_config.py | 54 ++++++ unit_tests/test_labelgroup.py | 4 +- 14 files changed, 343 insertions(+), 86 deletions(-) create mode 100644 examples/example_spine_instance_config.py delete mode 100644 panoptica/base_configs/arya.yaml delete mode 100644 panoptica/base_configs/config_edgecase.yaml delete mode 100644 panoptica/base_configs/test.yaml delete mode 100644 panoptica/base_configs/test_out.yaml create mode 100644 panoptica/configs/SegmentationClassGroups_example_unmatchedinstancepair.yaml create mode 100644 panoptica/utils/filepath.py create mode 100644 unit_tests/test_config.py diff --git a/examples/example_spine_instance_config.py b/examples/example_spine_instance_config.py new file mode 100644 index 0000000..ae233aa --- /dev/null +++ b/examples/example_spine_instance_config.py @@ -0,0 +1,38 @@ +import cProfile + +from auxiliary.nifti.io import read_nifti +from auxiliary.turbopath import turbopath + +from panoptica import UnmatchedInstancePair, Panoptica_Evaluator, NaiveThresholdMatching +from panoptica.metrics import Metric +from panoptica.utils.segmentation_class import LabelGroup, SegmentationClassGroups + +directory = turbopath(__file__).parent + +ref_masks = read_nifti(directory + "/spine_seg/matched_instance/ref.nii.gz") +pred_masks = read_nifti(directory + "/spine_seg/matched_instance/pred.nii.gz") + +sample = UnmatchedInstancePair(prediction_arr=pred_masks, reference_arr=ref_masks) + +# LabelGroup._register_permanently() + +evaluator = Panoptica_Evaluator( + expected_input=UnmatchedInstancePair, + eval_metrics=[Metric.DSC, Metric.IOU], + instance_matcher=NaiveThresholdMatching(), + segmentation_class_groups=SegmentationClassGroups.load_from_config_name("SegmentationClassGroups_example_unmatchedinstancepair"), + decision_metric=Metric.DSC, + decision_threshold=0.5, + log_times=True, +) + + +with cProfile.Profile() as pr: + if __name__ == "__main__": + results = evaluator.evaluate(sample, verbose=False) + for groupname, (result, debug) in results.items(): + print() + print("### Group", groupname) + print(result) + + pr.dump_stats(directory + "/instance_example.log") diff --git a/panoptica/base_configs/arya.yaml b/panoptica/base_configs/arya.yaml deleted file mode 100644 index bbd5e8c..0000000 --- a/panoptica/base_configs/arya.yaml +++ /dev/null @@ -1 +0,0 @@ -!Person {age: 18, name: arya} diff --git a/panoptica/base_configs/config_edgecase.yaml b/panoptica/base_configs/config_edgecase.yaml deleted file mode 100644 index e69de29..0000000 diff --git a/panoptica/base_configs/test.yaml b/panoptica/base_configs/test.yaml deleted file mode 100644 index ea2ab77..0000000 --- a/panoptica/base_configs/test.yaml +++ /dev/null @@ -1,4 +0,0 @@ -YAML: - - T: 2 - - E: 3 - - 3: "TEST" \ No newline at end of file diff --git a/panoptica/base_configs/test_out.yaml b/panoptica/base_configs/test_out.yaml deleted file mode 100644 index d503b15..0000000 --- a/panoptica/base_configs/test_out.yaml +++ /dev/null @@ -1,4 +0,0 @@ -YAML: -- {T: 2} -- {E: 3} -- {3: TEST} diff --git a/panoptica/configs/SegmentationClassGroups_example_unmatchedinstancepair.yaml b/panoptica/configs/SegmentationClassGroups_example_unmatchedinstancepair.yaml new file mode 100644 index 0000000..9c31720 --- /dev/null +++ b/panoptica/configs/SegmentationClassGroups_example_unmatchedinstancepair.yaml @@ -0,0 +1,14 @@ +!SegmentationClassGroups +groups: + endplate: !LabelGroup + single_instance: false + value_labels: [201, 202, 203, 204, 205, 206, 207, 208, 209, 210] + ivd: !LabelGroup + single_instance: false + value_labels: [101, 102, 103, 104, 105, 106, 107, 108, 109, 110] + sacrum: !LabelGroup + single_instance: true + value_labels: [26] + vertebra: !LabelGroup + single_instance: false + value_labels: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] diff --git a/panoptica/utils/config.py b/panoptica/utils/config.py index 49f32ff..56c5111 100644 --- a/panoptica/utils/config.py +++ b/panoptica/utils/config.py @@ -1,83 +1,163 @@ from ruamel.yaml import YAML from pathlib import Path +from panoptica.utils.filepath import config_by_name -#################### +supported_helper_classes = [] + + +def _register_helper_classes(yaml: YAML): + [yaml.register_class(s) for s in supported_helper_classes] -def load_yaml(file: str | Path): +def _load_yaml(file: str | Path, registered_class=None): if isinstance(file, str): file = Path(file) - yaml = YAML(typ="safe") - data = yaml.load(file) - assert isinstance(data, dict) or isinstance(data, object) + yaml = YAML(typ="safe") + _register_helper_classes(yaml) + if registered_class is not None: + yaml.register_class(registered_class) + yaml.default_flow_style = None + data = yaml.load(file) + assert isinstance(data, dict) or isinstance(data, object) return data -def save_yaml(data_dict: dict | object, out_file: str | Path, registered_class=None): +def _save_yaml(data_dict: dict | object, out_file: str | Path, registered_class=None): if isinstance(out_file, str): out_file = Path(out_file) - yaml = YAML(typ="safe") - if registered_class is not None: - yaml.register_class(registered_class) - if isinstance(data_dict, object): - yaml.dump(data_dict, out_file) - else: - yaml.dump([registered_class(*data_dict)], out_file) - else: - yaml.dump(data_dict, out_file) - - -class Person: - name: str - age: int - - def __init__(self, name, age) -> None: - self.name = name - self.age = age + yaml = YAML(typ="safe") + yaml.default_flow_style = None + if registered_class is not None: + yaml.register_class(registered_class) + _register_helper_classes(yaml) + # if isinstance(data_dict, object): + yaml.dump(data_dict, out_file) + # else: + # yaml.dump([registered_class(*data_dict)], out_file) + else: + yaml.dump(data_dict, out_file) #################### - -# TODO split into general config and object configuration (latter saves as object yaml and loads directly as object?) - - +# TODO Merge into SupportsConfig class Configuration: + """General Configuration class that handles yaml""" + _data_dict: dict _registered_class = None def __init__(self, data_dict: dict, registered_class=None) -> None: + assert isinstance(data_dict, dict) self._data_dict = data_dict if registered_class is not None: self.register_to_class(registered_class) def register_to_class(self, cls): + global supported_helper_classes + if cls not in supported_helper_classes: + supported_helper_classes.append(cls) self._registered_class = cls + return self @classmethod def save_from_object(cls, obj: object, file: str | Path): - save_yaml(obj, file, registered_class=type(obj)) - return Configuration.load(file, registered_class=type(obj)) + _save_yaml(obj, file, registered_class=type(obj)) + # return Configuration.load(file, registered_class=type(obj)) @classmethod def load(cls, file: str | Path, registered_class=None): - data = load_yaml(file) + data = _load_yaml(file, registered_class) + assert isinstance(data, dict), f"The config at {file} is registered to a class. Use load_as_object() instead" return Configuration(data, registered_class=registered_class) + @classmethod + def load_as_object(cls, file: str | Path, registered_class=None): + data = _load_yaml(file, registered_class) + assert not isinstance(data, dict), f"The config at {file} is not registered to a class. Use load() instead" + return data + def save(self, out_file: str | Path): - save_yaml(self._data_dict, out_file) + _save_yaml(self._data_dict, out_file) def cls_object_from_this(self): assert self._registered_class is not None - self._registered_class(*self._data_dict) + return self._registered_class(**self._data_dict) + + @property + def data_dict(self): + return self._data_dict + + @property + def cls(self): + return self._registered_class + + def __str__(self) -> str: + return f"Config({self.cls.__name__ if self.cls is not None else 'NoClass'} = {self.data_dict})" # type: ignore + +######### +# Universal Functions +######### +def _register_class_to_yaml(cls): + global supported_helper_classes + if cls not in supported_helper_classes: + supported_helper_classes.append(cls) -if __name__ == "__main__": - c = Configuration.load("/DATA/NAS/ongoing_projects/hendrik/panoptica/repo/panoptica/base_configs/test.yaml") - c.save("/DATA/NAS/ongoing_projects/hendrik/panoptica/repo/panoptica/base_configs/test_out.yaml") +def _load_from_config(cls, path: str | Path): + # cls._register_permanently() + if isinstance(path, str): + path = Path(path) + assert path.exists(), f"load_from_config: {path} does not exist" + obj = Configuration.load_as_object(path, registered_class=cls) + assert isinstance(obj, cls), f"Loaded config was not for class {cls.__name__}" + return obj - arya = Person("arya", 18) - c = Configuration.save_from_object(arya, "/DATA/NAS/ongoing_projects/hendrik/panoptica/repo/panoptica/base_configs/arya.yaml") - print(c._data_dict.name) +def _load_from_config_name(cls, name: str): + path = config_by_name(name) + assert path.exists(), f"load_from_config: {path} does not exist" + return _load_from_config(cls, path) + + +def _save_to_config(obj, path: str | Path): + if isinstance(path, str): + path = Path(path) + Configuration.save_from_object(obj, path) + + +class SupportsConfig: + """Metaclass that allows a class to save and load objects by yaml configs""" + + def __init_subclass__(cls, **kwargs): + # Registers all subclasses of this + super().__init_subclass__(**kwargs) + cls._register_permanently() + + @classmethod + def _register_permanently(cls): + _register_class_to_yaml(cls) + + @classmethod + def load_from_config(cls, path: str | Path): + return _load_from_config(cls, path) + + @classmethod + def load_from_config_name(cls, name: str): + return _load_from_config_name(cls, name) + + def save_to_config(self, path: str | Path): + _save_to_config(self, path) + + @classmethod + def to_yaml(cls, representer, node): + # cls._register_permanently() + assert hasattr(cls, "_yaml_repr"), f"Class {cls.__name__} has no _yaml_repr(cls, node) defined" + return representer.represent_mapping("!" + cls.__name__, cls._yaml_repr(node)) + + @classmethod + def from_yaml(cls, constructor, node): + # cls._register_permanently() + data = constructor.construct_mapping(node, deep=True) + return cls(**data) diff --git a/panoptica/utils/constants.py b/panoptica/utils/constants.py index d4a1faa..8c2fc80 100644 --- a/panoptica/utils/constants.py +++ b/panoptica/utils/constants.py @@ -1,4 +1,6 @@ from enum import Enum, auto +from panoptica.utils.config import _register_class_to_yaml, _load_from_config, _load_from_config_name, _save_to_config +from pathlib import Path class _Enum_Compare(Enum): @@ -16,6 +18,36 @@ def __str__(self) -> str: def __repr__(self) -> str: return str(self) + def __init_subclass__(cls, **kwargs): + # Registers all subclasses of this + super().__init_subclass__(**kwargs) + cls._register_permanently() + + @classmethod + def _register_permanently(cls): + _register_class_to_yaml(cls) + + @classmethod + def load_from_config(cls, path: str | Path): + return _load_from_config(cls, path) + + @classmethod + def load_from_config_name(cls, name: str): + return _load_from_config_name(cls, name) + + def save_to_config(self, path: str | Path): + _save_to_config(self, path) + + @classmethod + def to_yaml(cls, representer, node): + # cls._register_permanently() + # assert hasattr(cls, "_yaml_repr"), f"Class {cls.__name__} has no _yaml_repr(cls, node) defined" + return representer.represent_scalar("!" + cls.__name__, str(node.name)) + + @classmethod + def from_yaml(cls, constructor, node): + return cls[node.value] + class CCABackend(_Enum_Compare): """ diff --git a/panoptica/utils/edge_case_handling.py b/panoptica/utils/edge_case_handling.py index e0ecc7e..7a02fcc 100644 --- a/panoptica/utils/edge_case_handling.py +++ b/panoptica/utils/edge_case_handling.py @@ -1,8 +1,7 @@ import numpy as np -from typing import TYPE_CHECKING - from panoptica.metrics import Metric from panoptica.utils.constants import _Enum_Compare, auto +from panoptica.utils.config import SupportsConfig class EdgeCaseResult(_Enum_Compare): @@ -23,7 +22,7 @@ def __hash__(self) -> int: return self.value -class MetricZeroTPEdgeCaseHandling(object): +class MetricZeroTPEdgeCaseHandling(SupportsConfig): def __init__( self, default_result: EdgeCaseResult, @@ -33,26 +32,12 @@ def __init__( normal: EdgeCaseResult | None = None, ) -> None: self.edgecase_dict: dict[EdgeCaseZeroTP, EdgeCaseResult] = {} - self.edgecase_dict[EdgeCaseZeroTP.EMPTY_PRED] = ( - empty_prediction_result - if empty_prediction_result is not None - else default_result - ) - self.edgecase_dict[EdgeCaseZeroTP.EMPTY_REF] = ( - empty_reference_result - if empty_reference_result is not None - else default_result - ) - self.edgecase_dict[EdgeCaseZeroTP.NO_INSTANCES] = ( - no_instances_result if no_instances_result is not None else default_result - ) - self.edgecase_dict[EdgeCaseZeroTP.NORMAL] = ( - normal if normal is not None else default_result - ) + self.edgecase_dict[EdgeCaseZeroTP.EMPTY_PRED] = empty_prediction_result if empty_prediction_result is not None else default_result + self.edgecase_dict[EdgeCaseZeroTP.EMPTY_REF] = empty_reference_result if empty_reference_result is not None else default_result + self.edgecase_dict[EdgeCaseZeroTP.NO_INSTANCES] = no_instances_result if no_instances_result is not None else default_result + self.edgecase_dict[EdgeCaseZeroTP.NORMAL] = normal if normal is not None else default_result - def __call__( - self, tp: int, num_pred_instances, num_ref_instances - ) -> tuple[bool, float | None]: + def __call__(self, tp: int, num_pred_instances, num_ref_instances) -> tuple[bool, float | None]: if tp != 0: return False, EdgeCaseResult.NONE.value # @@ -76,8 +61,13 @@ def __str__(self) -> str: txt += str(k) + ": " + str(v) + "\n" return txt + @classmethod + def _yaml_repr(cls, node): + # TODO + return {"value_labels": node.value_labels, "single_instance": node.single_instance} -class EdgeCaseHandler: + +class EdgeCaseHandler(SupportsConfig): def __init__( self, @@ -106,9 +96,7 @@ def __init__( }, empty_list_std: EdgeCaseResult = EdgeCaseResult.NAN, ) -> None: - self.__listmetric_zeroTP_handling: dict[ - Metric, MetricZeroTPEdgeCaseHandling - ] = listmetric_zeroTP_handling + self.__listmetric_zeroTP_handling: dict[Metric, MetricZeroTPEdgeCaseHandling] = listmetric_zeroTP_handling self.__empty_list_std: EdgeCaseResult = empty_list_std def handle_zero_tp( @@ -121,9 +109,7 @@ def handle_zero_tp( if tp != 0: return False, EdgeCaseResult.NONE.value if metric not in self.__listmetric_zeroTP_handling: - raise NotImplementedError( - f"Metric {metric} encountered zero TP, but no edge handling available" - ) + raise NotImplementedError(f"Metric {metric} encountered zero TP, but no edge handling available") return self.__listmetric_zeroTP_handling[metric]( tp=tp, @@ -149,9 +135,7 @@ def __str__(self) -> str: print() # print(handler.get_metric_zero_tp_handle(ListMetric.IOU)) - r = handler.handle_zero_tp( - Metric.IOU, tp=0, num_pred_instances=1, num_ref_instances=1 - ) + r = handler.handle_zero_tp(Metric.IOU, tp=0, num_pred_instances=1, num_ref_instances=1) print(r) iou_test = MetricZeroTPEdgeCaseHandling( diff --git a/panoptica/utils/filepath.py b/panoptica/utils/filepath.py new file mode 100644 index 0000000..aac9edd --- /dev/null +++ b/panoptica/utils/filepath.py @@ -0,0 +1,37 @@ +import os +import warnings +from itertools import chain +from pathlib import Path +from auxiliary.turbopath import turbopath + + +def search_path(basepath: str | Path, query: str, verbose: bool = False, suppress: bool = False) -> list[Path]: + """Searches from basepath with query + Args: + basepath: ground path to look into + query: search query, can contain wildcards like *.npz or **/*.npz + verbose: + suppress: if true, will not throwing warnings if nothing is found + + Returns: + All found paths + """ + basepath = str(basepath) + assert os.path.exists(basepath), f"basepath for search_path() doesnt exist, got {basepath}" + if not basepath.endswith("/"): + basepath += "/" + print(f"search_path: in {basepath}{query}") if verbose else None + paths = sorted(list(chain(list(Path(f"{basepath}").glob(f"{query}"))))) + if len(paths) == 0 and not suppress: + warnings.warn(f"did not find any paths in {basepath}{query}", UserWarning) + return paths + + +# Find config path +def config_by_name(name: str) -> Path: + directory = turbopath(__file__).parent.parent + if not name.endswith(".yaml"): + name += ".yaml" + p = search_path(directory, query=f"**/{name}", suppress=True) + assert len(p) == 1, f"Did not find exactly one config yaml with name {name} in directory {directory}, got {p}" + return p[0] diff --git a/panoptica/utils/label_group.py b/panoptica/utils/label_group.py index 338a111..6757500 100644 --- a/panoptica/utils/label_group.py +++ b/panoptica/utils/label_group.py @@ -1,7 +1,10 @@ import numpy as np +from panoptica.utils.config import SupportsConfig +# -class LabelGroup: + +class LabelGroup(SupportsConfig): """Defines a group of labels that semantically belong to each other. Only labels within a group will be matched with each other""" def __init__( @@ -24,6 +27,8 @@ def __init__( if self.__single_instance: assert len(value_labels) == 1, f"single_instance set to True, but got more than one label for this group, got {value_labels}" + LabelGroup._register_permanently() + @property def value_labels(self) -> list[int]: return self.__value_labels @@ -57,3 +62,16 @@ def __str__(self) -> str: def __repr__(self) -> str: return str(self) + + @classmethod + def _yaml_repr(cls, node): + return {"value_labels": node.value_labels, "single_instance": node.single_instance} + + # @classmethod + # def to_yaml(cls, representer, node): + # return representer.represent_mapping("!" + cls.__name__, cls._yaml_repr(node)) + + # @classmethod + # def from_yaml(cls, constructor, node): + # data = constructor.construct_mapping(node, deep=True) + # return cls(**data) diff --git a/panoptica/utils/segmentation_class.py b/panoptica/utils/segmentation_class.py index 70abde9..678822d 100644 --- a/panoptica/utils/segmentation_class.py +++ b/panoptica/utils/segmentation_class.py @@ -1,8 +1,10 @@ import numpy as np +from pathlib import Path +from panoptica.utils.config import SupportsConfig from panoptica.utils.label_group import LabelGroup -class SegmentationClassGroups: +class SegmentationClassGroups(SupportsConfig): # def __init__( self, @@ -28,8 +30,9 @@ def __init__( duplicates = list_duplicates(labels) if len(duplicates) > 0: raise AssertionError(f"The same label was assigned to two different labelgroups, got {str(self)}") - self.__labels = labels + SegmentationClassGroups._register_permanently() + LabelGroup._register_permanently() def has_defined_labels_for(self, arr: np.ndarray | list[int], raise_error: bool = False): if isinstance(arr, list): @@ -63,10 +66,18 @@ def __iter__(self): def keys(self) -> list[str]: return list(self.__group_dictionary.keys()) + @property + def labels(self): + return self.__labels + def items(self): for k in self: yield k, self[k] + @classmethod + def _yaml_repr(cls, node): + return {"groups": node.__group_dictionary} + def list_duplicates(seq): seen = set() diff --git a/unit_tests/test_config.py b/unit_tests/test_config.py new file mode 100644 index 0000000..189ca24 --- /dev/null +++ b/unit_tests/test_config.py @@ -0,0 +1,54 @@ +# Call 'python -m unittest' on this folder +# coverage run -m unittest +# coverage report +# coverage html +import os +import unittest + +from panoptica.metrics import ( + Metric, + Evaluation_List_Metric, + MetricMode, + MetricCouldNotBeComputedException, +) +from panoptica.utils.segmentation_class import SegmentationClassGroups, LabelGroup +from panoptica.utils.constants import CCABackend +from pathlib import Path + +test_file = Path(__file__).parent.joinpath("test.yaml") + + +class Test_Datatypes(unittest.TestCase): + def setUp(self) -> None: + os.environ["PANOPTICA_CITATION_REMINDER"] = "False" + return super().setUp() + + def test_enum_config(self): + a = CCABackend.cc3d + a.save_to_config(test_file) + print(a) + b = CCABackend.load_from_config(test_file) + print(b) + os.remove(test_file) + + self.assertEqual(a, b) + + def test_SegmentationClassGroups_config(self): + e = { + "groups": { + "vertebra": LabelGroup([i for i in range(1, 11)], False), + "ivd": LabelGroup([i for i in range(101, 111)]), + "sacrum": LabelGroup(26, True), + "endplate": LabelGroup([i for i in range(201, 211)]), + } + } + t = SegmentationClassGroups(**e) + print(t) + print() + t.save_to_config(test_file) + d: SegmentationClassGroups = SegmentationClassGroups.load_from_config(test_file) + os.remove(test_file) + + for k, v in d.items(): + self.assertEqual(t[k].single_instance, v.single_instance) + self.assertEqual(len(t[k].value_labels), len(v.value_labels)) diff --git a/unit_tests/test_labelgroup.py b/unit_tests/test_labelgroup.py index 4aee9f3..292629a 100644 --- a/unit_tests/test_labelgroup.py +++ b/unit_tests/test_labelgroup.py @@ -82,9 +82,7 @@ def test_segmentationclassgroup_easy(self): self.assertTrue(isinstance(lg, LabelGroup)) def test_segmentationclassgroup_decarations(self): - classgroups = SegmentationClassGroups( - groups=[LabelGroup(i) for i in range(1, 5)] - ) + classgroups = SegmentationClassGroups(groups=[LabelGroup(i) for i in range(1, 5)]) keys = classgroups.keys() for i in range(1, 5): From 823c849ff587ad0e324d2eeb1e3c11da19b9fe40 Mon Sep 17 00:00:00 2001 From: iback Date: Mon, 5 Aug 2024 11:35:14 +0000 Subject: [PATCH 04/13] configs work now with panoptica_evaluator objects as well. Hid Processing Pair behind InputType Enum, moved instance labelmap to its own file. updated examples accordingly and added many new unit tests --- examples/example_spine_instance.py | 13 ++- examples/example_spine_instance_config.py | 24 ++--- examples/example_spine_semantic.py | 11 ++- panoptica/__init__.py | 1 + ...anoptica_evaluator_unmatched_instance.yaml | 42 +++++++++ panoptica/instance_approximator.py | 19 ++-- panoptica/instance_matcher.py | 57 +++++++----- panoptica/panoptica_evaluator.py | 31 +++++-- panoptica/panoptica_result.py | 2 +- panoptica/utils/__init__.py | 2 +- panoptica/utils/config.py | 21 ++++- panoptica/utils/constants.py | 15 +++- panoptica/utils/edge_case_handling.py | 64 +++++++++---- panoptica/utils/instancelabelmap.py | 64 +++++++++++++ panoptica/utils/processing_pair.py | 86 ++---------------- panoptica/utils/segmentation_class.py | 2 - unit_tests/test_config.py | 90 +++++++++++++++++++ unit_tests/test_datatype.py | 5 ++ 18 files changed, 378 insertions(+), 171 deletions(-) create mode 100644 panoptica/configs/panoptica_evaluator_unmatched_instance.yaml create mode 100644 panoptica/utils/instancelabelmap.py diff --git a/examples/example_spine_instance.py b/examples/example_spine_instance.py index 2e5aaca..bcea628 100644 --- a/examples/example_spine_instance.py +++ b/examples/example_spine_instance.py @@ -3,20 +3,17 @@ from auxiliary.nifti.io import read_nifti from auxiliary.turbopath import turbopath -from panoptica import MatchedInstancePair, Panoptica_Evaluator +from panoptica import Panoptica_Evaluator, InputType from panoptica.metrics import Metric from panoptica.utils.segmentation_class import LabelGroup, SegmentationClassGroups directory = turbopath(__file__).parent -ref_masks = read_nifti(directory + "/spine_seg/matched_instance/ref.nii.gz") - -pred_masks = read_nifti(directory + "/spine_seg/matched_instance/pred.nii.gz") - -sample = MatchedInstancePair(prediction_arr=pred_masks, reference_arr=ref_masks) +reference_mask = read_nifti(directory + "/spine_seg/matched_instance/ref.nii.gz") +prediction_mask = read_nifti(directory + "/spine_seg/matched_instance/pred.nii.gz") evaluator = Panoptica_Evaluator( - expected_input=MatchedInstancePair, + expected_input=InputType.MATCHED_INSTANCE, eval_metrics=[Metric.DSC, Metric.IOU], segmentation_class_groups=SegmentationClassGroups( { @@ -34,7 +31,7 @@ with cProfile.Profile() as pr: if __name__ == "__main__": - results = evaluator.evaluate(sample, verbose=False) + results = evaluator.evaluate(prediction_mask, reference_mask, verbose=False) for groupname, (result, debug) in results.items(): print() print("### Group", groupname) diff --git a/examples/example_spine_instance_config.py b/examples/example_spine_instance_config.py index ae233aa..2855add 100644 --- a/examples/example_spine_instance_config.py +++ b/examples/example_spine_instance_config.py @@ -3,33 +3,19 @@ from auxiliary.nifti.io import read_nifti from auxiliary.turbopath import turbopath -from panoptica import UnmatchedInstancePair, Panoptica_Evaluator, NaiveThresholdMatching -from panoptica.metrics import Metric -from panoptica.utils.segmentation_class import LabelGroup, SegmentationClassGroups +from panoptica import Panoptica_Evaluator directory = turbopath(__file__).parent -ref_masks = read_nifti(directory + "/spine_seg/matched_instance/ref.nii.gz") -pred_masks = read_nifti(directory + "/spine_seg/matched_instance/pred.nii.gz") +reference_mask = read_nifti(directory + "/spine_seg/matched_instance/ref.nii.gz") +prediction_mask = read_nifti(directory + "/spine_seg/matched_instance/pred.nii.gz") -sample = UnmatchedInstancePair(prediction_arr=pred_masks, reference_arr=ref_masks) - -# LabelGroup._register_permanently() - -evaluator = Panoptica_Evaluator( - expected_input=UnmatchedInstancePair, - eval_metrics=[Metric.DSC, Metric.IOU], - instance_matcher=NaiveThresholdMatching(), - segmentation_class_groups=SegmentationClassGroups.load_from_config_name("SegmentationClassGroups_example_unmatchedinstancepair"), - decision_metric=Metric.DSC, - decision_threshold=0.5, - log_times=True, -) +evaluator = Panoptica_Evaluator.load_from_config_name("panoptica_evaluator_unmatched_instance") with cProfile.Profile() as pr: if __name__ == "__main__": - results = evaluator.evaluate(sample, verbose=False) + results = evaluator.evaluate(prediction_mask, reference_mask, verbose=False) for groupname, (result, debug) in results.items(): print() print("### Group", groupname) diff --git a/examples/example_spine_semantic.py b/examples/example_spine_semantic.py index 7385701..2793592 100644 --- a/examples/example_spine_semantic.py +++ b/examples/example_spine_semantic.py @@ -7,18 +7,17 @@ ConnectedComponentsInstanceApproximator, NaiveThresholdMatching, Panoptica_Evaluator, - SemanticPair, + InputType, ) directory = turbopath(__file__).parent -ref_masks = read_nifti(directory + "/spine_seg/semantic/ref.nii.gz") -pred_masks = read_nifti(directory + "/spine_seg/semantic/pred.nii.gz") +reference_mask = read_nifti(directory + "/spine_seg/semantic/ref.nii.gz") +prediction_mask = read_nifti(directory + "/spine_seg/semantic/pred.nii.gz") -sample = SemanticPair(pred_masks, ref_masks) evaluator = Panoptica_Evaluator( - expected_input=SemanticPair, + expected_input=InputType.SEMANTIC, instance_approximator=ConnectedComponentsInstanceApproximator(), instance_matcher=NaiveThresholdMatching(), verbose=True, @@ -26,7 +25,7 @@ with cProfile.Profile() as pr: if __name__ == "__main__": - result, debug_data = evaluator.evaluate(sample)["ungrouped"] + result, debug_data = evaluator.evaluate(prediction_mask, reference_mask)["ungrouped"] print(result) pr.dump_stats(directory + "/semantic_example.log") diff --git a/panoptica/__init__.py b/panoptica/__init__.py index 9a4d080..dca6768 100644 --- a/panoptica/__init__.py +++ b/panoptica/__init__.py @@ -9,5 +9,6 @@ SemanticPair, UnmatchedInstancePair, MatchedInstancePair, + InputType, ) from panoptica.metrics import Metric, MetricMode, MetricType diff --git a/panoptica/configs/panoptica_evaluator_unmatched_instance.yaml b/panoptica/configs/panoptica_evaluator_unmatched_instance.yaml new file mode 100644 index 0000000..ad8d9b0 --- /dev/null +++ b/panoptica/configs/panoptica_evaluator_unmatched_instance.yaml @@ -0,0 +1,42 @@ +!Panoptica_Evaluator +decision_metric: !Metric DSC +decision_threshold: 0.5 +edge_case_handler: !EdgeCaseHandler + empty_list_std: !EdgeCaseResult NAN + listmetric_zeroTP_handling: + !Metric DSC: !MetricZeroTPEdgeCaseHandling {empty_prediction_result: !EdgeCaseResult ZERO, + empty_reference_result: !EdgeCaseResult ZERO, no_instances_result: !EdgeCaseResult NAN, + normal: !EdgeCaseResult ZERO} + !Metric clDSC: !MetricZeroTPEdgeCaseHandling {empty_prediction_result: !EdgeCaseResult ZERO, + empty_reference_result: !EdgeCaseResult ZERO, no_instances_result: !EdgeCaseResult NAN, + normal: !EdgeCaseResult ZERO} + !Metric IOU: !MetricZeroTPEdgeCaseHandling {empty_prediction_result: !EdgeCaseResult ZERO, + empty_reference_result: !EdgeCaseResult ZERO, no_instances_result: !EdgeCaseResult NAN, + normal: !EdgeCaseResult ZERO} + !Metric ASSD: !MetricZeroTPEdgeCaseHandling {empty_prediction_result: !EdgeCaseResult INF, + empty_reference_result: !EdgeCaseResult INF, no_instances_result: !EdgeCaseResult NAN, + normal: !EdgeCaseResult INF} + !Metric RVD: !MetricZeroTPEdgeCaseHandling {empty_prediction_result: !EdgeCaseResult NAN, + empty_reference_result: !EdgeCaseResult NAN, no_instances_result: !EdgeCaseResult NAN, + normal: !EdgeCaseResult NAN} +eval_metrics: [!Metric DSC, !Metric IOU] +expected_input: !InputType UNMATCHED_INSTANCE +instance_approximator: null +instance_matcher: !NaiveThresholdMatching {allow_many_to_one: false, matching_metric: !Metric IOU, + matching_threshold: 0.5} +log_times: true +segmentation_class_groups: !SegmentationClassGroups + groups: + endplate: !LabelGroup + single_instance: false + value_labels: [201, 202, 203, 204, 205, 206, 207, 208, 209, 210] + ivd: !LabelGroup + single_instance: false + value_labels: [101, 102, 103, 104, 105, 106, 107, 108, 109, 110] + sacrum: !LabelGroup + single_instance: true + value_labels: [26] + vertebra: !LabelGroup + single_instance: false + value_labels: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] +verbose: false diff --git a/panoptica/instance_approximator.py b/panoptica/instance_approximator.py index e73f843..957b934 100644 --- a/panoptica/instance_approximator.py +++ b/panoptica/instance_approximator.py @@ -1,4 +1,4 @@ -from abc import ABC, abstractmethod +from abc import ABC, abstractmethod, ABCMeta import numpy as np @@ -10,9 +10,10 @@ SemanticPair, UnmatchedInstancePair, ) +from panoptica.utils.config import SupportsConfig -class InstanceApproximator(ABC): +class InstanceApproximator(SupportsConfig, metaclass=ABCMeta): """ Abstract base class for instance approximation algorithms in panoptic segmentation evaluation. @@ -56,6 +57,10 @@ def _approximate_instances( """ pass + def _yaml_repr(cls, node) -> dict: + raise NotImplementedError(f"Tried to get yaml representation of abstract class {cls.__name__}") + return {} + def approximate_instances( self, semantic_pair: SemanticPair, verbose: bool = False, **kwargs ) -> UnmatchedInstancePair | MatchedInstancePair: @@ -140,10 +145,8 @@ def _approximate_instances( UnmatchedInstancePair: The result of the instance approximation. """ cca_backend = self.cca_backend - if self.cca_backend is None: - cca_backend = ( - CCABackend.cc3d if semantic_pair.n_dim >= 3 else CCABackend.scipy - ) + if cca_backend is None: + cca_backend = CCABackend.cc3d if semantic_pair.n_dim >= 3 else CCABackend.scipy assert cca_backend is not None empty_prediction = len(semantic_pair._pred_labels) == 0 @@ -164,3 +167,7 @@ def _approximate_instances( n_prediction_instance=n_prediction_instance, n_reference_instance=n_reference_instance, ) + + @classmethod + def _yaml_repr(cls, node) -> dict: + return {"cca_backend": node.cca_backend} diff --git a/panoptica/instance_matcher.py b/panoptica/instance_matcher.py index 3b32e63..6c18249 100644 --- a/panoptica/instance_matcher.py +++ b/panoptica/instance_matcher.py @@ -1,4 +1,4 @@ -from abc import ABC, abstractmethod +from abc import ABCMeta, abstractmethod import numpy as np @@ -8,13 +8,14 @@ ) from panoptica.metrics import Metric from panoptica.utils.processing_pair import ( - InstanceLabelMap, MatchedInstancePair, UnmatchedInstancePair, ) +from panoptica.utils.instancelabelmap import InstanceLabelMap +from panoptica.utils.config import SupportsConfig -class InstanceMatchingAlgorithm(ABC): +class InstanceMatchingAlgorithm(SupportsConfig, metaclass=ABCMeta): """ Abstract base class for instance matching algorithms in panoptic segmentation evaluation. @@ -79,6 +80,10 @@ def match_instances( # print("instance_labelmap:", instance_labelmap) return map_instance_labels(unmatched_instance_pair.copy(), instance_labelmap) + def _yaml_repr(cls, node) -> dict: + raise NotImplementedError(f"Tried to get yaml representation of abstract class {cls.__name__}") + return {} + def map_instance_labels( processing_pair: UnmatchedInstancePair, labelmap: InstanceLabelMap @@ -166,9 +171,9 @@ def __init__( Raises: AssertionError: If the specified IoU threshold is not within the valid range. """ - self.allow_many_to_one = allow_many_to_one - self.matching_metric = matching_metric - self.matching_threshold = matching_threshold + self._allow_many_to_one = allow_many_to_one + self._matching_metric = matching_metric + self._matching_threshold = matching_threshold def _match_instances( self, @@ -194,25 +199,26 @@ def _match_instances( unmatched_instance_pair.prediction_arr, unmatched_instance_pair.reference_arr, ) - mm_pairs = _calc_matching_metric_of_overlapping_labels( - pred_arr, ref_arr, ref_labels, matching_metric=self.matching_metric - ) + mm_pairs = _calc_matching_metric_of_overlapping_labels(pred_arr, ref_arr, ref_labels, matching_metric=self._matching_metric) # Loop through matched instances to compute PQ components for matching_score, (ref_label, pred_label) in mm_pairs: - if ( - labelmap.contains_or(pred_label, ref_label) - and not self.allow_many_to_one - ): + if labelmap.contains_or(pred_label, ref_label) and not self._allow_many_to_one: continue # -> doesnt make speed difference - if self.matching_metric.score_beats_threshold( - matching_score, self.matching_threshold - ): + if self._matching_metric.score_beats_threshold(matching_score, self._matching_threshold): # Match found, increment true positive count and collect IoU and Dice values labelmap.add_labelmap_entry(pred_label, ref_label) # map label ref_idx to pred_idx return labelmap + @classmethod + def _yaml_repr(cls, node) -> dict: + return { + "matching_metric": node._matching_metric, + "matching_threshold": node._matching_threshold, + "allow_many_to_one": node._allow_many_to_one, + } + class MaximizeMergeMatching(InstanceMatchingAlgorithm): """ @@ -241,8 +247,8 @@ def __init__( Raises: AssertionError: If the specified IoU threshold is not within the valid range. """ - self.matching_metric = matching_metric - self.matching_threshold = matching_threshold + self._matching_metric = matching_metric + self._matching_threshold = matching_threshold def _match_instances( self, @@ -274,7 +280,7 @@ def _match_instances( prediction_arr=pred_arr, reference_arr=ref_arr, ref_labels=ref_labels, - matching_metric=self.matching_metric, + matching_metric=self._matching_metric, ) # Loop through matched instances to compute PQ components @@ -290,9 +296,7 @@ def _match_instances( if new_score > score_ref[ref_label]: labelmap.add_labelmap_entry(pred_label, ref_label) score_ref[ref_label] = new_score - elif self.matching_metric.score_beats_threshold( - matching_score, self.matching_threshold - ): + elif self._matching_metric.score_beats_threshold(matching_score, self._matching_threshold): # Match found, increment true positive count and collect IoU and Dice values labelmap.add_labelmap_entry(pred_label, ref_label) score_ref[ref_label] = matching_score @@ -307,7 +311,7 @@ def new_combination_score( unmatched_instance_pair: UnmatchedInstancePair, ): pred_labels.append(new_pred_label) - score = self.matching_metric( + score = self._matching_metric( unmatched_instance_pair.reference_arr, prediction_arr=unmatched_instance_pair.prediction_arr, ref_instance_idx=ref_label, @@ -315,6 +319,13 @@ def new_combination_score( ) return score + @classmethod + def _yaml_repr(cls, node) -> dict: + return { + "matching_metric": node._matching_metric, + "matching_threshold": node._matching_threshold, + } + class MatchUntilConvergenceMatching(InstanceMatchingAlgorithm): # Match like the naive matcher (so each to their best reference) and then again and again until no overlapping labels are left diff --git a/panoptica/panoptica_evaluator.py b/panoptica/panoptica_evaluator.py index 810bd9b..7423592 100644 --- a/panoptica/panoptica_evaluator.py +++ b/panoptica/panoptica_evaluator.py @@ -14,16 +14,18 @@ SemanticPair, UnmatchedInstancePair, _ProcessingPair, + InputType, ) +import numpy as np +from panoptica.utils.config import SupportsConfig from panoptica.utils.segmentation_class import SegmentationClassGroups, LabelGroup -class Panoptica_Evaluator: +class Panoptica_Evaluator(SupportsConfig): def __init__( self, - # TODO let users give prediction and reference arr instead of the processing pair, so let this create the processing pair itself - expected_input: Type[SemanticPair] | Type[UnmatchedInstancePair] | Type[MatchedInstancePair] = MatchedInstancePair, + expected_input: InputType = InputType.MATCHED_INSTANCE, instance_approximator: InstanceApproximator | None = None, instance_matcher: InstanceMatchingAlgorithm | None = None, edge_case_handler: EdgeCaseHandler | None = None, @@ -37,7 +39,7 @@ def __init__( """Creates a Panoptica_Evaluator, that saves some parameters to be used for all subsequent evaluations Args: - expected_input (type, optional): Expected DataPair Input. Defaults to type(MatchedInstancePair). + expected_input (type, optional): Expected DataPair Input Type. Defaults to InputType.MATCHED_INSTANCE (which is type(MatchedInstancePair)). instance_approximator (InstanceApproximator | None, optional): Determines which instance approximator is used if necessary. Defaults to None. instance_matcher (InstanceMatchingAlgorithm | None, optional): Determines which instance matching algorithm is used if necessary. Defaults to None. iou_threshold (float, optional): Iou Threshold for evaluation. Defaults to 0.5. @@ -59,15 +61,32 @@ def __init__( self.__log_times = log_times self.__verbose = verbose + @classmethod + def _yaml_repr(cls, node) -> dict: + return { + "expected_input": node.__expected_input, + "instance_approximator": node.__instance_approximator, + "instance_matcher": node.__instance_matcher, + "edge_case_handler": node.__edge_case_handler, + "segmentation_class_groups": node.__segmentation_class_groups, + "eval_metrics": node.__eval_metrics, + "decision_metric": node.__decision_metric, + "decision_threshold": node.__decision_threshold, + "log_times": node.__log_times, + "verbose": node.__verbose, + } + @citation_reminder @measure_time def evaluate( self, - processing_pair: SemanticPair | UnmatchedInstancePair | MatchedInstancePair | PanopticaResult, + prediction_arr: np.ndarray, + reference_arr: np.ndarray, result_all: bool = True, verbose: bool | None = None, ) -> dict[str, tuple[PanopticaResult, dict[str, _ProcessingPair]]]: - assert type(processing_pair) == self.__expected_input, f"input not of expected type {self.__expected_input}" + processing_pair = self.__expected_input(prediction_arr, reference_arr) + assert isinstance(processing_pair, self.__expected_input.value), f"input not of expected type {self.__expected_input}" if self.__segmentation_class_groups is None: return { diff --git a/panoptica/panoptica_result.py b/panoptica/panoptica_result.py index c037e85..b831e04 100644 --- a/panoptica/panoptica_result.py +++ b/panoptica/panoptica_result.py @@ -43,7 +43,7 @@ def __init__( edge_case_handler (EdgeCaseHandler): EdgeCaseHandler object that handles various forms of edge cases """ self._edge_case_handler = edge_case_handler - empty_list_std = self._edge_case_handler.handle_empty_list_std() + empty_list_std = self._edge_case_handler.handle_empty_list_std().value self._prediction_arr = prediction_arr self._reference_arr = reference_arr ###################### diff --git a/panoptica/utils/__init__.py b/panoptica/utils/__init__.py index d4dfe79..0b72f2a 100644 --- a/panoptica/utils/__init__.py +++ b/panoptica/utils/__init__.py @@ -3,11 +3,11 @@ _unique_without_zeros, ) from panoptica.utils.processing_pair import ( - InstanceLabelMap, MatchedInstancePair, SemanticPair, UnmatchedInstancePair, ) +from panoptica.utils.instancelabelmap import InstanceLabelMap from panoptica.utils.edge_case_handling import ( EdgeCaseHandler, EdgeCaseResult, diff --git a/panoptica/utils/config.py b/panoptica/utils/config.py index 56c5111..3afc165 100644 --- a/panoptica/utils/config.py +++ b/panoptica/utils/config.py @@ -1,6 +1,7 @@ from ruamel.yaml import YAML from pathlib import Path from panoptica.utils.filepath import config_by_name +from abc import ABC, abstractmethod supported_helper_classes = [] @@ -28,9 +29,11 @@ def _save_yaml(data_dict: dict | object, out_file: str | Path, registered_class= yaml = YAML(typ="safe") yaml.default_flow_style = None + yaml.representer.ignore_aliases = lambda *data: True + _register_helper_classes(yaml) if registered_class is not None: yaml.register_class(registered_class) - _register_helper_classes(yaml) + assert isinstance(data_dict, registered_class) # if isinstance(data_dict, object): yaml.dump(data_dict, out_file) # else: @@ -130,6 +133,9 @@ def _save_to_config(obj, path: str | Path): class SupportsConfig: """Metaclass that allows a class to save and load objects by yaml configs""" + def __init__(self) -> None: + raise NotImplementedError(f"Tried to instantiate abstract class {type(self)}") + def __init_subclass__(cls, **kwargs): # Registers all subclasses of this super().__init_subclass__(**kwargs) @@ -141,11 +147,15 @@ def _register_permanently(cls): @classmethod def load_from_config(cls, path: str | Path): - return _load_from_config(cls, path) + obj = _load_from_config(cls, path) + assert isinstance(obj, cls), f"loaded object was not of the correct class, expected {cls.__name__} but got {type(obj)}" + return obj @classmethod def load_from_config_name(cls, name: str): - return _load_from_config_name(cls, name) + obj = _load_from_config_name(cls, name) + assert isinstance(obj, cls) + return obj def save_to_config(self, path: str | Path): _save_to_config(self, path) @@ -161,3 +171,8 @@ def from_yaml(cls, constructor, node): # cls._register_permanently() data = constructor.construct_mapping(node, deep=True) return cls(**data) + + @classmethod + @abstractmethod + def _yaml_repr(cls, node) -> dict: + pass # return {"groups": node.__group_dictionary} diff --git a/panoptica/utils/constants.py b/panoptica/utils/constants.py index 8c2fc80..2889ad6 100644 --- a/panoptica/utils/constants.py +++ b/panoptica/utils/constants.py @@ -1,12 +1,25 @@ from enum import Enum, auto from panoptica.utils.config import _register_class_to_yaml, _load_from_config, _load_from_config_name, _save_to_config from pathlib import Path +import numpy as np class _Enum_Compare(Enum): def __eq__(self, __value: object) -> bool: if isinstance(__value, Enum): - return self.name == __value.name and self.value == __value.value + namecheck = self.name == __value.name + if not namecheck: + return False + if self.value is None: + return __value.value is None + + try: + if np.isnan(self.value): + return np.isnan(__value.value) + except Exception: + pass + + return self.value == __value.value elif isinstance(__value, str): return self.name == __value else: diff --git a/panoptica/utils/edge_case_handling.py b/panoptica/utils/edge_case_handling.py index 7a02fcc..dfe2ca2 100644 --- a/panoptica/utils/edge_case_handling.py +++ b/panoptica/utils/edge_case_handling.py @@ -23,48 +23,69 @@ def __hash__(self) -> int: class MetricZeroTPEdgeCaseHandling(SupportsConfig): + def __init__( self, - default_result: EdgeCaseResult, + default_result: EdgeCaseResult | None = None, no_instances_result: EdgeCaseResult | None = None, empty_prediction_result: EdgeCaseResult | None = None, empty_reference_result: EdgeCaseResult | None = None, normal: EdgeCaseResult | None = None, ) -> None: - self.edgecase_dict: dict[EdgeCaseZeroTP, EdgeCaseResult] = {} - self.edgecase_dict[EdgeCaseZeroTP.EMPTY_PRED] = empty_prediction_result if empty_prediction_result is not None else default_result - self.edgecase_dict[EdgeCaseZeroTP.EMPTY_REF] = empty_reference_result if empty_reference_result is not None else default_result - self.edgecase_dict[EdgeCaseZeroTP.NO_INSTANCES] = no_instances_result if no_instances_result is not None else default_result - self.edgecase_dict[EdgeCaseZeroTP.NORMAL] = normal if normal is not None else default_result + assert default_result is not None or ( + no_instances_result is not None + and empty_prediction_result is not None + and empty_reference_result is not None + and normal is not None + ), "default_result is None and the rest is not fully specified" + + self._default_result = default_result + self._edgecase_dict: dict[EdgeCaseZeroTP, EdgeCaseResult] = {} + self._edgecase_dict[EdgeCaseZeroTP.EMPTY_PRED] = empty_prediction_result if empty_prediction_result is not None else default_result + self._edgecase_dict[EdgeCaseZeroTP.EMPTY_REF] = empty_reference_result if empty_reference_result is not None else default_result + self._edgecase_dict[EdgeCaseZeroTP.NO_INSTANCES] = no_instances_result if no_instances_result is not None else default_result + self._edgecase_dict[EdgeCaseZeroTP.NORMAL] = normal if normal is not None else default_result def __call__(self, tp: int, num_pred_instances, num_ref_instances) -> tuple[bool, float | None]: if tp != 0: return False, EdgeCaseResult.NONE.value # elif num_pred_instances + num_ref_instances == 0: - return True, self.edgecase_dict[EdgeCaseZeroTP.NO_INSTANCES].value + return True, self._edgecase_dict[EdgeCaseZeroTP.NO_INSTANCES].value elif num_ref_instances == 0: - return True, self.edgecase_dict[EdgeCaseZeroTP.EMPTY_REF].value + return True, self._edgecase_dict[EdgeCaseZeroTP.EMPTY_REF].value elif num_pred_instances == 0: - return True, self.edgecase_dict[EdgeCaseZeroTP.EMPTY_PRED].value + return True, self._edgecase_dict[EdgeCaseZeroTP.EMPTY_PRED].value elif num_pred_instances > 0 and num_ref_instances > 0: - return True, self.edgecase_dict[EdgeCaseZeroTP.NORMAL].value + return True, self._edgecase_dict[EdgeCaseZeroTP.NORMAL].value raise NotImplementedError( f"MetricZeroTPEdgeCaseHandling: couldn't handle case, got tp {tp}, n_pred_instances {num_pred_instances}, n_ref_instances {num_ref_instances}" ) + def __eq__(self, __value: object) -> bool: + if isinstance(__value, MetricZeroTPEdgeCaseHandling): + for s, k in self._edgecase_dict.items(): + if s not in __value._edgecase_dict or k != __value._edgecase_dict[s]: + return False + return True + return False + def __str__(self) -> str: txt = "" - for k, v in self.edgecase_dict.items(): + for k, v in self._edgecase_dict.items(): if v is not None: txt += str(k) + ": " + str(v) + "\n" return txt @classmethod - def _yaml_repr(cls, node): - # TODO - return {"value_labels": node.value_labels, "single_instance": node.single_instance} + def _yaml_repr(cls, node) -> dict: + return { + "no_instances_result": node._edgecase_dict[EdgeCaseZeroTP.NO_INSTANCES], + "empty_prediction_result": node._edgecase_dict[EdgeCaseZeroTP.EMPTY_PRED], + "empty_reference_result": node._edgecase_dict[EdgeCaseZeroTP.EMPTY_REF], + "normal": node._edgecase_dict[EdgeCaseZeroTP.NORMAL], + } class EdgeCaseHandler(SupportsConfig): @@ -117,11 +138,15 @@ def handle_zero_tp( num_ref_instances=num_ref_instances, ) + @property + def listmetric_zeroTP_handling(self): + return self.__listmetric_zeroTP_handling + def get_metric_zero_tp_handle(self, metric: Metric): return self.__listmetric_zeroTP_handling[metric] - def handle_empty_list_std(self) -> float | None: - return self.__empty_list_std.value + def handle_empty_list_std(self) -> EdgeCaseResult | None: + return self.__empty_list_std def __str__(self) -> str: txt = f"EdgeCaseHandler:\n - Standard Deviation of Empty = {self.__empty_list_std}" @@ -129,6 +154,13 @@ def __str__(self) -> str: txt += f"\n- {k}: {str(v)}" return str(txt) + @classmethod + def _yaml_repr(cls, node) -> dict: + return { + "listmetric_zeroTP_handling": node.__listmetric_zeroTP_handling, + "empty_list_std": node.__empty_list_std, + } + if __name__ == "__main__": handler = EdgeCaseHandler() diff --git a/panoptica/utils/instancelabelmap.py b/panoptica/utils/instancelabelmap.py new file mode 100644 index 0000000..7480875 --- /dev/null +++ b/panoptica/utils/instancelabelmap.py @@ -0,0 +1,64 @@ +import numpy as np + + +# Many-to-One Mapping +class InstanceLabelMap(object): + # Mapping ((prediction_label, ...), (reference_label, ...)) + labelmap: dict[int, int] + + def __init__(self) -> None: + self.labelmap = {} + + def add_labelmap_entry(self, pred_labels: list[int] | int, ref_label: int): + if not isinstance(pred_labels, list): + pred_labels = [pred_labels] + assert isinstance(ref_label, int), "add_labelmap_entry: got no int as ref_label" + assert np.all([isinstance(r, int) for r in pred_labels]), "add_labelmap_entry: got no int as pred_label" + for p in pred_labels: + if p in self.labelmap and self.labelmap[p] != ref_label: + raise Exception( + f"You are mapping a prediction label to a reference label that was already assigned differently, got {self.__str__} and you tried {pred_labels}, {ref_label}" + ) + self.labelmap[p] = ref_label + + def get_pred_labels_matched_to_ref(self, ref_label: int): + return [k for k, v in self.labelmap.items() if v == ref_label] + + def contains_pred(self, pred_label: int): + return pred_label in self.labelmap + + def contains_ref(self, ref_label: int): + return ref_label in self.labelmap.values() + + def contains_and(self, pred_label: int | None = None, ref_label: int | None = None) -> bool: + pred_in = True if pred_label is None else pred_label in self.labelmap + ref_in = True if ref_label is None else ref_label in self.labelmap.values() + return pred_in and ref_in + + def contains_or(self, pred_label: int | None = None, ref_label: int | None = None) -> bool: + pred_in = True if pred_label is None else pred_label in self.labelmap + ref_in = True if ref_label is None else ref_label in self.labelmap.values() + return pred_in or ref_in + + def get_one_to_one_dictionary(self): + return self.labelmap + + def __str__(self) -> str: + return str( + list( + [ + str(tuple(k for k in self.labelmap.keys() if self.labelmap[k] == v)) + " -> " + str(v) + for v in set(self.labelmap.values()) + ] + ) + ) + + def __repr__(self) -> str: + return str(self) + + # Make all variables read-only! + def __setattr__(self, attr, value): + if hasattr(self, attr): + raise Exception("Attempting to alter read-only value") + + self.__dict__[attr] = value diff --git a/panoptica/utils/processing_pair.py b/panoptica/utils/processing_pair.py index 5ed1a7c..938a8bc 100644 --- a/panoptica/utils/processing_pair.py +++ b/panoptica/utils/processing_pair.py @@ -4,6 +4,7 @@ from panoptica._functionals import _get_paired_crop from panoptica.utils import _count_unique_without_zeros, _unique_without_zeros +from panoptica.utils.constants import _Enum_Compare uint_type: type = np.unsignedinteger int_type: type = np.integer @@ -314,83 +315,10 @@ def copy(self): ) -# Many-to-One Mapping -class InstanceLabelMap(object): - # Mapping ((prediction_label, ...), (reference_label, ...)) - labelmap: dict[int, int] - - def __init__(self) -> None: - self.labelmap = {} - - def add_labelmap_entry(self, pred_labels: list[int] | int, ref_label: int): - if not isinstance(pred_labels, list): - pred_labels = [pred_labels] - assert isinstance(ref_label, int), "add_labelmap_entry: got no int as ref_label" - assert np.all( - [isinstance(r, int) for r in pred_labels] - ), "add_labelmap_entry: got no int as pred_label" - for p in pred_labels: - if p in self.labelmap and self.labelmap[p] != ref_label: - raise Exception( - f"You are mapping a prediction label to a reference label that was already assigned differently, got {self.__str__} and you tried {pred_labels}, {ref_label}" - ) - self.labelmap[p] = ref_label - - def get_pred_labels_matched_to_ref(self, ref_label: int): - return [k for k, v in self.labelmap.items() if v == ref_label] - - def contains_pred(self, pred_label: int): - return pred_label in self.labelmap - - def contains_ref(self, ref_label: int): - return ref_label in self.labelmap.values() - - def contains_and( - self, pred_label: int | None = None, ref_label: int | None = None - ) -> bool: - pred_in = True if pred_label is None else pred_label in self.labelmap - ref_in = True if ref_label is None else ref_label in self.labelmap.values() - return pred_in and ref_in - - def contains_or( - self, pred_label: int | None = None, ref_label: int | None = None - ) -> bool: - pred_in = True if pred_label is None else pred_label in self.labelmap - ref_in = True if ref_label is None else ref_label in self.labelmap.values() - return pred_in or ref_in - - def get_one_to_one_dictionary(self): - return self.labelmap - - def __str__(self) -> str: - return str( - list( - [ - str(tuple(k for k in self.labelmap.keys() if self.labelmap[k] == v)) - + " -> " - + str(v) - for v in set(self.labelmap.values()) - ] - ) - ) - - def __repr__(self) -> str: - return str(self) - - # Make all variables read-only! - def __setattr__(self, attr, value): - if hasattr(self, attr): - raise Exception("Attempting to alter read-only value") - - self.__dict__[attr] = value - - -if __name__ == "__main__": - n = np.zeros([50, 50], dtype=np.int32) - a = SemanticPair(n, n) - print(a) - # print(a.prediction_arr) +class InputType(_Enum_Compare): + SEMANTIC = SemanticPair + UNMATCHED_INSTANCE = UnmatchedInstancePair + MATCHED_INSTANCE = MatchedInstancePair - map = InstanceLabelMap() - map.labelmap = {2: 3, 3: 3, 4: 6} - print(map) + def __call__(self, prediction_arr: np.ndarray, reference_arr: np.ndarray) -> _ProcessingPair: + return self.value(prediction_arr, reference_arr) diff --git a/panoptica/utils/segmentation_class.py b/panoptica/utils/segmentation_class.py index 678822d..24e7568 100644 --- a/panoptica/utils/segmentation_class.py +++ b/panoptica/utils/segmentation_class.py @@ -31,8 +31,6 @@ def __init__( if len(duplicates) > 0: raise AssertionError(f"The same label was assigned to two different labelgroups, got {str(self)}") self.__labels = labels - SegmentationClassGroups._register_permanently() - LabelGroup._register_permanently() def has_defined_labels_for(self, arr: np.ndarray | list[int], raise_error: bool = False): if isinstance(arr, list): diff --git a/unit_tests/test_config.py b/unit_tests/test_config.py index 189ca24..5d516db 100644 --- a/unit_tests/test_config.py +++ b/unit_tests/test_config.py @@ -7,13 +7,18 @@ from panoptica.metrics import ( Metric, + _Metric, Evaluation_List_Metric, MetricMode, MetricCouldNotBeComputedException, ) from panoptica.utils.segmentation_class import SegmentationClassGroups, LabelGroup from panoptica.utils.constants import CCABackend +from panoptica.utils.edge_case_handling import EdgeCaseResult, EdgeCaseZeroTP, MetricZeroTPEdgeCaseHandling, EdgeCaseHandler +from panoptica import ConnectedComponentsInstanceApproximator, NaiveThresholdMatching from pathlib import Path +import numpy as np +import random test_file = Path(__file__).parent.joinpath("test.yaml") @@ -33,6 +38,32 @@ def test_enum_config(self): self.assertEqual(a, b) + def test_enum_config_all(self): + for enum in [CCABackend, EdgeCaseZeroTP, EdgeCaseResult, Metric]: + for a in enum: + a.save_to_config(test_file) + print(a) + b = enum.load_from_config(test_file) + print(b) + os.remove(test_file) + + self.assertEqual(a, b) + self.assertEqual(a.name, b.name) + self.assertEqual(str(a), str(b)) + + # check for value equality + if a.value is None: + self.assertTrue(b.value is None) + else: + # if it is _Metric object, just check name + if not isinstance(a.value, _Metric): + if np.isnan(a.value): + self.assertTrue(np.isnan(b.value)) + else: + self.assertEqual(a.value, b.value) + else: + self.assertEqual(a.value.name, b.value.name) + def test_SegmentationClassGroups_config(self): e = { "groups": { @@ -52,3 +83,62 @@ def test_SegmentationClassGroups_config(self): for k, v in d.items(): self.assertEqual(t[k].single_instance, v.single_instance) self.assertEqual(len(t[k].value_labels), len(v.value_labels)) + + def test_InstanceApproximator_config(self): + for backend in [None, CCABackend.cc3d, CCABackend.scipy]: + t = ConnectedComponentsInstanceApproximator(cca_backend=backend) + print(t) + print() + t.save_to_config(test_file) + d: ConnectedComponentsInstanceApproximator = ConnectedComponentsInstanceApproximator.load_from_config(test_file) + os.remove(test_file) + + self.assertEqual(d.cca_backend, t.cca_backend) + + def test_NaiveThresholdMatching_config(self): + for mm in [Metric.DSC, Metric.IOU, Metric.ASSD]: + for mt in [0.1, 0.4, 0.5, 0.8, 1.0]: + for amto in [False, True]: + t = NaiveThresholdMatching(matching_metric=mm, matching_threshold=mt, allow_many_to_one=amto) + print(t) + print() + t.save_to_config(test_file) + d: NaiveThresholdMatching = NaiveThresholdMatching.load_from_config(test_file) + os.remove(test_file) + + self.assertEqual(d._allow_many_to_one, t._allow_many_to_one) + self.assertEqual(d._matching_metric, t._matching_metric) + self.assertEqual(d._matching_threshold, t._matching_threshold) + + def test_MetricZeroTPEdgeCaseHandling_config(self): + for iter in range(10): + args = [random.choice(list(EdgeCaseResult)) for i in range(5)] + + t = MetricZeroTPEdgeCaseHandling(*args) + print(t) + print() + t.save_to_config(test_file) + d: MetricZeroTPEdgeCaseHandling = MetricZeroTPEdgeCaseHandling.load_from_config(test_file) + os.remove(test_file) + + for k, v in t._edgecase_dict.items(): + self.assertEqual(v, d._edgecase_dict[k]) + # self.assertEqual(d.cca_backend, t.cca_backend) + + def test_EdgeCaseHandler_config(self): + t = EdgeCaseHandler() + print(t) + print() + t.save_to_config(test_file) + d: EdgeCaseHandler = EdgeCaseHandler.load_from_config(test_file) + # os.remove(test_file) + + self.assertEqual(t.handle_empty_list_std(), d.handle_empty_list_std()) + for k, v in t.listmetric_zeroTP_handling.items(): + # v is dict[Metric, MetricZeroTPEdgeCaseHandling] + v2 = d.listmetric_zeroTP_handling[k] + + print(v) + print(v2) + + self.assertEqual(v, v2) diff --git a/unit_tests/test_datatype.py b/unit_tests/test_datatype.py index 4d6b287..c892aae 100644 --- a/unit_tests/test_datatype.py +++ b/unit_tests/test_datatype.py @@ -11,6 +11,7 @@ MetricMode, MetricCouldNotBeComputedException, ) +from panoptica.utils.edge_case_handling import EdgeCaseResult class Test_Datatypes(unittest.TestCase): @@ -29,6 +30,10 @@ def test_metrics_enum(self): self.assertNotEqual(Metric.DSC, Metric.IOU) self.assertNotEqual(Metric.DSC, "IOU") + def test_EdgeCaseResult_enum(self): + for e in EdgeCaseResult: + self.assertEqual(e, e) + def test_matching_metric(self): dsc_metric = Metric.DSC From 146d753505a485b919e298818de76d71508447be Mon Sep 17 00:00:00 2001 From: iback Date: Mon, 5 Aug 2024 11:38:57 +0000 Subject: [PATCH 05/13] updated unittests to new API --- unit_tests/test_panoptic_evaluator.py | 84 +++++++++------------------ 1 file changed, 29 insertions(+), 55 deletions(-) diff --git a/unit_tests/test_panoptic_evaluator.py b/unit_tests/test_panoptic_evaluator.py index 276e195..5fe9c0e 100644 --- a/unit_tests/test_panoptic_evaluator.py +++ b/unit_tests/test_panoptic_evaluator.py @@ -7,6 +7,7 @@ import numpy as np +from panoptica import InputType from panoptica.instance_approximator import ConnectedComponentsInstanceApproximator from panoptica.instance_matcher import MaximizeMergeMatching, NaiveThresholdMatching from panoptica.metrics import Metric @@ -27,15 +28,13 @@ def test_simple_evaluation(self): a[20:40, 10:20] = 1 b[20:35, 10:20] = 2 - sample = SemanticPair(b, a) - evaluator = Panoptica_Evaluator( - expected_input=SemanticPair, + expected_input=InputType.SEMANTIC, instance_approximator=ConnectedComponentsInstanceApproximator(), instance_matcher=NaiveThresholdMatching(), ) - result, debug_data = evaluator.evaluate(sample)["ungrouped"] + result, debug_data = evaluator.evaluate(b, a)["ungrouped"] print(result) self.assertEqual(result.tp, 1) self.assertEqual(result.fp, 0) @@ -48,15 +47,13 @@ def test_simple_evaluation_DSC(self): a[20:40, 10:20] = 1 b[20:35, 10:20] = 2 - sample = SemanticPair(b, a) - evaluator = Panoptica_Evaluator( - expected_input=SemanticPair, + expected_input=InputType.SEMANTIC, instance_approximator=ConnectedComponentsInstanceApproximator(), instance_matcher=NaiveThresholdMatching(), ) - result, debug_data = evaluator.evaluate(sample)["ungrouped"] + result, debug_data = evaluator.evaluate(b, a)["ungrouped"] print(result) self.assertEqual(result.tp, 1) self.assertEqual(result.fp, 0) @@ -69,16 +66,14 @@ def test_simple_evaluation_DSC_partial(self): a[20:40, 10:20] = 1 b[20:35, 10:20] = 2 - sample = SemanticPair(b, a) - evaluator = Panoptica_Evaluator( - expected_input=SemanticPair, + expected_input=InputType.SEMANTIC, instance_approximator=ConnectedComponentsInstanceApproximator(), instance_matcher=NaiveThresholdMatching(matching_metric=Metric.DSC), eval_metrics=[Metric.DSC], ) - result, debug_data = evaluator.evaluate(sample)["ungrouped"] + result, debug_data = evaluator.evaluate(b, a)["ungrouped"] print(result) self.assertEqual(result.tp, 1) self.assertEqual(result.fp, 0) @@ -95,10 +90,8 @@ def test_simple_evaluation_ASSD(self): a[20:40, 10:20] = 1 b[20:35, 10:20] = 2 - sample = SemanticPair(b, a) - evaluator = Panoptica_Evaluator( - expected_input=SemanticPair, + expected_input=InputType.SEMANTIC, instance_approximator=ConnectedComponentsInstanceApproximator(), instance_matcher=NaiveThresholdMatching( matching_metric=Metric.ASSD, @@ -106,7 +99,7 @@ def test_simple_evaluation_ASSD(self): ), ) - result, debug_data = evaluator.evaluate(sample)["ungrouped"] + result, debug_data = evaluator.evaluate(b, a)["ungrouped"] print(result) self.assertEqual(result.tp, 1) self.assertEqual(result.fp, 0) @@ -119,10 +112,8 @@ def test_simple_evaluation_ASSD_negative(self): a[20:40, 10:20] = 1 b[20:35, 10:20] = 2 - sample = SemanticPair(b, a) - evaluator = Panoptica_Evaluator( - expected_input=SemanticPair, + expected_input=InputType.SEMANTIC, instance_approximator=ConnectedComponentsInstanceApproximator(), instance_matcher=NaiveThresholdMatching( matching_metric=Metric.ASSD, @@ -130,7 +121,7 @@ def test_simple_evaluation_ASSD_negative(self): ), ) - result, debug_data = evaluator.evaluate(sample)["ungrouped"] + result, debug_data = evaluator.evaluate(b, a)["ungrouped"] print(result) self.assertEqual(result.tp, 0) self.assertEqual(result.fp, 1) @@ -144,15 +135,13 @@ def test_pred_empty(self): a[20:40, 10:20] = 1 # b[20:35, 10:20] = 2 - sample = SemanticPair(b, a) - evaluator = Panoptica_Evaluator( - expected_input=SemanticPair, + expected_input=InputType.SEMANTIC, instance_approximator=ConnectedComponentsInstanceApproximator(), instance_matcher=NaiveThresholdMatching(), ) - result, debug_data = evaluator.evaluate(sample)["ungrouped"] + result, debug_data = evaluator.evaluate(b, a)["ungrouped"] print(result) self.assertEqual(result.tp, 0) self.assertEqual(result.fp, 0) @@ -167,15 +156,13 @@ def test_ref_empty(self): # a[20:40, 10:20] = 1 b[20:35, 10:20] = 2 - sample = SemanticPair(b, a) - evaluator = Panoptica_Evaluator( - expected_input=SemanticPair, + expected_input=InputType.SEMANTIC, instance_approximator=ConnectedComponentsInstanceApproximator(), instance_matcher=NaiveThresholdMatching(), ) - result, debug_data = evaluator.evaluate(sample)["ungrouped"] + result, debug_data = evaluator.evaluate(b, a)["ungrouped"] print(result) self.assertEqual(result.tp, 0) self.assertEqual(result.fp, 1) @@ -190,15 +177,13 @@ def test_both_empty(self): # a[20:40, 10:20] = 1 # b[20:35, 10:20] = 2 - sample = SemanticPair(b, a) - evaluator = Panoptica_Evaluator( - expected_input=SemanticPair, + expected_input=InputType.SEMANTIC, instance_approximator=ConnectedComponentsInstanceApproximator(), instance_matcher=NaiveThresholdMatching(), ) - result, debug_data = evaluator.evaluate(sample)["ungrouped"] + result, debug_data = evaluator.evaluate(b, a)["ungrouped"] print(result) self.assertEqual(result.tp, 0) self.assertEqual(result.fp, 0) @@ -228,15 +213,14 @@ def test_dtype_evaluation(self): if da != db: self.assertRaises(AssertionError, SemanticPair, b, a) else: - sample = SemanticPair(b, a) evaluator = Panoptica_Evaluator( - expected_input=SemanticPair, + expected_input=InputType.SEMANTIC, instance_approximator=ConnectedComponentsInstanceApproximator(), instance_matcher=NaiveThresholdMatching(), ) - result, debug_data = evaluator.evaluate(sample)["ungrouped"] + result, debug_data = evaluator.evaluate(b, a)["ungrouped"] print(result) self.assertEqual(result.tp, 1) self.assertEqual(result.fp, 0) @@ -249,15 +233,13 @@ def test_simple_evaluation_maximize_matcher(self): a[20:40, 10:20] = 1 b[20:35, 10:20] = 2 - sample = SemanticPair(b, a) - evaluator = Panoptica_Evaluator( - expected_input=SemanticPair, + expected_input=InputType.SEMANTIC, instance_approximator=ConnectedComponentsInstanceApproximator(), instance_matcher=MaximizeMergeMatching(), ) - result, debug_data = evaluator.evaluate(sample)["ungrouped"] + result, debug_data = evaluator.evaluate(b, a)["ungrouped"] print(result) self.assertEqual(result.tp, 1) self.assertEqual(result.fp, 0) @@ -271,15 +253,13 @@ def test_simple_evaluation_maximize_matcher_overlaptwo(self): b[20:35, 10:20] = 2 b[36:38, 10:20] = 3 - sample = SemanticPair(b, a) - evaluator = Panoptica_Evaluator( - expected_input=SemanticPair, + expected_input=InputType.SEMANTIC, instance_approximator=ConnectedComponentsInstanceApproximator(), instance_matcher=MaximizeMergeMatching(), ) - result, debug_data = evaluator.evaluate(sample)["ungrouped"] + result, debug_data = evaluator.evaluate(b, a)["ungrouped"] print(result) self.assertEqual(result.tp, 1) self.assertEqual(result.fp, 0) @@ -295,15 +275,13 @@ def test_simple_evaluation_maximize_matcher_overlap(self): # match the two above to 1 and the 4 to nothing (FP) b[39:47, 10:20] = 4 - sample = SemanticPair(b, a) - evaluator = Panoptica_Evaluator( - expected_input=SemanticPair, + expected_input=InputType.SEMANTIC, instance_approximator=ConnectedComponentsInstanceApproximator(), instance_matcher=MaximizeMergeMatching(), ) - result, debug_data = evaluator.evaluate(sample)["ungrouped"] + result, debug_data = evaluator.evaluate(b, a)["ungrouped"] print(result) self.assertEqual(result.tp, 1) self.assertEqual(result.fp, 1) @@ -318,16 +296,14 @@ def test_single_instance_mode(self): a[20:40, 10:20] = 5 b[20:35, 10:20] = 5 - sample = SemanticPair(b, a) - evaluator = Panoptica_Evaluator( - expected_input=SemanticPair, + expected_input=InputType.SEMANTIC, instance_approximator=ConnectedComponentsInstanceApproximator(), instance_matcher=NaiveThresholdMatching(), segmentation_class_groups=SegmentationClassGroups({"organ": (5, True)}), ) - result, debug_data = evaluator.evaluate(sample)["organ"] + result, debug_data = evaluator.evaluate(b, a)["organ"] print(result) self.assertEqual(result.tp, 1) self.assertEqual(result.fp, 0) @@ -340,16 +316,14 @@ def test_single_instance_mode_nooverlap(self): a[20:40, 10:20] = 5 b[5:15, 30:50] = 5 - sample = SemanticPair(b, a) - evaluator = Panoptica_Evaluator( - expected_input=SemanticPair, + expected_input=InputType.SEMANTIC, instance_approximator=ConnectedComponentsInstanceApproximator(), instance_matcher=NaiveThresholdMatching(), segmentation_class_groups=SegmentationClassGroups({"organ": (5, True)}), ) - result, debug_data = evaluator.evaluate(sample)["organ"] + result, debug_data = evaluator.evaluate(b, a)["organ"] print(result) self.assertEqual(result.tp, 1) self.assertEqual(result.fp, 0) From 7a090bb154c207626aeafc9715b81a8754f84238 Mon Sep 17 00:00:00 2001 From: iback Date: Mon, 5 Aug 2024 11:40:31 +0000 Subject: [PATCH 06/13] updated pyproject with ruamel.yaml --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 80544da..b0f197c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ connected-components-3d = "^3.12.3" scipy = "^1.7.0" rich = "^13.6.0" scikit-image = "^0.22.0" -ruamel = "0.18.6" +ruamel.yaml = "0.18.6" [tool.poetry.dev-dependencies] pytest = "^6.2.5" From d2432f51fac96bae4bd88e7eb1ca79bf8d14c4e3 Mon Sep 17 00:00:00 2001 From: iback Date: Mon, 5 Aug 2024 11:43:04 +0000 Subject: [PATCH 07/13] updated pyproject with ruamel.yaml --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index b0f197c..40fc840 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ connected-components-3d = "^3.12.3" scipy = "^1.7.0" rich = "^13.6.0" scikit-image = "^0.22.0" -ruamel.yaml = "0.18.6" +"ruamel.yaml" = "^0.18.6" [tool.poetry.dev-dependencies] pytest = "^6.2.5" From d25fabc2290b2cba1f4f0636ee888fd518001e40 Mon Sep 17 00:00:00 2001 From: iback Date: Mon, 5 Aug 2024 11:45:49 +0000 Subject: [PATCH 08/13] removed auxiliary from filepath.py --- panoptica/utils/filepath.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/panoptica/utils/filepath.py b/panoptica/utils/filepath.py index aac9edd..f668b75 100644 --- a/panoptica/utils/filepath.py +++ b/panoptica/utils/filepath.py @@ -2,7 +2,6 @@ import warnings from itertools import chain from pathlib import Path -from auxiliary.turbopath import turbopath def search_path(basepath: str | Path, query: str, verbose: bool = False, suppress: bool = False) -> list[Path]: @@ -29,7 +28,7 @@ def search_path(basepath: str | Path, query: str, verbose: bool = False, suppres # Find config path def config_by_name(name: str) -> Path: - directory = turbopath(__file__).parent.parent + directory = Path(__file__.replace("////", "/").replace("\\\\", "/").replace("//", "/").replace("\\", "/")).parent.parent if not name.endswith(".yaml"): name += ".yaml" p = search_path(directory, query=f"**/{name}", suppress=True) From 3e89dd8f4b94f51159c96a878cf2924afd930c03 Mon Sep 17 00:00:00 2001 From: "brainless-bot[bot]" <153751247+brainless-bot[bot]@users.noreply.github.com> Date: Mon, 5 Aug 2024 11:48:17 +0000 Subject: [PATCH 09/13] Autoformat with black --- examples/example_spine_instance_config.py | 4 ++- examples/example_spine_semantic.py | 4 ++- panoptica/instance_approximator.py | 8 +++-- panoptica/instance_matcher.py | 21 +++++++++---- panoptica/panoptica_evaluator.py | 36 +++++++++++++++++------ panoptica/utils/config.py | 16 +++++++--- panoptica/utils/constants.py | 7 ++++- panoptica/utils/edge_case_handling.py | 36 ++++++++++++++++++----- panoptica/utils/filepath.py | 19 +++++++++--- panoptica/utils/instancelabelmap.py | 16 +++++++--- panoptica/utils/label_group.py | 17 ++++++++--- panoptica/utils/processing_pair.py | 4 ++- panoptica/utils/segmentation_class.py | 22 ++++++++++---- unit_tests/test_config.py | 25 ++++++++++++---- unit_tests/test_labelgroup.py | 4 ++- 15 files changed, 184 insertions(+), 55 deletions(-) diff --git a/examples/example_spine_instance_config.py b/examples/example_spine_instance_config.py index 2855add..56c61e4 100644 --- a/examples/example_spine_instance_config.py +++ b/examples/example_spine_instance_config.py @@ -10,7 +10,9 @@ reference_mask = read_nifti(directory + "/spine_seg/matched_instance/ref.nii.gz") prediction_mask = read_nifti(directory + "/spine_seg/matched_instance/pred.nii.gz") -evaluator = Panoptica_Evaluator.load_from_config_name("panoptica_evaluator_unmatched_instance") +evaluator = Panoptica_Evaluator.load_from_config_name( + "panoptica_evaluator_unmatched_instance" +) with cProfile.Profile() as pr: diff --git a/examples/example_spine_semantic.py b/examples/example_spine_semantic.py index 2793592..a2e5a32 100644 --- a/examples/example_spine_semantic.py +++ b/examples/example_spine_semantic.py @@ -25,7 +25,9 @@ with cProfile.Profile() as pr: if __name__ == "__main__": - result, debug_data = evaluator.evaluate(prediction_mask, reference_mask)["ungrouped"] + result, debug_data = evaluator.evaluate(prediction_mask, reference_mask)[ + "ungrouped" + ] print(result) pr.dump_stats(directory + "/semantic_example.log") diff --git a/panoptica/instance_approximator.py b/panoptica/instance_approximator.py index 957b934..f8e061f 100644 --- a/panoptica/instance_approximator.py +++ b/panoptica/instance_approximator.py @@ -58,7 +58,9 @@ def _approximate_instances( pass def _yaml_repr(cls, node) -> dict: - raise NotImplementedError(f"Tried to get yaml representation of abstract class {cls.__name__}") + raise NotImplementedError( + f"Tried to get yaml representation of abstract class {cls.__name__}" + ) return {} def approximate_instances( @@ -146,7 +148,9 @@ def _approximate_instances( """ cca_backend = self.cca_backend if cca_backend is None: - cca_backend = CCABackend.cc3d if semantic_pair.n_dim >= 3 else CCABackend.scipy + cca_backend = ( + CCABackend.cc3d if semantic_pair.n_dim >= 3 else CCABackend.scipy + ) assert cca_backend is not None empty_prediction = len(semantic_pair._pred_labels) == 0 diff --git a/panoptica/instance_matcher.py b/panoptica/instance_matcher.py index 6c18249..5bedd9e 100644 --- a/panoptica/instance_matcher.py +++ b/panoptica/instance_matcher.py @@ -81,7 +81,9 @@ def match_instances( return map_instance_labels(unmatched_instance_pair.copy(), instance_labelmap) def _yaml_repr(cls, node) -> dict: - raise NotImplementedError(f"Tried to get yaml representation of abstract class {cls.__name__}") + raise NotImplementedError( + f"Tried to get yaml representation of abstract class {cls.__name__}" + ) return {} @@ -199,13 +201,20 @@ def _match_instances( unmatched_instance_pair.prediction_arr, unmatched_instance_pair.reference_arr, ) - mm_pairs = _calc_matching_metric_of_overlapping_labels(pred_arr, ref_arr, ref_labels, matching_metric=self._matching_metric) + mm_pairs = _calc_matching_metric_of_overlapping_labels( + pred_arr, ref_arr, ref_labels, matching_metric=self._matching_metric + ) # Loop through matched instances to compute PQ components for matching_score, (ref_label, pred_label) in mm_pairs: - if labelmap.contains_or(pred_label, ref_label) and not self._allow_many_to_one: + if ( + labelmap.contains_or(pred_label, ref_label) + and not self._allow_many_to_one + ): continue # -> doesnt make speed difference - if self._matching_metric.score_beats_threshold(matching_score, self._matching_threshold): + if self._matching_metric.score_beats_threshold( + matching_score, self._matching_threshold + ): # Match found, increment true positive count and collect IoU and Dice values labelmap.add_labelmap_entry(pred_label, ref_label) # map label ref_idx to pred_idx @@ -296,7 +305,9 @@ def _match_instances( if new_score > score_ref[ref_label]: labelmap.add_labelmap_entry(pred_label, ref_label) score_ref[ref_label] = new_score - elif self._matching_metric.score_beats_threshold(matching_score, self._matching_threshold): + elif self._matching_metric.score_beats_threshold( + matching_score, self._matching_threshold + ): # Match found, increment true positive count and collect IoU and Dice values labelmap.add_labelmap_entry(pred_label, ref_label) score_ref[ref_label] = matching_score diff --git a/panoptica/panoptica_evaluator.py b/panoptica/panoptica_evaluator.py index 7423592..fd5617c 100644 --- a/panoptica/panoptica_evaluator.py +++ b/panoptica/panoptica_evaluator.py @@ -54,9 +54,13 @@ def __init__( self.__segmentation_class_groups = segmentation_class_groups - self.__edge_case_handler = edge_case_handler if edge_case_handler is not None else EdgeCaseHandler() + self.__edge_case_handler = ( + edge_case_handler if edge_case_handler is not None else EdgeCaseHandler() + ) if self.__decision_metric is not None: - assert self.__decision_threshold is not None, "decision metric set but no decision threshold for it" + assert ( + self.__decision_threshold is not None + ), "decision metric set but no decision threshold for it" # self.__log_times = log_times self.__verbose = verbose @@ -86,7 +90,9 @@ def evaluate( verbose: bool | None = None, ) -> dict[str, tuple[PanopticaResult, dict[str, _ProcessingPair]]]: processing_pair = self.__expected_input(prediction_arr, reference_arr) - assert isinstance(processing_pair, self.__expected_input.value), f"input not of expected type {self.__expected_input}" + assert isinstance( + processing_pair, self.__expected_input.value + ), f"input not of expected type {self.__expected_input}" if self.__segmentation_class_groups is None: return { @@ -105,8 +111,12 @@ def evaluate( ) } - self.__segmentation_class_groups.has_defined_labels_for(processing_pair.prediction_arr, raise_error=True) - self.__segmentation_class_groups.has_defined_labels_for(processing_pair.reference_arr, raise_error=True) + self.__segmentation_class_groups.has_defined_labels_for( + processing_pair.prediction_arr, raise_error=True + ) + self.__segmentation_class_groups.has_defined_labels_for( + processing_pair.reference_arr, raise_error=True + ) result_grouped = {} for group_name, label_group in self.__segmentation_class_groups.items(): @@ -118,7 +128,9 @@ def evaluate( single_instance_mode = label_group.single_instance processing_pair_grouped = processing_pair.__class__(prediction_arr=prediction_arr_grouped, reference_arr=reference_arr_grouped) # type: ignore decision_threshold = self.__decision_threshold - if single_instance_mode and not isinstance(processing_pair, MatchedInstancePair): + if single_instance_mode and not isinstance( + processing_pair, MatchedInstancePair + ): processing_pair_grouped = MatchedInstancePair( prediction_arr=processing_pair_grouped.prediction_arr, reference_arr=processing_pair_grouped.reference_arr, @@ -142,7 +154,9 @@ def evaluate( def panoptic_evaluate( - processing_pair: SemanticPair | UnmatchedInstancePair | MatchedInstancePair | PanopticaResult, + processing_pair: ( + SemanticPair | UnmatchedInstancePair | MatchedInstancePair | PanopticaResult + ), instance_approximator: InstanceApproximator | None = None, instance_matcher: InstanceMatchingAlgorithm | None = None, eval_metrics: list[Metric] = [Metric.DSC, Metric.IOU, Metric.ASSD], @@ -198,7 +212,9 @@ def panoptic_evaluate( processing_pair.crop_data() if isinstance(processing_pair, SemanticPair): - assert instance_approximator is not None, "Got SemanticPair but not InstanceApproximator" + assert ( + instance_approximator is not None + ), "Got SemanticPair but not InstanceApproximator" if verbose: print("-- Got SemanticPair, will approximate instances") start = perf_counter() @@ -218,7 +234,9 @@ def panoptic_evaluate( if isinstance(processing_pair, UnmatchedInstancePair): if verbose: print("-- Got UnmatchedInstancePair, will match instances") - assert instance_matcher is not None, "Got UnmatchedInstancePair but not InstanceMatchingAlgorithm" + assert ( + instance_matcher is not None + ), "Got UnmatchedInstancePair but not InstanceMatchingAlgorithm" start = perf_counter() processing_pair = instance_matcher.match_instances( processing_pair, diff --git a/panoptica/utils/config.py b/panoptica/utils/config.py index 3afc165..9545cae 100644 --- a/panoptica/utils/config.py +++ b/panoptica/utils/config.py @@ -71,13 +71,17 @@ def save_from_object(cls, obj: object, file: str | Path): @classmethod def load(cls, file: str | Path, registered_class=None): data = _load_yaml(file, registered_class) - assert isinstance(data, dict), f"The config at {file} is registered to a class. Use load_as_object() instead" + assert isinstance( + data, dict + ), f"The config at {file} is registered to a class. Use load_as_object() instead" return Configuration(data, registered_class=registered_class) @classmethod def load_as_object(cls, file: str | Path, registered_class=None): data = _load_yaml(file, registered_class) - assert not isinstance(data, dict), f"The config at {file} is not registered to a class. Use load() instead" + assert not isinstance( + data, dict + ), f"The config at {file} is not registered to a class. Use load() instead" return data def save(self, out_file: str | Path): @@ -148,7 +152,9 @@ def _register_permanently(cls): @classmethod def load_from_config(cls, path: str | Path): obj = _load_from_config(cls, path) - assert isinstance(obj, cls), f"loaded object was not of the correct class, expected {cls.__name__} but got {type(obj)}" + assert isinstance( + obj, cls + ), f"loaded object was not of the correct class, expected {cls.__name__} but got {type(obj)}" return obj @classmethod @@ -163,7 +169,9 @@ def save_to_config(self, path: str | Path): @classmethod def to_yaml(cls, representer, node): # cls._register_permanently() - assert hasattr(cls, "_yaml_repr"), f"Class {cls.__name__} has no _yaml_repr(cls, node) defined" + assert hasattr( + cls, "_yaml_repr" + ), f"Class {cls.__name__} has no _yaml_repr(cls, node) defined" return representer.represent_mapping("!" + cls.__name__, cls._yaml_repr(node)) @classmethod diff --git a/panoptica/utils/constants.py b/panoptica/utils/constants.py index 2889ad6..ecb7f63 100644 --- a/panoptica/utils/constants.py +++ b/panoptica/utils/constants.py @@ -1,5 +1,10 @@ from enum import Enum, auto -from panoptica.utils.config import _register_class_to_yaml, _load_from_config, _load_from_config_name, _save_to_config +from panoptica.utils.config import ( + _register_class_to_yaml, + _load_from_config, + _load_from_config_name, + _save_to_config, +) from pathlib import Path import numpy as np diff --git a/panoptica/utils/edge_case_handling.py b/panoptica/utils/edge_case_handling.py index dfe2ca2..1a865f0 100644 --- a/panoptica/utils/edge_case_handling.py +++ b/panoptica/utils/edge_case_handling.py @@ -41,12 +41,26 @@ def __init__( self._default_result = default_result self._edgecase_dict: dict[EdgeCaseZeroTP, EdgeCaseResult] = {} - self._edgecase_dict[EdgeCaseZeroTP.EMPTY_PRED] = empty_prediction_result if empty_prediction_result is not None else default_result - self._edgecase_dict[EdgeCaseZeroTP.EMPTY_REF] = empty_reference_result if empty_reference_result is not None else default_result - self._edgecase_dict[EdgeCaseZeroTP.NO_INSTANCES] = no_instances_result if no_instances_result is not None else default_result - self._edgecase_dict[EdgeCaseZeroTP.NORMAL] = normal if normal is not None else default_result + self._edgecase_dict[EdgeCaseZeroTP.EMPTY_PRED] = ( + empty_prediction_result + if empty_prediction_result is not None + else default_result + ) + self._edgecase_dict[EdgeCaseZeroTP.EMPTY_REF] = ( + empty_reference_result + if empty_reference_result is not None + else default_result + ) + self._edgecase_dict[EdgeCaseZeroTP.NO_INSTANCES] = ( + no_instances_result if no_instances_result is not None else default_result + ) + self._edgecase_dict[EdgeCaseZeroTP.NORMAL] = ( + normal if normal is not None else default_result + ) - def __call__(self, tp: int, num_pred_instances, num_ref_instances) -> tuple[bool, float | None]: + def __call__( + self, tp: int, num_pred_instances, num_ref_instances + ) -> tuple[bool, float | None]: if tp != 0: return False, EdgeCaseResult.NONE.value # @@ -117,7 +131,9 @@ def __init__( }, empty_list_std: EdgeCaseResult = EdgeCaseResult.NAN, ) -> None: - self.__listmetric_zeroTP_handling: dict[Metric, MetricZeroTPEdgeCaseHandling] = listmetric_zeroTP_handling + self.__listmetric_zeroTP_handling: dict[ + Metric, MetricZeroTPEdgeCaseHandling + ] = listmetric_zeroTP_handling self.__empty_list_std: EdgeCaseResult = empty_list_std def handle_zero_tp( @@ -130,7 +146,9 @@ def handle_zero_tp( if tp != 0: return False, EdgeCaseResult.NONE.value if metric not in self.__listmetric_zeroTP_handling: - raise NotImplementedError(f"Metric {metric} encountered zero TP, but no edge handling available") + raise NotImplementedError( + f"Metric {metric} encountered zero TP, but no edge handling available" + ) return self.__listmetric_zeroTP_handling[metric]( tp=tp, @@ -167,7 +185,9 @@ def _yaml_repr(cls, node) -> dict: print() # print(handler.get_metric_zero_tp_handle(ListMetric.IOU)) - r = handler.handle_zero_tp(Metric.IOU, tp=0, num_pred_instances=1, num_ref_instances=1) + r = handler.handle_zero_tp( + Metric.IOU, tp=0, num_pred_instances=1, num_ref_instances=1 + ) print(r) iou_test = MetricZeroTPEdgeCaseHandling( diff --git a/panoptica/utils/filepath.py b/panoptica/utils/filepath.py index f668b75..a0d8c25 100644 --- a/panoptica/utils/filepath.py +++ b/panoptica/utils/filepath.py @@ -4,7 +4,9 @@ from pathlib import Path -def search_path(basepath: str | Path, query: str, verbose: bool = False, suppress: bool = False) -> list[Path]: +def search_path( + basepath: str | Path, query: str, verbose: bool = False, suppress: bool = False +) -> list[Path]: """Searches from basepath with query Args: basepath: ground path to look into @@ -16,7 +18,9 @@ def search_path(basepath: str | Path, query: str, verbose: bool = False, suppres All found paths """ basepath = str(basepath) - assert os.path.exists(basepath), f"basepath for search_path() doesnt exist, got {basepath}" + assert os.path.exists( + basepath + ), f"basepath for search_path() doesnt exist, got {basepath}" if not basepath.endswith("/"): basepath += "/" print(f"search_path: in {basepath}{query}") if verbose else None @@ -28,9 +32,16 @@ def search_path(basepath: str | Path, query: str, verbose: bool = False, suppres # Find config path def config_by_name(name: str) -> Path: - directory = Path(__file__.replace("////", "/").replace("\\\\", "/").replace("//", "/").replace("\\", "/")).parent.parent + directory = Path( + __file__.replace("////", "/") + .replace("\\\\", "/") + .replace("//", "/") + .replace("\\", "/") + ).parent.parent if not name.endswith(".yaml"): name += ".yaml" p = search_path(directory, query=f"**/{name}", suppress=True) - assert len(p) == 1, f"Did not find exactly one config yaml with name {name} in directory {directory}, got {p}" + assert ( + len(p) == 1 + ), f"Did not find exactly one config yaml with name {name} in directory {directory}, got {p}" return p[0] diff --git a/panoptica/utils/instancelabelmap.py b/panoptica/utils/instancelabelmap.py index 7480875..16fd33c 100644 --- a/panoptica/utils/instancelabelmap.py +++ b/panoptica/utils/instancelabelmap.py @@ -13,7 +13,9 @@ def add_labelmap_entry(self, pred_labels: list[int] | int, ref_label: int): if not isinstance(pred_labels, list): pred_labels = [pred_labels] assert isinstance(ref_label, int), "add_labelmap_entry: got no int as ref_label" - assert np.all([isinstance(r, int) for r in pred_labels]), "add_labelmap_entry: got no int as pred_label" + assert np.all( + [isinstance(r, int) for r in pred_labels] + ), "add_labelmap_entry: got no int as pred_label" for p in pred_labels: if p in self.labelmap and self.labelmap[p] != ref_label: raise Exception( @@ -30,12 +32,16 @@ def contains_pred(self, pred_label: int): def contains_ref(self, ref_label: int): return ref_label in self.labelmap.values() - def contains_and(self, pred_label: int | None = None, ref_label: int | None = None) -> bool: + def contains_and( + self, pred_label: int | None = None, ref_label: int | None = None + ) -> bool: pred_in = True if pred_label is None else pred_label in self.labelmap ref_in = True if ref_label is None else ref_label in self.labelmap.values() return pred_in and ref_in - def contains_or(self, pred_label: int | None = None, ref_label: int | None = None) -> bool: + def contains_or( + self, pred_label: int | None = None, ref_label: int | None = None + ) -> bool: pred_in = True if pred_label is None else pred_label in self.labelmap ref_in = True if ref_label is None else ref_label in self.labelmap.values() return pred_in or ref_in @@ -47,7 +53,9 @@ def __str__(self) -> str: return str( list( [ - str(tuple(k for k in self.labelmap.keys() if self.labelmap[k] == v)) + " -> " + str(v) + str(tuple(k for k in self.labelmap.keys() if self.labelmap[k] == v)) + + " -> " + + str(v) for v in set(self.labelmap.values()) ] ) diff --git a/panoptica/utils/label_group.py b/panoptica/utils/label_group.py index 6757500..430e2dc 100644 --- a/panoptica/utils/label_group.py +++ b/panoptica/utils/label_group.py @@ -20,12 +20,18 @@ def __init__( """ if isinstance(value_labels, int): value_labels = [value_labels] - assert len(value_labels) >= 1, f"You tried to define a LabelGroup without any specified labels, got {value_labels}" + assert ( + len(value_labels) >= 1 + ), f"You tried to define a LabelGroup without any specified labels, got {value_labels}" self.__value_labels = value_labels - assert np.all([v > 0 for v in self.__value_labels]), f"Given value labels are not >0, got {value_labels}" + assert np.all( + [v > 0 for v in self.__value_labels] + ), f"Given value labels are not >0, got {value_labels}" self.__single_instance = single_instance if self.__single_instance: - assert len(value_labels) == 1, f"single_instance set to True, but got more than one label for this group, got {value_labels}" + assert ( + len(value_labels) == 1 + ), f"single_instance set to True, but got more than one label for this group, got {value_labels}" LabelGroup._register_permanently() @@ -65,7 +71,10 @@ def __repr__(self) -> str: @classmethod def _yaml_repr(cls, node): - return {"value_labels": node.value_labels, "single_instance": node.single_instance} + return { + "value_labels": node.value_labels, + "single_instance": node.single_instance, + } # @classmethod # def to_yaml(cls, representer, node): diff --git a/panoptica/utils/processing_pair.py b/panoptica/utils/processing_pair.py index 938a8bc..c64b1e9 100644 --- a/panoptica/utils/processing_pair.py +++ b/panoptica/utils/processing_pair.py @@ -320,5 +320,7 @@ class InputType(_Enum_Compare): UNMATCHED_INSTANCE = UnmatchedInstancePair MATCHED_INSTANCE = MatchedInstancePair - def __call__(self, prediction_arr: np.ndarray, reference_arr: np.ndarray) -> _ProcessingPair: + def __call__( + self, prediction_arr: np.ndarray, reference_arr: np.ndarray + ) -> _ProcessingPair: return self.value(prediction_arr, reference_arr) diff --git a/panoptica/utils/segmentation_class.py b/panoptica/utils/segmentation_class.py index 24e7568..550b6aa 100644 --- a/panoptica/utils/segmentation_class.py +++ b/panoptica/utils/segmentation_class.py @@ -15,24 +15,36 @@ def __init__( # maps name of group to the group itself if isinstance(groups, list): - self.__group_dictionary = {f"group_{idx}": g for idx, g in enumerate(groups)} + self.__group_dictionary = { + f"group_{idx}": g for idx, g in enumerate(groups) + } elif isinstance(groups, dict): # transform dict into list of LabelGroups for i, g in groups.items(): name_lower = str(i).lower() if isinstance(g, LabelGroup): - self.__group_dictionary[name_lower] = LabelGroup(g.value_labels, g.single_instance) + self.__group_dictionary[name_lower] = LabelGroup( + g.value_labels, g.single_instance + ) else: self.__group_dictionary[name_lower] = LabelGroup(g[0], g[1]) # needs to check that each label is accounted for exactly ONCE - labels = [value_label for lg in self.__group_dictionary.values() for value_label in lg.value_labels] + labels = [ + value_label + for lg in self.__group_dictionary.values() + for value_label in lg.value_labels + ] duplicates = list_duplicates(labels) if len(duplicates) > 0: - raise AssertionError(f"The same label was assigned to two different labelgroups, got {str(self)}") + raise AssertionError( + f"The same label was assigned to two different labelgroups, got {str(self)}" + ) self.__labels = labels - def has_defined_labels_for(self, arr: np.ndarray | list[int], raise_error: bool = False): + def has_defined_labels_for( + self, arr: np.ndarray | list[int], raise_error: bool = False + ): if isinstance(arr, list): arr_labels = arr else: diff --git a/unit_tests/test_config.py b/unit_tests/test_config.py index 5d516db..59c0b46 100644 --- a/unit_tests/test_config.py +++ b/unit_tests/test_config.py @@ -14,7 +14,12 @@ ) from panoptica.utils.segmentation_class import SegmentationClassGroups, LabelGroup from panoptica.utils.constants import CCABackend -from panoptica.utils.edge_case_handling import EdgeCaseResult, EdgeCaseZeroTP, MetricZeroTPEdgeCaseHandling, EdgeCaseHandler +from panoptica.utils.edge_case_handling import ( + EdgeCaseResult, + EdgeCaseZeroTP, + MetricZeroTPEdgeCaseHandling, + EdgeCaseHandler, +) from panoptica import ConnectedComponentsInstanceApproximator, NaiveThresholdMatching from pathlib import Path import numpy as np @@ -90,7 +95,9 @@ def test_InstanceApproximator_config(self): print(t) print() t.save_to_config(test_file) - d: ConnectedComponentsInstanceApproximator = ConnectedComponentsInstanceApproximator.load_from_config(test_file) + d: ConnectedComponentsInstanceApproximator = ( + ConnectedComponentsInstanceApproximator.load_from_config(test_file) + ) os.remove(test_file) self.assertEqual(d.cca_backend, t.cca_backend) @@ -99,11 +106,17 @@ def test_NaiveThresholdMatching_config(self): for mm in [Metric.DSC, Metric.IOU, Metric.ASSD]: for mt in [0.1, 0.4, 0.5, 0.8, 1.0]: for amto in [False, True]: - t = NaiveThresholdMatching(matching_metric=mm, matching_threshold=mt, allow_many_to_one=amto) + t = NaiveThresholdMatching( + matching_metric=mm, + matching_threshold=mt, + allow_many_to_one=amto, + ) print(t) print() t.save_to_config(test_file) - d: NaiveThresholdMatching = NaiveThresholdMatching.load_from_config(test_file) + d: NaiveThresholdMatching = NaiveThresholdMatching.load_from_config( + test_file + ) os.remove(test_file) self.assertEqual(d._allow_many_to_one, t._allow_many_to_one) @@ -118,7 +131,9 @@ def test_MetricZeroTPEdgeCaseHandling_config(self): print(t) print() t.save_to_config(test_file) - d: MetricZeroTPEdgeCaseHandling = MetricZeroTPEdgeCaseHandling.load_from_config(test_file) + d: MetricZeroTPEdgeCaseHandling = ( + MetricZeroTPEdgeCaseHandling.load_from_config(test_file) + ) os.remove(test_file) for k, v in t._edgecase_dict.items(): diff --git a/unit_tests/test_labelgroup.py b/unit_tests/test_labelgroup.py index 292629a..4aee9f3 100644 --- a/unit_tests/test_labelgroup.py +++ b/unit_tests/test_labelgroup.py @@ -82,7 +82,9 @@ def test_segmentationclassgroup_easy(self): self.assertTrue(isinstance(lg, LabelGroup)) def test_segmentationclassgroup_decarations(self): - classgroups = SegmentationClassGroups(groups=[LabelGroup(i) for i in range(1, 5)]) + classgroups = SegmentationClassGroups( + groups=[LabelGroup(i) for i in range(1, 5)] + ) keys = classgroups.keys() for i in range(1, 5): From b4a1131fe3369f0a2d35719cb663f3a1c43a7a9f Mon Sep 17 00:00:00 2001 From: iback Date: Mon, 5 Aug 2024 12:08:12 +0000 Subject: [PATCH 10/13] added save_to_config_by_name --- panoptica/__init__.py | 2 +- panoptica/utils/config.py | 10 +++++++++- panoptica/utils/filepath.py | 9 +++++++-- 3 files changed, 17 insertions(+), 4 deletions(-) diff --git a/panoptica/__init__.py b/panoptica/__init__.py index dca6768..2f02659 100644 --- a/panoptica/__init__.py +++ b/panoptica/__init__.py @@ -6,9 +6,9 @@ from panoptica.panoptica_evaluator import Panoptica_Evaluator from panoptica.panoptica_result import PanopticaResult from panoptica.utils.processing_pair import ( + InputType, SemanticPair, UnmatchedInstancePair, MatchedInstancePair, - InputType, ) from panoptica.metrics import Metric, MetricMode, MetricType diff --git a/panoptica/utils/config.py b/panoptica/utils/config.py index 3afc165..a0ca8ba 100644 --- a/panoptica/utils/config.py +++ b/panoptica/utils/config.py @@ -1,6 +1,6 @@ from ruamel.yaml import YAML from pathlib import Path -from panoptica.utils.filepath import config_by_name +from panoptica.utils.filepath import config_by_name, config_dir_by_name from abc import ABC, abstractmethod supported_helper_classes = [] @@ -130,6 +130,11 @@ def _save_to_config(obj, path: str | Path): Configuration.save_from_object(obj, path) +def _save_to_config_by_name(obj, name: str): + dir, name = config_dir_by_name(name) + _save_to_config(obj, dir.joinpath(name)) + + class SupportsConfig: """Metaclass that allows a class to save and load objects by yaml configs""" @@ -160,6 +165,9 @@ def load_from_config_name(cls, name: str): def save_to_config(self, path: str | Path): _save_to_config(self, path) + def save_to_config_by_name(self, name: str): + _save_to_config_by_name(self, name) + @classmethod def to_yaml(cls, representer, node): # cls._register_permanently() diff --git a/panoptica/utils/filepath.py b/panoptica/utils/filepath.py index f668b75..70467c0 100644 --- a/panoptica/utils/filepath.py +++ b/panoptica/utils/filepath.py @@ -26,11 +26,16 @@ def search_path(basepath: str | Path, query: str, verbose: bool = False, suppres return paths -# Find config path -def config_by_name(name: str) -> Path: +def config_dir_by_name(name: str) -> tuple[Path, str]: directory = Path(__file__.replace("////", "/").replace("\\\\", "/").replace("//", "/").replace("\\", "/")).parent.parent if not name.endswith(".yaml"): name += ".yaml" + return directory, name + + +# Find config path +def config_by_name(name: str) -> Path: + directory, name = config_dir_by_name(name) p = search_path(directory, query=f"**/{name}", suppress=True) assert len(p) == 1, f"Did not find exactly one config yaml with name {name} in directory {directory}, got {p}" return p[0] From a2f5dd361974e4672971bd70874b8929b964b33f Mon Sep 17 00:00:00 2001 From: iback Date: Mon, 5 Aug 2024 12:12:00 +0000 Subject: [PATCH 11/13] added save printout --- panoptica/utils/config.py | 1 + 1 file changed, 1 insertion(+) diff --git a/panoptica/utils/config.py b/panoptica/utils/config.py index 1ad8994..241bc0a 100644 --- a/panoptica/utils/config.py +++ b/panoptica/utils/config.py @@ -40,6 +40,7 @@ def _save_yaml(data_dict: dict | object, out_file: str | Path, registered_class= # yaml.dump([registered_class(*data_dict)], out_file) else: yaml.dump(data_dict, out_file) + print(f"Saved config into {out_file}") #################### From 9ddbc2de5a0a53a28fe6063d0a014aba543dc2f5 Mon Sep 17 00:00:00 2001 From: iback Date: Mon, 5 Aug 2024 12:18:20 +0000 Subject: [PATCH 12/13] updated readme with configs notebook example --- README.md | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/README.md b/README.md index a70f613..ade9ec4 100644 --- a/README.md +++ b/README.md @@ -58,6 +58,16 @@ For this case, the matcher module can be utilized to match instances and the eva If your predicted instances already match the reference instances, you can directly compute metrics using the evaluator module. + +### Using Configs (saving and loading) + +You can construct Panoptica_Evaluator (among many others) objects and save their arguments, so you can save project-specific configurations and use them later. + +[Jupyter notebook tutorial](https://github.com/BrainLesion/tutorials/tree/main/panoptica/example_config.ipynb) + +It uses ruamel.yaml in a readable way. + + ## Citation If you use panoptica in your research, please cite it to support the development! From 52cfe7f5c61d079099a00b059fe0349bcef6532a Mon Sep 17 00:00:00 2001 From: "brainless-bot[bot]" <153751247+brainless-bot[bot]@users.noreply.github.com> Date: Mon, 5 Aug 2024 12:18:53 +0000 Subject: [PATCH 13/13] Autoformat with black --- panoptica/utils/filepath.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/panoptica/utils/filepath.py b/panoptica/utils/filepath.py index 70467c0..68ce65b 100644 --- a/panoptica/utils/filepath.py +++ b/panoptica/utils/filepath.py @@ -4,7 +4,9 @@ from pathlib import Path -def search_path(basepath: str | Path, query: str, verbose: bool = False, suppress: bool = False) -> list[Path]: +def search_path( + basepath: str | Path, query: str, verbose: bool = False, suppress: bool = False +) -> list[Path]: """Searches from basepath with query Args: basepath: ground path to look into @@ -16,7 +18,9 @@ def search_path(basepath: str | Path, query: str, verbose: bool = False, suppres All found paths """ basepath = str(basepath) - assert os.path.exists(basepath), f"basepath for search_path() doesnt exist, got {basepath}" + assert os.path.exists( + basepath + ), f"basepath for search_path() doesnt exist, got {basepath}" if not basepath.endswith("/"): basepath += "/" print(f"search_path: in {basepath}{query}") if verbose else None @@ -27,7 +31,12 @@ def search_path(basepath: str | Path, query: str, verbose: bool = False, suppres def config_dir_by_name(name: str) -> tuple[Path, str]: - directory = Path(__file__.replace("////", "/").replace("\\\\", "/").replace("//", "/").replace("\\", "/")).parent.parent + directory = Path( + __file__.replace("////", "/") + .replace("\\\\", "/") + .replace("//", "/") + .replace("\\", "/") + ).parent.parent if not name.endswith(".yaml"): name += ".yaml" return directory, name @@ -37,5 +46,7 @@ def config_dir_by_name(name: str) -> tuple[Path, str]: def config_by_name(name: str) -> Path: directory, name = config_dir_by_name(name) p = search_path(directory, query=f"**/{name}", suppress=True) - assert len(p) == 1, f"Did not find exactly one config yaml with name {name} in directory {directory}, got {p}" + assert ( + len(p) == 1 + ), f"Did not find exactly one config yaml with name {name} in directory {directory}, got {p}" return p[0]