diff --git a/README.md b/README.md index a70f613..ade9ec4 100644 --- a/README.md +++ b/README.md @@ -58,6 +58,16 @@ For this case, the matcher module can be utilized to match instances and the eva If your predicted instances already match the reference instances, you can directly compute metrics using the evaluator module. + +### Using Configs (saving and loading) + +You can construct Panoptica_Evaluator (among many others) objects and save their arguments, so you can save project-specific configurations and use them later. + +[Jupyter notebook tutorial](https://github.com/BrainLesion/tutorials/tree/main/panoptica/example_config.ipynb) + +It uses ruamel.yaml in a readable way. + + ## Citation If you use panoptica in your research, please cite it to support the development! diff --git a/examples/example_spine_instance.py b/examples/example_spine_instance.py index f10cb34..bcea628 100644 --- a/examples/example_spine_instance.py +++ b/examples/example_spine_instance.py @@ -3,20 +3,17 @@ from auxiliary.nifti.io import read_nifti from auxiliary.turbopath import turbopath -from panoptica import MatchedInstancePair, Panoptica_Evaluator +from panoptica import Panoptica_Evaluator, InputType from panoptica.metrics import Metric from panoptica.utils.segmentation_class import LabelGroup, SegmentationClassGroups directory = turbopath(__file__).parent -ref_masks = read_nifti(directory + "/spine_seg/matched_instance/ref.nii.gz") - -pred_masks = read_nifti(directory + "/spine_seg/matched_instance/pred.nii.gz") - -sample = MatchedInstancePair(prediction_arr=pred_masks, reference_arr=ref_masks) +reference_mask = read_nifti(directory + "/spine_seg/matched_instance/ref.nii.gz") +prediction_mask = read_nifti(directory + "/spine_seg/matched_instance/pred.nii.gz") evaluator = Panoptica_Evaluator( - expected_input=MatchedInstancePair, + expected_input=InputType.MATCHED_INSTANCE, eval_metrics=[Metric.DSC, Metric.IOU], segmentation_class_groups=SegmentationClassGroups( { @@ -28,12 +25,13 @@ ), decision_metric=Metric.DSC, decision_threshold=0.5, + log_times=True, ) with cProfile.Profile() as pr: if __name__ == "__main__": - results = evaluator.evaluate(sample, verbose=False) + results = evaluator.evaluate(prediction_mask, reference_mask, verbose=False) for groupname, (result, debug) in results.items(): print() print("### Group", groupname) diff --git a/examples/example_spine_instance_config.py b/examples/example_spine_instance_config.py new file mode 100644 index 0000000..56c61e4 --- /dev/null +++ b/examples/example_spine_instance_config.py @@ -0,0 +1,26 @@ +import cProfile + +from auxiliary.nifti.io import read_nifti +from auxiliary.turbopath import turbopath + +from panoptica import Panoptica_Evaluator + +directory = turbopath(__file__).parent + +reference_mask = read_nifti(directory + "/spine_seg/matched_instance/ref.nii.gz") +prediction_mask = read_nifti(directory + "/spine_seg/matched_instance/pred.nii.gz") + +evaluator = Panoptica_Evaluator.load_from_config_name( + "panoptica_evaluator_unmatched_instance" +) + + +with cProfile.Profile() as pr: + if __name__ == "__main__": + results = evaluator.evaluate(prediction_mask, reference_mask, verbose=False) + for groupname, (result, debug) in results.items(): + print() + print("### Group", groupname) + print(result) + + pr.dump_stats(directory + "/instance_example.log") diff --git a/examples/example_spine_semantic.py b/examples/example_spine_semantic.py index 7385701..a2e5a32 100644 --- a/examples/example_spine_semantic.py +++ b/examples/example_spine_semantic.py @@ -7,18 +7,17 @@ ConnectedComponentsInstanceApproximator, NaiveThresholdMatching, Panoptica_Evaluator, - SemanticPair, + InputType, ) directory = turbopath(__file__).parent -ref_masks = read_nifti(directory + "/spine_seg/semantic/ref.nii.gz") -pred_masks = read_nifti(directory + "/spine_seg/semantic/pred.nii.gz") +reference_mask = read_nifti(directory + "/spine_seg/semantic/ref.nii.gz") +prediction_mask = read_nifti(directory + "/spine_seg/semantic/pred.nii.gz") -sample = SemanticPair(pred_masks, ref_masks) evaluator = Panoptica_Evaluator( - expected_input=SemanticPair, + expected_input=InputType.SEMANTIC, instance_approximator=ConnectedComponentsInstanceApproximator(), instance_matcher=NaiveThresholdMatching(), verbose=True, @@ -26,7 +25,9 @@ with cProfile.Profile() as pr: if __name__ == "__main__": - result, debug_data = evaluator.evaluate(sample)["ungrouped"] + result, debug_data = evaluator.evaluate(prediction_mask, reference_mask)[ + "ungrouped" + ] print(result) pr.dump_stats(directory + "/semantic_example.log") diff --git a/panoptica/__init__.py b/panoptica/__init__.py index 9a4d080..2f02659 100644 --- a/panoptica/__init__.py +++ b/panoptica/__init__.py @@ -6,6 +6,7 @@ from panoptica.panoptica_evaluator import Panoptica_Evaluator from panoptica.panoptica_result import PanopticaResult from panoptica.utils.processing_pair import ( + InputType, SemanticPair, UnmatchedInstancePair, MatchedInstancePair, diff --git a/panoptica/configs/SegmentationClassGroups_example_unmatchedinstancepair.yaml b/panoptica/configs/SegmentationClassGroups_example_unmatchedinstancepair.yaml new file mode 100644 index 0000000..9c31720 --- /dev/null +++ b/panoptica/configs/SegmentationClassGroups_example_unmatchedinstancepair.yaml @@ -0,0 +1,14 @@ +!SegmentationClassGroups +groups: + endplate: !LabelGroup + single_instance: false + value_labels: [201, 202, 203, 204, 205, 206, 207, 208, 209, 210] + ivd: !LabelGroup + single_instance: false + value_labels: [101, 102, 103, 104, 105, 106, 107, 108, 109, 110] + sacrum: !LabelGroup + single_instance: true + value_labels: [26] + vertebra: !LabelGroup + single_instance: false + value_labels: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] diff --git a/panoptica/configs/panoptica_evaluator_unmatched_instance.yaml b/panoptica/configs/panoptica_evaluator_unmatched_instance.yaml new file mode 100644 index 0000000..ad8d9b0 --- /dev/null +++ b/panoptica/configs/panoptica_evaluator_unmatched_instance.yaml @@ -0,0 +1,42 @@ +!Panoptica_Evaluator +decision_metric: !Metric DSC +decision_threshold: 0.5 +edge_case_handler: !EdgeCaseHandler + empty_list_std: !EdgeCaseResult NAN + listmetric_zeroTP_handling: + !Metric DSC: !MetricZeroTPEdgeCaseHandling {empty_prediction_result: !EdgeCaseResult ZERO, + empty_reference_result: !EdgeCaseResult ZERO, no_instances_result: !EdgeCaseResult NAN, + normal: !EdgeCaseResult ZERO} + !Metric clDSC: !MetricZeroTPEdgeCaseHandling {empty_prediction_result: !EdgeCaseResult ZERO, + empty_reference_result: !EdgeCaseResult ZERO, no_instances_result: !EdgeCaseResult NAN, + normal: !EdgeCaseResult ZERO} + !Metric IOU: !MetricZeroTPEdgeCaseHandling {empty_prediction_result: !EdgeCaseResult ZERO, + empty_reference_result: !EdgeCaseResult ZERO, no_instances_result: !EdgeCaseResult NAN, + normal: !EdgeCaseResult ZERO} + !Metric ASSD: !MetricZeroTPEdgeCaseHandling {empty_prediction_result: !EdgeCaseResult INF, + empty_reference_result: !EdgeCaseResult INF, no_instances_result: !EdgeCaseResult NAN, + normal: !EdgeCaseResult INF} + !Metric RVD: !MetricZeroTPEdgeCaseHandling {empty_prediction_result: !EdgeCaseResult NAN, + empty_reference_result: !EdgeCaseResult NAN, no_instances_result: !EdgeCaseResult NAN, + normal: !EdgeCaseResult NAN} +eval_metrics: [!Metric DSC, !Metric IOU] +expected_input: !InputType UNMATCHED_INSTANCE +instance_approximator: null +instance_matcher: !NaiveThresholdMatching {allow_many_to_one: false, matching_metric: !Metric IOU, + matching_threshold: 0.5} +log_times: true +segmentation_class_groups: !SegmentationClassGroups + groups: + endplate: !LabelGroup + single_instance: false + value_labels: [201, 202, 203, 204, 205, 206, 207, 208, 209, 210] + ivd: !LabelGroup + single_instance: false + value_labels: [101, 102, 103, 104, 105, 106, 107, 108, 109, 110] + sacrum: !LabelGroup + single_instance: true + value_labels: [26] + vertebra: !LabelGroup + single_instance: false + value_labels: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] +verbose: false diff --git a/panoptica/instance_approximator.py b/panoptica/instance_approximator.py index e73f843..f8e061f 100644 --- a/panoptica/instance_approximator.py +++ b/panoptica/instance_approximator.py @@ -1,4 +1,4 @@ -from abc import ABC, abstractmethod +from abc import ABC, abstractmethod, ABCMeta import numpy as np @@ -10,9 +10,10 @@ SemanticPair, UnmatchedInstancePair, ) +from panoptica.utils.config import SupportsConfig -class InstanceApproximator(ABC): +class InstanceApproximator(SupportsConfig, metaclass=ABCMeta): """ Abstract base class for instance approximation algorithms in panoptic segmentation evaluation. @@ -56,6 +57,12 @@ def _approximate_instances( """ pass + def _yaml_repr(cls, node) -> dict: + raise NotImplementedError( + f"Tried to get yaml representation of abstract class {cls.__name__}" + ) + return {} + def approximate_instances( self, semantic_pair: SemanticPair, verbose: bool = False, **kwargs ) -> UnmatchedInstancePair | MatchedInstancePair: @@ -140,7 +147,7 @@ def _approximate_instances( UnmatchedInstancePair: The result of the instance approximation. """ cca_backend = self.cca_backend - if self.cca_backend is None: + if cca_backend is None: cca_backend = ( CCABackend.cc3d if semantic_pair.n_dim >= 3 else CCABackend.scipy ) @@ -164,3 +171,7 @@ def _approximate_instances( n_prediction_instance=n_prediction_instance, n_reference_instance=n_reference_instance, ) + + @classmethod + def _yaml_repr(cls, node) -> dict: + return {"cca_backend": node.cca_backend} diff --git a/panoptica/instance_matcher.py b/panoptica/instance_matcher.py index 3b32e63..5bedd9e 100644 --- a/panoptica/instance_matcher.py +++ b/panoptica/instance_matcher.py @@ -1,4 +1,4 @@ -from abc import ABC, abstractmethod +from abc import ABCMeta, abstractmethod import numpy as np @@ -8,13 +8,14 @@ ) from panoptica.metrics import Metric from panoptica.utils.processing_pair import ( - InstanceLabelMap, MatchedInstancePair, UnmatchedInstancePair, ) +from panoptica.utils.instancelabelmap import InstanceLabelMap +from panoptica.utils.config import SupportsConfig -class InstanceMatchingAlgorithm(ABC): +class InstanceMatchingAlgorithm(SupportsConfig, metaclass=ABCMeta): """ Abstract base class for instance matching algorithms in panoptic segmentation evaluation. @@ -79,6 +80,12 @@ def match_instances( # print("instance_labelmap:", instance_labelmap) return map_instance_labels(unmatched_instance_pair.copy(), instance_labelmap) + def _yaml_repr(cls, node) -> dict: + raise NotImplementedError( + f"Tried to get yaml representation of abstract class {cls.__name__}" + ) + return {} + def map_instance_labels( processing_pair: UnmatchedInstancePair, labelmap: InstanceLabelMap @@ -166,9 +173,9 @@ def __init__( Raises: AssertionError: If the specified IoU threshold is not within the valid range. """ - self.allow_many_to_one = allow_many_to_one - self.matching_metric = matching_metric - self.matching_threshold = matching_threshold + self._allow_many_to_one = allow_many_to_one + self._matching_metric = matching_metric + self._matching_threshold = matching_threshold def _match_instances( self, @@ -195,24 +202,32 @@ def _match_instances( unmatched_instance_pair.reference_arr, ) mm_pairs = _calc_matching_metric_of_overlapping_labels( - pred_arr, ref_arr, ref_labels, matching_metric=self.matching_metric + pred_arr, ref_arr, ref_labels, matching_metric=self._matching_metric ) # Loop through matched instances to compute PQ components for matching_score, (ref_label, pred_label) in mm_pairs: if ( labelmap.contains_or(pred_label, ref_label) - and not self.allow_many_to_one + and not self._allow_many_to_one ): continue # -> doesnt make speed difference - if self.matching_metric.score_beats_threshold( - matching_score, self.matching_threshold + if self._matching_metric.score_beats_threshold( + matching_score, self._matching_threshold ): # Match found, increment true positive count and collect IoU and Dice values labelmap.add_labelmap_entry(pred_label, ref_label) # map label ref_idx to pred_idx return labelmap + @classmethod + def _yaml_repr(cls, node) -> dict: + return { + "matching_metric": node._matching_metric, + "matching_threshold": node._matching_threshold, + "allow_many_to_one": node._allow_many_to_one, + } + class MaximizeMergeMatching(InstanceMatchingAlgorithm): """ @@ -241,8 +256,8 @@ def __init__( Raises: AssertionError: If the specified IoU threshold is not within the valid range. """ - self.matching_metric = matching_metric - self.matching_threshold = matching_threshold + self._matching_metric = matching_metric + self._matching_threshold = matching_threshold def _match_instances( self, @@ -274,7 +289,7 @@ def _match_instances( prediction_arr=pred_arr, reference_arr=ref_arr, ref_labels=ref_labels, - matching_metric=self.matching_metric, + matching_metric=self._matching_metric, ) # Loop through matched instances to compute PQ components @@ -290,8 +305,8 @@ def _match_instances( if new_score > score_ref[ref_label]: labelmap.add_labelmap_entry(pred_label, ref_label) score_ref[ref_label] = new_score - elif self.matching_metric.score_beats_threshold( - matching_score, self.matching_threshold + elif self._matching_metric.score_beats_threshold( + matching_score, self._matching_threshold ): # Match found, increment true positive count and collect IoU and Dice values labelmap.add_labelmap_entry(pred_label, ref_label) @@ -307,7 +322,7 @@ def new_combination_score( unmatched_instance_pair: UnmatchedInstancePair, ): pred_labels.append(new_pred_label) - score = self.matching_metric( + score = self._matching_metric( unmatched_instance_pair.reference_arr, prediction_arr=unmatched_instance_pair.prediction_arr, ref_instance_idx=ref_label, @@ -315,6 +330,13 @@ def new_combination_score( ) return score + @classmethod + def _yaml_repr(cls, node) -> dict: + return { + "matching_metric": node._matching_metric, + "matching_threshold": node._matching_threshold, + } + class MatchUntilConvergenceMatching(InstanceMatchingAlgorithm): # Match like the naive matcher (so each to their best reference) and then again and again until no overlapping labels are left diff --git a/panoptica/metrics/relative_volume_difference.py b/panoptica/metrics/relative_volume_difference.py index 4d952b2..1cde13e 100644 --- a/panoptica/metrics/relative_volume_difference.py +++ b/panoptica/metrics/relative_volume_difference.py @@ -55,16 +55,15 @@ def _compute_relative_volume_difference( prediction (np.ndarray): Prediction binary mask. Returns: - float: Relative volume Error between the two binary masks. A value between 0 and 1, where higher values - indicate better overlap and similarity between masks. + float: Relative volume Error between the two binary masks. A value of zero means perfect volume match, while >0 means oversegmentation and <0 undersegmentation. """ - reference_mask = np.sum(reference) - prediction_mask = np.sum(prediction) + reference_mask = float(np.sum(reference)) + prediction_mask = float(np.sum(prediction)) # Handle division by zero if reference_mask == 0 and prediction_mask == 0: return 0.0 # Calculate Dice coefficient - rvd = float(prediction_mask - reference_mask) / reference_mask + rvd = (prediction_mask - reference_mask) / reference_mask return rvd diff --git a/panoptica/panoptica_evaluator.py b/panoptica/panoptica_evaluator.py index ca59b98..fd5617c 100644 --- a/panoptica/panoptica_evaluator.py +++ b/panoptica/panoptica_evaluator.py @@ -14,18 +14,18 @@ SemanticPair, UnmatchedInstancePair, _ProcessingPair, + InputType, ) +import numpy as np +from panoptica.utils.config import SupportsConfig from panoptica.utils.segmentation_class import SegmentationClassGroups, LabelGroup -class Panoptica_Evaluator: +class Panoptica_Evaluator(SupportsConfig): def __init__( self, - # TODO let users give prediction and reference arr instead of the processing pair, so let this create the processing pair itself - expected_input: ( - Type[SemanticPair] | Type[UnmatchedInstancePair] | Type[MatchedInstancePair] - ) = MatchedInstancePair, + expected_input: InputType = InputType.MATCHED_INSTANCE, instance_approximator: InstanceApproximator | None = None, instance_matcher: InstanceMatchingAlgorithm | None = None, edge_case_handler: EdgeCaseHandler | None = None, @@ -36,10 +36,10 @@ def __init__( log_times: bool = False, verbose: bool = False, ) -> None: - """Creates a Panoptic_Evaluator, that saves some parameters to be used for all subsequent evaluations + """Creates a Panoptica_Evaluator, that saves some parameters to be used for all subsequent evaluations Args: - expected_input (type, optional): Expected DataPair Input. Defaults to type(MatchedInstancePair). + expected_input (type, optional): Expected DataPair Input Type. Defaults to InputType.MATCHED_INSTANCE (which is type(MatchedInstancePair)). instance_approximator (InstanceApproximator | None, optional): Determines which instance approximator is used if necessary. Defaults to None. instance_matcher (InstanceMatchingAlgorithm | None, optional): Determines which instance matching algorithm is used if necessary. Defaults to None. iou_threshold (float, optional): Iou Threshold for evaluation. Defaults to 0.5. @@ -65,18 +65,33 @@ def __init__( self.__log_times = log_times self.__verbose = verbose + @classmethod + def _yaml_repr(cls, node) -> dict: + return { + "expected_input": node.__expected_input, + "instance_approximator": node.__instance_approximator, + "instance_matcher": node.__instance_matcher, + "edge_case_handler": node.__edge_case_handler, + "segmentation_class_groups": node.__segmentation_class_groups, + "eval_metrics": node.__eval_metrics, + "decision_metric": node.__decision_metric, + "decision_threshold": node.__decision_threshold, + "log_times": node.__log_times, + "verbose": node.__verbose, + } + @citation_reminder @measure_time def evaluate( self, - processing_pair: ( - SemanticPair | UnmatchedInstancePair | MatchedInstancePair | PanopticaResult - ), + prediction_arr: np.ndarray, + reference_arr: np.ndarray, result_all: bool = True, verbose: bool | None = None, ) -> dict[str, tuple[PanopticaResult, dict[str, _ProcessingPair]]]: - assert ( - type(processing_pair) == self.__expected_input + processing_pair = self.__expected_input(prediction_arr, reference_arr) + assert isinstance( + processing_pair, self.__expected_input.value ), f"input not of expected type {self.__expected_input}" if self.__segmentation_class_groups is None: diff --git a/panoptica/panoptica_result.py b/panoptica/panoptica_result.py index c037e85..b831e04 100644 --- a/panoptica/panoptica_result.py +++ b/panoptica/panoptica_result.py @@ -43,7 +43,7 @@ def __init__( edge_case_handler (EdgeCaseHandler): EdgeCaseHandler object that handles various forms of edge cases """ self._edge_case_handler = edge_case_handler - empty_list_std = self._edge_case_handler.handle_empty_list_std() + empty_list_std = self._edge_case_handler.handle_empty_list_std().value self._prediction_arr = prediction_arr self._reference_arr = reference_arr ###################### diff --git a/panoptica/utils/__init__.py b/panoptica/utils/__init__.py index d4dfe79..0b72f2a 100644 --- a/panoptica/utils/__init__.py +++ b/panoptica/utils/__init__.py @@ -3,11 +3,11 @@ _unique_without_zeros, ) from panoptica.utils.processing_pair import ( - InstanceLabelMap, MatchedInstancePair, SemanticPair, UnmatchedInstancePair, ) +from panoptica.utils.instancelabelmap import InstanceLabelMap from panoptica.utils.edge_case_handling import ( EdgeCaseHandler, EdgeCaseResult, diff --git a/panoptica/utils/config.py b/panoptica/utils/config.py new file mode 100644 index 0000000..241bc0a --- /dev/null +++ b/panoptica/utils/config.py @@ -0,0 +1,195 @@ +from ruamel.yaml import YAML +from pathlib import Path +from panoptica.utils.filepath import config_by_name, config_dir_by_name +from abc import ABC, abstractmethod + +supported_helper_classes = [] + + +def _register_helper_classes(yaml: YAML): + [yaml.register_class(s) for s in supported_helper_classes] + + +def _load_yaml(file: str | Path, registered_class=None): + if isinstance(file, str): + file = Path(file) + yaml = YAML(typ="safe") + _register_helper_classes(yaml) + if registered_class is not None: + yaml.register_class(registered_class) + yaml.default_flow_style = None + data = yaml.load(file) + assert isinstance(data, dict) or isinstance(data, object) + return data + + +def _save_yaml(data_dict: dict | object, out_file: str | Path, registered_class=None): + if isinstance(out_file, str): + out_file = Path(out_file) + + yaml = YAML(typ="safe") + yaml.default_flow_style = None + yaml.representer.ignore_aliases = lambda *data: True + _register_helper_classes(yaml) + if registered_class is not None: + yaml.register_class(registered_class) + assert isinstance(data_dict, registered_class) + # if isinstance(data_dict, object): + yaml.dump(data_dict, out_file) + # else: + # yaml.dump([registered_class(*data_dict)], out_file) + else: + yaml.dump(data_dict, out_file) + print(f"Saved config into {out_file}") + + +#################### +# TODO Merge into SupportsConfig +class Configuration: + """General Configuration class that handles yaml""" + + _data_dict: dict + _registered_class = None + + def __init__(self, data_dict: dict, registered_class=None) -> None: + assert isinstance(data_dict, dict) + self._data_dict = data_dict + if registered_class is not None: + self.register_to_class(registered_class) + + def register_to_class(self, cls): + global supported_helper_classes + if cls not in supported_helper_classes: + supported_helper_classes.append(cls) + self._registered_class = cls + return self + + @classmethod + def save_from_object(cls, obj: object, file: str | Path): + _save_yaml(obj, file, registered_class=type(obj)) + # return Configuration.load(file, registered_class=type(obj)) + + @classmethod + def load(cls, file: str | Path, registered_class=None): + data = _load_yaml(file, registered_class) + assert isinstance( + data, dict + ), f"The config at {file} is registered to a class. Use load_as_object() instead" + return Configuration(data, registered_class=registered_class) + + @classmethod + def load_as_object(cls, file: str | Path, registered_class=None): + data = _load_yaml(file, registered_class) + assert not isinstance( + data, dict + ), f"The config at {file} is not registered to a class. Use load() instead" + return data + + def save(self, out_file: str | Path): + _save_yaml(self._data_dict, out_file) + + def cls_object_from_this(self): + assert self._registered_class is not None + return self._registered_class(**self._data_dict) + + @property + def data_dict(self): + return self._data_dict + + @property + def cls(self): + return self._registered_class + + def __str__(self) -> str: + return f"Config({self.cls.__name__ if self.cls is not None else 'NoClass'} = {self.data_dict})" # type: ignore + + +######### +# Universal Functions +######### +def _register_class_to_yaml(cls): + global supported_helper_classes + if cls not in supported_helper_classes: + supported_helper_classes.append(cls) + + +def _load_from_config(cls, path: str | Path): + # cls._register_permanently() + if isinstance(path, str): + path = Path(path) + assert path.exists(), f"load_from_config: {path} does not exist" + obj = Configuration.load_as_object(path, registered_class=cls) + assert isinstance(obj, cls), f"Loaded config was not for class {cls.__name__}" + return obj + + +def _load_from_config_name(cls, name: str): + path = config_by_name(name) + assert path.exists(), f"load_from_config: {path} does not exist" + return _load_from_config(cls, path) + + +def _save_to_config(obj, path: str | Path): + if isinstance(path, str): + path = Path(path) + Configuration.save_from_object(obj, path) + + +def _save_to_config_by_name(obj, name: str): + dir, name = config_dir_by_name(name) + _save_to_config(obj, dir.joinpath(name)) + + +class SupportsConfig: + """Metaclass that allows a class to save and load objects by yaml configs""" + + def __init__(self) -> None: + raise NotImplementedError(f"Tried to instantiate abstract class {type(self)}") + + def __init_subclass__(cls, **kwargs): + # Registers all subclasses of this + super().__init_subclass__(**kwargs) + cls._register_permanently() + + @classmethod + def _register_permanently(cls): + _register_class_to_yaml(cls) + + @classmethod + def load_from_config(cls, path: str | Path): + obj = _load_from_config(cls, path) + assert isinstance( + obj, cls + ), f"loaded object was not of the correct class, expected {cls.__name__} but got {type(obj)}" + return obj + + @classmethod + def load_from_config_name(cls, name: str): + obj = _load_from_config_name(cls, name) + assert isinstance(obj, cls) + return obj + + def save_to_config(self, path: str | Path): + _save_to_config(self, path) + + def save_to_config_by_name(self, name: str): + _save_to_config_by_name(self, name) + + @classmethod + def to_yaml(cls, representer, node): + # cls._register_permanently() + assert hasattr( + cls, "_yaml_repr" + ), f"Class {cls.__name__} has no _yaml_repr(cls, node) defined" + return representer.represent_mapping("!" + cls.__name__, cls._yaml_repr(node)) + + @classmethod + def from_yaml(cls, constructor, node): + # cls._register_permanently() + data = constructor.construct_mapping(node, deep=True) + return cls(**data) + + @classmethod + @abstractmethod + def _yaml_repr(cls, node) -> dict: + pass # return {"groups": node.__group_dictionary} diff --git a/panoptica/utils/constants.py b/panoptica/utils/constants.py index d4a1faa..ecb7f63 100644 --- a/panoptica/utils/constants.py +++ b/panoptica/utils/constants.py @@ -1,10 +1,30 @@ from enum import Enum, auto +from panoptica.utils.config import ( + _register_class_to_yaml, + _load_from_config, + _load_from_config_name, + _save_to_config, +) +from pathlib import Path +import numpy as np class _Enum_Compare(Enum): def __eq__(self, __value: object) -> bool: if isinstance(__value, Enum): - return self.name == __value.name and self.value == __value.value + namecheck = self.name == __value.name + if not namecheck: + return False + if self.value is None: + return __value.value is None + + try: + if np.isnan(self.value): + return np.isnan(__value.value) + except Exception: + pass + + return self.value == __value.value elif isinstance(__value, str): return self.name == __value else: @@ -16,6 +36,36 @@ def __str__(self) -> str: def __repr__(self) -> str: return str(self) + def __init_subclass__(cls, **kwargs): + # Registers all subclasses of this + super().__init_subclass__(**kwargs) + cls._register_permanently() + + @classmethod + def _register_permanently(cls): + _register_class_to_yaml(cls) + + @classmethod + def load_from_config(cls, path: str | Path): + return _load_from_config(cls, path) + + @classmethod + def load_from_config_name(cls, name: str): + return _load_from_config_name(cls, name) + + def save_to_config(self, path: str | Path): + _save_to_config(self, path) + + @classmethod + def to_yaml(cls, representer, node): + # cls._register_permanently() + # assert hasattr(cls, "_yaml_repr"), f"Class {cls.__name__} has no _yaml_repr(cls, node) defined" + return representer.represent_scalar("!" + cls.__name__, str(node.name)) + + @classmethod + def from_yaml(cls, constructor, node): + return cls[node.value] + class CCABackend(_Enum_Compare): """ diff --git a/panoptica/utils/edge_case_handling.py b/panoptica/utils/edge_case_handling.py index e0ecc7e..1a865f0 100644 --- a/panoptica/utils/edge_case_handling.py +++ b/panoptica/utils/edge_case_handling.py @@ -1,8 +1,7 @@ import numpy as np -from typing import TYPE_CHECKING - from panoptica.metrics import Metric from panoptica.utils.constants import _Enum_Compare, auto +from panoptica.utils.config import SupportsConfig class EdgeCaseResult(_Enum_Compare): @@ -23,30 +22,39 @@ def __hash__(self) -> int: return self.value -class MetricZeroTPEdgeCaseHandling(object): +class MetricZeroTPEdgeCaseHandling(SupportsConfig): + def __init__( self, - default_result: EdgeCaseResult, + default_result: EdgeCaseResult | None = None, no_instances_result: EdgeCaseResult | None = None, empty_prediction_result: EdgeCaseResult | None = None, empty_reference_result: EdgeCaseResult | None = None, normal: EdgeCaseResult | None = None, ) -> None: - self.edgecase_dict: dict[EdgeCaseZeroTP, EdgeCaseResult] = {} - self.edgecase_dict[EdgeCaseZeroTP.EMPTY_PRED] = ( + assert default_result is not None or ( + no_instances_result is not None + and empty_prediction_result is not None + and empty_reference_result is not None + and normal is not None + ), "default_result is None and the rest is not fully specified" + + self._default_result = default_result + self._edgecase_dict: dict[EdgeCaseZeroTP, EdgeCaseResult] = {} + self._edgecase_dict[EdgeCaseZeroTP.EMPTY_PRED] = ( empty_prediction_result if empty_prediction_result is not None else default_result ) - self.edgecase_dict[EdgeCaseZeroTP.EMPTY_REF] = ( + self._edgecase_dict[EdgeCaseZeroTP.EMPTY_REF] = ( empty_reference_result if empty_reference_result is not None else default_result ) - self.edgecase_dict[EdgeCaseZeroTP.NO_INSTANCES] = ( + self._edgecase_dict[EdgeCaseZeroTP.NO_INSTANCES] = ( no_instances_result if no_instances_result is not None else default_result ) - self.edgecase_dict[EdgeCaseZeroTP.NORMAL] = ( + self._edgecase_dict[EdgeCaseZeroTP.NORMAL] = ( normal if normal is not None else default_result ) @@ -57,27 +65,44 @@ def __call__( return False, EdgeCaseResult.NONE.value # elif num_pred_instances + num_ref_instances == 0: - return True, self.edgecase_dict[EdgeCaseZeroTP.NO_INSTANCES].value + return True, self._edgecase_dict[EdgeCaseZeroTP.NO_INSTANCES].value elif num_ref_instances == 0: - return True, self.edgecase_dict[EdgeCaseZeroTP.EMPTY_REF].value + return True, self._edgecase_dict[EdgeCaseZeroTP.EMPTY_REF].value elif num_pred_instances == 0: - return True, self.edgecase_dict[EdgeCaseZeroTP.EMPTY_PRED].value + return True, self._edgecase_dict[EdgeCaseZeroTP.EMPTY_PRED].value elif num_pred_instances > 0 and num_ref_instances > 0: - return True, self.edgecase_dict[EdgeCaseZeroTP.NORMAL].value + return True, self._edgecase_dict[EdgeCaseZeroTP.NORMAL].value raise NotImplementedError( f"MetricZeroTPEdgeCaseHandling: couldn't handle case, got tp {tp}, n_pred_instances {num_pred_instances}, n_ref_instances {num_ref_instances}" ) + def __eq__(self, __value: object) -> bool: + if isinstance(__value, MetricZeroTPEdgeCaseHandling): + for s, k in self._edgecase_dict.items(): + if s not in __value._edgecase_dict or k != __value._edgecase_dict[s]: + return False + return True + return False + def __str__(self) -> str: txt = "" - for k, v in self.edgecase_dict.items(): + for k, v in self._edgecase_dict.items(): if v is not None: txt += str(k) + ": " + str(v) + "\n" return txt + @classmethod + def _yaml_repr(cls, node) -> dict: + return { + "no_instances_result": node._edgecase_dict[EdgeCaseZeroTP.NO_INSTANCES], + "empty_prediction_result": node._edgecase_dict[EdgeCaseZeroTP.EMPTY_PRED], + "empty_reference_result": node._edgecase_dict[EdgeCaseZeroTP.EMPTY_REF], + "normal": node._edgecase_dict[EdgeCaseZeroTP.NORMAL], + } -class EdgeCaseHandler: + +class EdgeCaseHandler(SupportsConfig): def __init__( self, @@ -131,11 +156,15 @@ def handle_zero_tp( num_ref_instances=num_ref_instances, ) + @property + def listmetric_zeroTP_handling(self): + return self.__listmetric_zeroTP_handling + def get_metric_zero_tp_handle(self, metric: Metric): return self.__listmetric_zeroTP_handling[metric] - def handle_empty_list_std(self) -> float | None: - return self.__empty_list_std.value + def handle_empty_list_std(self) -> EdgeCaseResult | None: + return self.__empty_list_std def __str__(self) -> str: txt = f"EdgeCaseHandler:\n - Standard Deviation of Empty = {self.__empty_list_std}" @@ -143,6 +172,13 @@ def __str__(self) -> str: txt += f"\n- {k}: {str(v)}" return str(txt) + @classmethod + def _yaml_repr(cls, node) -> dict: + return { + "listmetric_zeroTP_handling": node.__listmetric_zeroTP_handling, + "empty_list_std": node.__empty_list_std, + } + if __name__ == "__main__": handler = EdgeCaseHandler() diff --git a/panoptica/utils/filepath.py b/panoptica/utils/filepath.py new file mode 100644 index 0000000..68ce65b --- /dev/null +++ b/panoptica/utils/filepath.py @@ -0,0 +1,52 @@ +import os +import warnings +from itertools import chain +from pathlib import Path + + +def search_path( + basepath: str | Path, query: str, verbose: bool = False, suppress: bool = False +) -> list[Path]: + """Searches from basepath with query + Args: + basepath: ground path to look into + query: search query, can contain wildcards like *.npz or **/*.npz + verbose: + suppress: if true, will not throwing warnings if nothing is found + + Returns: + All found paths + """ + basepath = str(basepath) + assert os.path.exists( + basepath + ), f"basepath for search_path() doesnt exist, got {basepath}" + if not basepath.endswith("/"): + basepath += "/" + print(f"search_path: in {basepath}{query}") if verbose else None + paths = sorted(list(chain(list(Path(f"{basepath}").glob(f"{query}"))))) + if len(paths) == 0 and not suppress: + warnings.warn(f"did not find any paths in {basepath}{query}", UserWarning) + return paths + + +def config_dir_by_name(name: str) -> tuple[Path, str]: + directory = Path( + __file__.replace("////", "/") + .replace("\\\\", "/") + .replace("//", "/") + .replace("\\", "/") + ).parent.parent + if not name.endswith(".yaml"): + name += ".yaml" + return directory, name + + +# Find config path +def config_by_name(name: str) -> Path: + directory, name = config_dir_by_name(name) + p = search_path(directory, query=f"**/{name}", suppress=True) + assert ( + len(p) == 1 + ), f"Did not find exactly one config yaml with name {name} in directory {directory}, got {p}" + return p[0] diff --git a/panoptica/utils/instancelabelmap.py b/panoptica/utils/instancelabelmap.py new file mode 100644 index 0000000..16fd33c --- /dev/null +++ b/panoptica/utils/instancelabelmap.py @@ -0,0 +1,72 @@ +import numpy as np + + +# Many-to-One Mapping +class InstanceLabelMap(object): + # Mapping ((prediction_label, ...), (reference_label, ...)) + labelmap: dict[int, int] + + def __init__(self) -> None: + self.labelmap = {} + + def add_labelmap_entry(self, pred_labels: list[int] | int, ref_label: int): + if not isinstance(pred_labels, list): + pred_labels = [pred_labels] + assert isinstance(ref_label, int), "add_labelmap_entry: got no int as ref_label" + assert np.all( + [isinstance(r, int) for r in pred_labels] + ), "add_labelmap_entry: got no int as pred_label" + for p in pred_labels: + if p in self.labelmap and self.labelmap[p] != ref_label: + raise Exception( + f"You are mapping a prediction label to a reference label that was already assigned differently, got {self.__str__} and you tried {pred_labels}, {ref_label}" + ) + self.labelmap[p] = ref_label + + def get_pred_labels_matched_to_ref(self, ref_label: int): + return [k for k, v in self.labelmap.items() if v == ref_label] + + def contains_pred(self, pred_label: int): + return pred_label in self.labelmap + + def contains_ref(self, ref_label: int): + return ref_label in self.labelmap.values() + + def contains_and( + self, pred_label: int | None = None, ref_label: int | None = None + ) -> bool: + pred_in = True if pred_label is None else pred_label in self.labelmap + ref_in = True if ref_label is None else ref_label in self.labelmap.values() + return pred_in and ref_in + + def contains_or( + self, pred_label: int | None = None, ref_label: int | None = None + ) -> bool: + pred_in = True if pred_label is None else pred_label in self.labelmap + ref_in = True if ref_label is None else ref_label in self.labelmap.values() + return pred_in or ref_in + + def get_one_to_one_dictionary(self): + return self.labelmap + + def __str__(self) -> str: + return str( + list( + [ + str(tuple(k for k in self.labelmap.keys() if self.labelmap[k] == v)) + + " -> " + + str(v) + for v in set(self.labelmap.values()) + ] + ) + ) + + def __repr__(self) -> str: + return str(self) + + # Make all variables read-only! + def __setattr__(self, attr, value): + if hasattr(self, attr): + raise Exception("Attempting to alter read-only value") + + self.__dict__[attr] = value diff --git a/panoptica/utils/label_group.py b/panoptica/utils/label_group.py new file mode 100644 index 0000000..430e2dc --- /dev/null +++ b/panoptica/utils/label_group.py @@ -0,0 +1,86 @@ +import numpy as np +from panoptica.utils.config import SupportsConfig + +# + + +class LabelGroup(SupportsConfig): + """Defines a group of labels that semantically belong to each other. Only labels within a group will be matched with each other""" + + def __init__( + self, + value_labels: list[int] | int, + single_instance: bool = False, + ) -> None: + """Defines a group of labels that semantically belong to each other + + Args: + value_labels (list[int]): Actually labels in the prediction and reference mask in this group. Defines the labels that can be matched to each other + single_instance (bool, optional): If true, will not use the matching_threshold as there is only one instance (large organ, ...). Defaults to False. + """ + if isinstance(value_labels, int): + value_labels = [value_labels] + assert ( + len(value_labels) >= 1 + ), f"You tried to define a LabelGroup without any specified labels, got {value_labels}" + self.__value_labels = value_labels + assert np.all( + [v > 0 for v in self.__value_labels] + ), f"Given value labels are not >0, got {value_labels}" + self.__single_instance = single_instance + if self.__single_instance: + assert ( + len(value_labels) == 1 + ), f"single_instance set to True, but got more than one label for this group, got {value_labels}" + + LabelGroup._register_permanently() + + @property + def value_labels(self) -> list[int]: + return self.__value_labels + + @property + def single_instance(self) -> bool: + return self.__single_instance + + def __call__( + self, + array: np.ndarray, + set_to_binary: bool = False, + ) -> np.ndarray: + """Extracts the labels of this class + + Args: + array (np.ndarray): Array to extract the segmentation group labels from + set_to_binary (bool, optional): If true, will output a binary array. Defaults to False. + + Returns: + np.ndarray: Array containing only the labels of this segmentation group + """ + array = array.copy() + array[np.isin(array, self.value_labels, invert=True)] = 0 + if set_to_binary: + array[array != 0] = 1 + return array + + def __str__(self) -> str: + return f"LabelGroup {self.value_labels}, single_instance={self.single_instance}" + + def __repr__(self) -> str: + return str(self) + + @classmethod + def _yaml_repr(cls, node): + return { + "value_labels": node.value_labels, + "single_instance": node.single_instance, + } + + # @classmethod + # def to_yaml(cls, representer, node): + # return representer.represent_mapping("!" + cls.__name__, cls._yaml_repr(node)) + + # @classmethod + # def from_yaml(cls, constructor, node): + # data = constructor.construct_mapping(node, deep=True) + # return cls(**data) diff --git a/panoptica/utils/processing_pair.py b/panoptica/utils/processing_pair.py index 5ed1a7c..c64b1e9 100644 --- a/panoptica/utils/processing_pair.py +++ b/panoptica/utils/processing_pair.py @@ -4,6 +4,7 @@ from panoptica._functionals import _get_paired_crop from panoptica.utils import _count_unique_without_zeros, _unique_without_zeros +from panoptica.utils.constants import _Enum_Compare uint_type: type = np.unsignedinteger int_type: type = np.integer @@ -314,83 +315,12 @@ def copy(self): ) -# Many-to-One Mapping -class InstanceLabelMap(object): - # Mapping ((prediction_label, ...), (reference_label, ...)) - labelmap: dict[int, int] - - def __init__(self) -> None: - self.labelmap = {} - - def add_labelmap_entry(self, pred_labels: list[int] | int, ref_label: int): - if not isinstance(pred_labels, list): - pred_labels = [pred_labels] - assert isinstance(ref_label, int), "add_labelmap_entry: got no int as ref_label" - assert np.all( - [isinstance(r, int) for r in pred_labels] - ), "add_labelmap_entry: got no int as pred_label" - for p in pred_labels: - if p in self.labelmap and self.labelmap[p] != ref_label: - raise Exception( - f"You are mapping a prediction label to a reference label that was already assigned differently, got {self.__str__} and you tried {pred_labels}, {ref_label}" - ) - self.labelmap[p] = ref_label - - def get_pred_labels_matched_to_ref(self, ref_label: int): - return [k for k, v in self.labelmap.items() if v == ref_label] - - def contains_pred(self, pred_label: int): - return pred_label in self.labelmap - - def contains_ref(self, ref_label: int): - return ref_label in self.labelmap.values() - - def contains_and( - self, pred_label: int | None = None, ref_label: int | None = None - ) -> bool: - pred_in = True if pred_label is None else pred_label in self.labelmap - ref_in = True if ref_label is None else ref_label in self.labelmap.values() - return pred_in and ref_in - - def contains_or( - self, pred_label: int | None = None, ref_label: int | None = None - ) -> bool: - pred_in = True if pred_label is None else pred_label in self.labelmap - ref_in = True if ref_label is None else ref_label in self.labelmap.values() - return pred_in or ref_in - - def get_one_to_one_dictionary(self): - return self.labelmap - - def __str__(self) -> str: - return str( - list( - [ - str(tuple(k for k in self.labelmap.keys() if self.labelmap[k] == v)) - + " -> " - + str(v) - for v in set(self.labelmap.values()) - ] - ) - ) - - def __repr__(self) -> str: - return str(self) - - # Make all variables read-only! - def __setattr__(self, attr, value): - if hasattr(self, attr): - raise Exception("Attempting to alter read-only value") - - self.__dict__[attr] = value - - -if __name__ == "__main__": - n = np.zeros([50, 50], dtype=np.int32) - a = SemanticPair(n, n) - print(a) - # print(a.prediction_arr) +class InputType(_Enum_Compare): + SEMANTIC = SemanticPair + UNMATCHED_INSTANCE = UnmatchedInstancePair + MATCHED_INSTANCE = MatchedInstancePair - map = InstanceLabelMap() - map.labelmap = {2: 3, 3: 3, 4: 6} - print(map) + def __call__( + self, prediction_arr: np.ndarray, reference_arr: np.ndarray + ) -> _ProcessingPair: + return self.value(prediction_arr, reference_arr) diff --git a/panoptica/utils/segmentation_class.py b/panoptica/utils/segmentation_class.py index 4535ad3..550b6aa 100644 --- a/panoptica/utils/segmentation_class.py +++ b/panoptica/utils/segmentation_class.py @@ -1,71 +1,11 @@ import numpy as np +from pathlib import Path +from panoptica.utils.config import SupportsConfig +from panoptica.utils.label_group import LabelGroup -class LabelGroup: - """Defines a group of labels that semantically belong to each other. Only labels within a group will be matched with each other""" - - def __init__( - self, - value_labels: list[int] | int, - single_instance: bool = False, - ) -> None: - """Defines a group of labels that semantically belong to each other - - Args: - value_labels (list[int]): Actually labels in the prediction and reference mask in this group. Defines the labels that can be matched to each other - single_instance (bool, optional): If true, will not use the matching_threshold as there is only one instance (large organ, ...). Defaults to False. - """ - if isinstance(value_labels, int): - value_labels = [value_labels] - assert ( - len(value_labels) >= 1 - ), f"You tried to define a LabelGroup without any specified labels, got {value_labels}" - self.__value_labels = value_labels - assert np.all( - [v > 0 for v in self.__value_labels] - ), f"Given value labels are not >0, got {value_labels}" - self.__single_instance = single_instance - if self.__single_instance: - assert ( - len(value_labels) == 1 - ), f"single_instance set to True, but got more than one label for this group, got {value_labels}" - - @property - def value_labels(self) -> list[int]: - return self.__value_labels - - @property - def single_instance(self) -> bool: - return self.__single_instance - - def __call__( - self, - array: np.ndarray, - set_to_binary: bool = False, - ) -> np.ndarray: - """Extracts the labels of this class - - Args: - array (np.ndarray): Array to extract the segmentation group labels from - set_to_binary (bool, optional): If true, will output a binary array. Defaults to False. - - Returns: - np.ndarray: Array containing only the labels of this segmentation group - """ - array = array.copy() - array[np.isin(array, self.value_labels, invert=True)] = 0 - if set_to_binary: - array[array != 0] = 1 - return array - - def __str__(self) -> str: - return f"LabelGroup {self.value_labels}, single_instance={self.single_instance}" - - def __repr__(self) -> str: - return str(self) - - -class SegmentationClassGroups: +class SegmentationClassGroups(SupportsConfig): + # def __init__( self, groups: list[LabelGroup] | dict[str, LabelGroup | tuple[list[int] | int, bool]], @@ -78,7 +18,7 @@ def __init__( self.__group_dictionary = { f"group_{idx}": g for idx, g in enumerate(groups) } - else: + elif isinstance(groups, dict): # transform dict into list of LabelGroups for i, g in groups.items(): name_lower = str(i).lower() @@ -100,7 +40,6 @@ def __init__( raise AssertionError( f"The same label was assigned to two different labelgroups, got {str(self)}" ) - self.__labels = labels def has_defined_labels_for( @@ -137,10 +76,18 @@ def __iter__(self): def keys(self) -> list[str]: return list(self.__group_dictionary.keys()) + @property + def labels(self): + return self.__labels + def items(self): for k in self: yield k, self[k] + @classmethod + def _yaml_repr(cls, node): + return {"groups": node.__group_dictionary} + def list_duplicates(seq): seen = set() diff --git a/pyproject.toml b/pyproject.toml index 2a30f1f..40fc840 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,7 @@ connected-components-3d = "^3.12.3" scipy = "^1.7.0" rich = "^13.6.0" scikit-image = "^0.22.0" +"ruamel.yaml" = "^0.18.6" [tool.poetry.dev-dependencies] pytest = "^6.2.5" diff --git a/unit_tests/test_config.py b/unit_tests/test_config.py new file mode 100644 index 0000000..59c0b46 --- /dev/null +++ b/unit_tests/test_config.py @@ -0,0 +1,159 @@ +# Call 'python -m unittest' on this folder +# coverage run -m unittest +# coverage report +# coverage html +import os +import unittest + +from panoptica.metrics import ( + Metric, + _Metric, + Evaluation_List_Metric, + MetricMode, + MetricCouldNotBeComputedException, +) +from panoptica.utils.segmentation_class import SegmentationClassGroups, LabelGroup +from panoptica.utils.constants import CCABackend +from panoptica.utils.edge_case_handling import ( + EdgeCaseResult, + EdgeCaseZeroTP, + MetricZeroTPEdgeCaseHandling, + EdgeCaseHandler, +) +from panoptica import ConnectedComponentsInstanceApproximator, NaiveThresholdMatching +from pathlib import Path +import numpy as np +import random + +test_file = Path(__file__).parent.joinpath("test.yaml") + + +class Test_Datatypes(unittest.TestCase): + def setUp(self) -> None: + os.environ["PANOPTICA_CITATION_REMINDER"] = "False" + return super().setUp() + + def test_enum_config(self): + a = CCABackend.cc3d + a.save_to_config(test_file) + print(a) + b = CCABackend.load_from_config(test_file) + print(b) + os.remove(test_file) + + self.assertEqual(a, b) + + def test_enum_config_all(self): + for enum in [CCABackend, EdgeCaseZeroTP, EdgeCaseResult, Metric]: + for a in enum: + a.save_to_config(test_file) + print(a) + b = enum.load_from_config(test_file) + print(b) + os.remove(test_file) + + self.assertEqual(a, b) + self.assertEqual(a.name, b.name) + self.assertEqual(str(a), str(b)) + + # check for value equality + if a.value is None: + self.assertTrue(b.value is None) + else: + # if it is _Metric object, just check name + if not isinstance(a.value, _Metric): + if np.isnan(a.value): + self.assertTrue(np.isnan(b.value)) + else: + self.assertEqual(a.value, b.value) + else: + self.assertEqual(a.value.name, b.value.name) + + def test_SegmentationClassGroups_config(self): + e = { + "groups": { + "vertebra": LabelGroup([i for i in range(1, 11)], False), + "ivd": LabelGroup([i for i in range(101, 111)]), + "sacrum": LabelGroup(26, True), + "endplate": LabelGroup([i for i in range(201, 211)]), + } + } + t = SegmentationClassGroups(**e) + print(t) + print() + t.save_to_config(test_file) + d: SegmentationClassGroups = SegmentationClassGroups.load_from_config(test_file) + os.remove(test_file) + + for k, v in d.items(): + self.assertEqual(t[k].single_instance, v.single_instance) + self.assertEqual(len(t[k].value_labels), len(v.value_labels)) + + def test_InstanceApproximator_config(self): + for backend in [None, CCABackend.cc3d, CCABackend.scipy]: + t = ConnectedComponentsInstanceApproximator(cca_backend=backend) + print(t) + print() + t.save_to_config(test_file) + d: ConnectedComponentsInstanceApproximator = ( + ConnectedComponentsInstanceApproximator.load_from_config(test_file) + ) + os.remove(test_file) + + self.assertEqual(d.cca_backend, t.cca_backend) + + def test_NaiveThresholdMatching_config(self): + for mm in [Metric.DSC, Metric.IOU, Metric.ASSD]: + for mt in [0.1, 0.4, 0.5, 0.8, 1.0]: + for amto in [False, True]: + t = NaiveThresholdMatching( + matching_metric=mm, + matching_threshold=mt, + allow_many_to_one=amto, + ) + print(t) + print() + t.save_to_config(test_file) + d: NaiveThresholdMatching = NaiveThresholdMatching.load_from_config( + test_file + ) + os.remove(test_file) + + self.assertEqual(d._allow_many_to_one, t._allow_many_to_one) + self.assertEqual(d._matching_metric, t._matching_metric) + self.assertEqual(d._matching_threshold, t._matching_threshold) + + def test_MetricZeroTPEdgeCaseHandling_config(self): + for iter in range(10): + args = [random.choice(list(EdgeCaseResult)) for i in range(5)] + + t = MetricZeroTPEdgeCaseHandling(*args) + print(t) + print() + t.save_to_config(test_file) + d: MetricZeroTPEdgeCaseHandling = ( + MetricZeroTPEdgeCaseHandling.load_from_config(test_file) + ) + os.remove(test_file) + + for k, v in t._edgecase_dict.items(): + self.assertEqual(v, d._edgecase_dict[k]) + # self.assertEqual(d.cca_backend, t.cca_backend) + + def test_EdgeCaseHandler_config(self): + t = EdgeCaseHandler() + print(t) + print() + t.save_to_config(test_file) + d: EdgeCaseHandler = EdgeCaseHandler.load_from_config(test_file) + # os.remove(test_file) + + self.assertEqual(t.handle_empty_list_std(), d.handle_empty_list_std()) + for k, v in t.listmetric_zeroTP_handling.items(): + # v is dict[Metric, MetricZeroTPEdgeCaseHandling] + v2 = d.listmetric_zeroTP_handling[k] + + print(v) + print(v2) + + self.assertEqual(v, v2) diff --git a/unit_tests/test_datatype.py b/unit_tests/test_datatype.py index 4d6b287..c892aae 100644 --- a/unit_tests/test_datatype.py +++ b/unit_tests/test_datatype.py @@ -11,6 +11,7 @@ MetricMode, MetricCouldNotBeComputedException, ) +from panoptica.utils.edge_case_handling import EdgeCaseResult class Test_Datatypes(unittest.TestCase): @@ -29,6 +30,10 @@ def test_metrics_enum(self): self.assertNotEqual(Metric.DSC, Metric.IOU) self.assertNotEqual(Metric.DSC, "IOU") + def test_EdgeCaseResult_enum(self): + for e in EdgeCaseResult: + self.assertEqual(e, e) + def test_matching_metric(self): dsc_metric = Metric.DSC diff --git a/unit_tests/test_panoptic_evaluator.py b/unit_tests/test_panoptic_evaluator.py index a2cf755..5fe9c0e 100644 --- a/unit_tests/test_panoptic_evaluator.py +++ b/unit_tests/test_panoptic_evaluator.py @@ -7,6 +7,7 @@ import numpy as np +from panoptica import InputType from panoptica.instance_approximator import ConnectedComponentsInstanceApproximator from panoptica.instance_matcher import MaximizeMergeMatching, NaiveThresholdMatching from panoptica.metrics import Metric @@ -16,7 +17,7 @@ from panoptica.utils.segmentation_class import SegmentationClassGroups, LabelGroup -class Test_Panoptic_Evaluator(unittest.TestCase): +class Test_Panoptica_Evaluator(unittest.TestCase): def setUp(self) -> None: os.environ["PANOPTICA_CITATION_REMINDER"] = "False" return super().setUp() @@ -27,15 +28,13 @@ def test_simple_evaluation(self): a[20:40, 10:20] = 1 b[20:35, 10:20] = 2 - sample = SemanticPair(b, a) - evaluator = Panoptica_Evaluator( - expected_input=SemanticPair, + expected_input=InputType.SEMANTIC, instance_approximator=ConnectedComponentsInstanceApproximator(), instance_matcher=NaiveThresholdMatching(), ) - result, debug_data = evaluator.evaluate(sample)["ungrouped"] + result, debug_data = evaluator.evaluate(b, a)["ungrouped"] print(result) self.assertEqual(result.tp, 1) self.assertEqual(result.fp, 0) @@ -48,15 +47,13 @@ def test_simple_evaluation_DSC(self): a[20:40, 10:20] = 1 b[20:35, 10:20] = 2 - sample = SemanticPair(b, a) - evaluator = Panoptica_Evaluator( - expected_input=SemanticPair, + expected_input=InputType.SEMANTIC, instance_approximator=ConnectedComponentsInstanceApproximator(), instance_matcher=NaiveThresholdMatching(), ) - result, debug_data = evaluator.evaluate(sample)["ungrouped"] + result, debug_data = evaluator.evaluate(b, a)["ungrouped"] print(result) self.assertEqual(result.tp, 1) self.assertEqual(result.fp, 0) @@ -69,16 +66,14 @@ def test_simple_evaluation_DSC_partial(self): a[20:40, 10:20] = 1 b[20:35, 10:20] = 2 - sample = SemanticPair(b, a) - evaluator = Panoptica_Evaluator( - expected_input=SemanticPair, + expected_input=InputType.SEMANTIC, instance_approximator=ConnectedComponentsInstanceApproximator(), instance_matcher=NaiveThresholdMatching(matching_metric=Metric.DSC), eval_metrics=[Metric.DSC], ) - result, debug_data = evaluator.evaluate(sample)["ungrouped"] + result, debug_data = evaluator.evaluate(b, a)["ungrouped"] print(result) self.assertEqual(result.tp, 1) self.assertEqual(result.fp, 0) @@ -95,10 +90,8 @@ def test_simple_evaluation_ASSD(self): a[20:40, 10:20] = 1 b[20:35, 10:20] = 2 - sample = SemanticPair(b, a) - evaluator = Panoptica_Evaluator( - expected_input=SemanticPair, + expected_input=InputType.SEMANTIC, instance_approximator=ConnectedComponentsInstanceApproximator(), instance_matcher=NaiveThresholdMatching( matching_metric=Metric.ASSD, @@ -106,7 +99,7 @@ def test_simple_evaluation_ASSD(self): ), ) - result, debug_data = evaluator.evaluate(sample)["ungrouped"] + result, debug_data = evaluator.evaluate(b, a)["ungrouped"] print(result) self.assertEqual(result.tp, 1) self.assertEqual(result.fp, 0) @@ -119,10 +112,8 @@ def test_simple_evaluation_ASSD_negative(self): a[20:40, 10:20] = 1 b[20:35, 10:20] = 2 - sample = SemanticPair(b, a) - evaluator = Panoptica_Evaluator( - expected_input=SemanticPair, + expected_input=InputType.SEMANTIC, instance_approximator=ConnectedComponentsInstanceApproximator(), instance_matcher=NaiveThresholdMatching( matching_metric=Metric.ASSD, @@ -130,7 +121,7 @@ def test_simple_evaluation_ASSD_negative(self): ), ) - result, debug_data = evaluator.evaluate(sample)["ungrouped"] + result, debug_data = evaluator.evaluate(b, a)["ungrouped"] print(result) self.assertEqual(result.tp, 0) self.assertEqual(result.fp, 1) @@ -144,15 +135,13 @@ def test_pred_empty(self): a[20:40, 10:20] = 1 # b[20:35, 10:20] = 2 - sample = SemanticPair(b, a) - evaluator = Panoptica_Evaluator( - expected_input=SemanticPair, + expected_input=InputType.SEMANTIC, instance_approximator=ConnectedComponentsInstanceApproximator(), instance_matcher=NaiveThresholdMatching(), ) - result, debug_data = evaluator.evaluate(sample)["ungrouped"] + result, debug_data = evaluator.evaluate(b, a)["ungrouped"] print(result) self.assertEqual(result.tp, 0) self.assertEqual(result.fp, 0) @@ -167,15 +156,13 @@ def test_ref_empty(self): # a[20:40, 10:20] = 1 b[20:35, 10:20] = 2 - sample = SemanticPair(b, a) - evaluator = Panoptica_Evaluator( - expected_input=SemanticPair, + expected_input=InputType.SEMANTIC, instance_approximator=ConnectedComponentsInstanceApproximator(), instance_matcher=NaiveThresholdMatching(), ) - result, debug_data = evaluator.evaluate(sample)["ungrouped"] + result, debug_data = evaluator.evaluate(b, a)["ungrouped"] print(result) self.assertEqual(result.tp, 0) self.assertEqual(result.fp, 1) @@ -190,15 +177,13 @@ def test_both_empty(self): # a[20:40, 10:20] = 1 # b[20:35, 10:20] = 2 - sample = SemanticPair(b, a) - evaluator = Panoptica_Evaluator( - expected_input=SemanticPair, + expected_input=InputType.SEMANTIC, instance_approximator=ConnectedComponentsInstanceApproximator(), instance_matcher=NaiveThresholdMatching(), ) - result, debug_data = evaluator.evaluate(sample)["ungrouped"] + result, debug_data = evaluator.evaluate(b, a)["ungrouped"] print(result) self.assertEqual(result.tp, 0) self.assertEqual(result.fp, 0) @@ -228,15 +213,14 @@ def test_dtype_evaluation(self): if da != db: self.assertRaises(AssertionError, SemanticPair, b, a) else: - sample = SemanticPair(b, a) evaluator = Panoptica_Evaluator( - expected_input=SemanticPair, + expected_input=InputType.SEMANTIC, instance_approximator=ConnectedComponentsInstanceApproximator(), instance_matcher=NaiveThresholdMatching(), ) - result, debug_data = evaluator.evaluate(sample)["ungrouped"] + result, debug_data = evaluator.evaluate(b, a)["ungrouped"] print(result) self.assertEqual(result.tp, 1) self.assertEqual(result.fp, 0) @@ -249,15 +233,13 @@ def test_simple_evaluation_maximize_matcher(self): a[20:40, 10:20] = 1 b[20:35, 10:20] = 2 - sample = SemanticPair(b, a) - evaluator = Panoptica_Evaluator( - expected_input=SemanticPair, + expected_input=InputType.SEMANTIC, instance_approximator=ConnectedComponentsInstanceApproximator(), instance_matcher=MaximizeMergeMatching(), ) - result, debug_data = evaluator.evaluate(sample)["ungrouped"] + result, debug_data = evaluator.evaluate(b, a)["ungrouped"] print(result) self.assertEqual(result.tp, 1) self.assertEqual(result.fp, 0) @@ -271,15 +253,13 @@ def test_simple_evaluation_maximize_matcher_overlaptwo(self): b[20:35, 10:20] = 2 b[36:38, 10:20] = 3 - sample = SemanticPair(b, a) - evaluator = Panoptica_Evaluator( - expected_input=SemanticPair, + expected_input=InputType.SEMANTIC, instance_approximator=ConnectedComponentsInstanceApproximator(), instance_matcher=MaximizeMergeMatching(), ) - result, debug_data = evaluator.evaluate(sample)["ungrouped"] + result, debug_data = evaluator.evaluate(b, a)["ungrouped"] print(result) self.assertEqual(result.tp, 1) self.assertEqual(result.fp, 0) @@ -295,15 +275,13 @@ def test_simple_evaluation_maximize_matcher_overlap(self): # match the two above to 1 and the 4 to nothing (FP) b[39:47, 10:20] = 4 - sample = SemanticPair(b, a) - evaluator = Panoptica_Evaluator( - expected_input=SemanticPair, + expected_input=InputType.SEMANTIC, instance_approximator=ConnectedComponentsInstanceApproximator(), instance_matcher=MaximizeMergeMatching(), ) - result, debug_data = evaluator.evaluate(sample)["ungrouped"] + result, debug_data = evaluator.evaluate(b, a)["ungrouped"] print(result) self.assertEqual(result.tp, 1) self.assertEqual(result.fp, 1) @@ -318,16 +296,14 @@ def test_single_instance_mode(self): a[20:40, 10:20] = 5 b[20:35, 10:20] = 5 - sample = SemanticPair(b, a) - evaluator = Panoptica_Evaluator( - expected_input=SemanticPair, + expected_input=InputType.SEMANTIC, instance_approximator=ConnectedComponentsInstanceApproximator(), instance_matcher=NaiveThresholdMatching(), segmentation_class_groups=SegmentationClassGroups({"organ": (5, True)}), ) - result, debug_data = evaluator.evaluate(sample)["organ"] + result, debug_data = evaluator.evaluate(b, a)["organ"] print(result) self.assertEqual(result.tp, 1) self.assertEqual(result.fp, 0) @@ -340,16 +316,14 @@ def test_single_instance_mode_nooverlap(self): a[20:40, 10:20] = 5 b[5:15, 30:50] = 5 - sample = SemanticPair(b, a) - evaluator = Panoptica_Evaluator( - expected_input=SemanticPair, + expected_input=InputType.SEMANTIC, instance_approximator=ConnectedComponentsInstanceApproximator(), instance_matcher=NaiveThresholdMatching(), segmentation_class_groups=SegmentationClassGroups({"organ": (5, True)}), ) - result, debug_data = evaluator.evaluate(sample)["organ"] + result, debug_data = evaluator.evaluate(b, a)["organ"] print(result) self.assertEqual(result.tp, 1) self.assertEqual(result.fp, 0) diff --git a/unit_tests/test_panoptic_result.py b/unit_tests/test_panoptic_result.py index e88b2b2..f5087ea 100644 --- a/unit_tests/test_panoptic_result.py +++ b/unit_tests/test_panoptic_result.py @@ -15,7 +15,7 @@ from panoptica.utils.edge_case_handling import EdgeCaseHandler, EdgeCaseResult -class Test_Panoptic_Evaluator(unittest.TestCase): +class Test_Panoptica_Evaluator(unittest.TestCase): def setUp(self) -> None: os.environ["PANOPTICA_CITATION_REMINDER"] = "False" return super().setUp()