diff --git a/panoptica/_functionals.py b/panoptica/_functionals.py index 3b6a6b1..4c8f10c 100644 --- a/panoptica/_functionals.py +++ b/panoptica/_functionals.py @@ -141,6 +141,23 @@ def _get_paired_crop( reference_arr: np.ndarray, px_pad: int = 2, ): + """Calculates a bounding box based on paired prediction and reference arrays. + + This function combines the prediction and reference arrays, checks if they are identical, + and computes a bounding box around the non-zero regions. If both arrays are completely zero, + a small value is added to ensure the bounding box is valid. + + Args: + prediction_arr (np.ndarray): The predicted segmentation array. + reference_arr (np.ndarray): The ground truth segmentation array. + px_pad (int, optional): Padding to apply around the bounding box. Defaults to 2. + + Returns: + np.ndarray: The bounding box coordinates around the combined non-zero regions. + + Raises: + AssertionError: If the prediction and reference arrays do not have the same shape. + """ assert prediction_arr.shape == reference_arr.shape combined = prediction_arr + reference_arr @@ -150,6 +167,19 @@ def _get_paired_crop( def _round_to_n(value: float | int, n_significant_digits: int = 2): + """Rounds a number to a specified number of significant digits. + + This function rounds the given value to the specified number of significant digits. + If the value is zero, it is returned unchanged. + + Args: + value (float | int): The number to be rounded. + n_significant_digits (int, optional): The number of significant digits to round to. + Defaults to 2. + + Returns: + float: The rounded value. + """ return ( value if value == 0 diff --git a/panoptica/metrics/assd.py b/panoptica/metrics/assd.py index b98b407..41d6937 100644 --- a/panoptica/metrics/assd.py +++ b/panoptica/metrics/assd.py @@ -107,6 +107,29 @@ def _distance_transform_edt( return_distances=True, return_indices=False, ): + """Computes the Euclidean distance transform and/or feature transform of a binary array. + + This function calculates the Euclidean distance transform (EDT) of a binary array, + which gives the distance from each non-zero point to the nearest zero point. It can + also return the feature transform, which provides indices to the nearest non-zero point. + + Args: + input_array (np.ndarray): The input binary array where non-zero values are considered + foreground. + sampling (optional): A sequence or array that specifies the spacing along each dimension. + If provided, scales the distances by the sampling value along each axis. + return_distances (bool, optional): If True, returns the distance transform. Default is True. + return_indices (bool, optional): If True, returns the feature transform with indices to + the nearest foreground points. Default is False. + + Returns: + np.ndarray or tuple[np.ndarray, ...]: If `return_distances` is True, returns the distance + transform as an array. If `return_indices` is True, returns the feature transform. If both + are True, returns a tuple with the distance and feature transforms. + + Raises: + ValueError: If the input array is empty or has unsupported dimensions. + """ # calculate the feature transform # input = np.atleast_1d(np.where(input, 1, 0).astype(np.int8)) # if sampling is not None: diff --git a/panoptica/metrics/metrics.py b/panoptica/metrics/metrics.py index 5cbce45..7e30c53 100644 --- a/panoptica/metrics/metrics.py +++ b/panoptica/metrics/metrics.py @@ -20,7 +20,25 @@ @dataclass class _Metric: - """A Metric class containing a name, whether higher or lower values is better, and a function to calculate that metric between two instances in an array""" + """Represents a metric with a name, direction (increasing or decreasing), and a calculation function. + + This class provides a framework for defining and calculating metrics, which can be used + to evaluate the similarity or performance between reference and prediction arrays. + The metric direction indicates whether higher or lower values are better. + + Attributes: + name (str): Short name of the metric. + long_name (str): Full descriptive name of the metric. + decreasing (bool): If True, lower metric values are better; otherwise, higher values are preferred. + _metric_function (Callable): A callable function that computes the metric + between two input arrays. + + Example: + >>> my_metric = _Metric(name="accuracy", long_name="Accuracy", decreasing=False, _metric_function=accuracy_function) + >>> score = my_metric(reference_array, prediction_array) + >>> print(score) + + """ name: str long_name: str @@ -36,6 +54,20 @@ def __call__( *args, **kwargs, ) -> int | float: + """Calculates the metric between reference and prediction arrays. + + Args: + reference_arr (np.ndarray): The reference array. + prediction_arr (np.ndarray): The prediction array. + ref_instance_idx (int, optional): The instance index to filter in the reference array. + pred_instance_idx (int | list[int], optional): Instance index or indices to filter in + the prediction array. + *args: Additional positional arguments for the metric function. + **kwargs: Additional keyword arguments for the metric function. + + Returns: + int | float: The computed metric value. + """ if ref_instance_idx is not None and pred_instance_idx is not None: reference_arr = reference_arr.copy() == ref_instance_idx if isinstance(pred_instance_idx, int): @@ -60,15 +92,35 @@ def __repr__(self) -> str: return str(self) def __hash__(self) -> int: + """Hash based on metric name, constrained to fit within 8 digits. + + Returns: + int: The hash value of the metric. + """ return abs(hash(self.name)) % (10**8) @property def increasing(self): + """Indicates if higher values of the metric are better. + + Returns: + bool: True if increasing values are preferred, otherwise False. + """ return not self.decreasing def score_beats_threshold( self, matching_score: float, matching_threshold: float ) -> bool: + """Determines if a matching score meets a specified threshold. + + Args: + matching_score (float): The score to evaluate. + matching_threshold (float): The threshold value to compare against. + + Returns: + bool: True if the score meets the threshold, taking into account the + metric's preferred direction. + """ return (self.increasing and matching_score >= matching_threshold) or ( self.decreasing and matching_score <= matching_threshold ) @@ -206,6 +258,16 @@ def __init__(self, *args: object) -> None: class Evaluation_Metric: + """This represents a metric in the evaluation derived from other metrics or list metrics (no circular dependancies!) + + Args: + name_id (str): code-name of this metric, must be same as the member variable of PanopticResult + calc_func (Callable): the function to calculate this metric based on the PanopticResult object + long_name (str | None, optional): A longer descriptive name for printing/logging purposes. Defaults to None. + was_calculated (bool, optional): Whether this metric has been calculated or not. Defaults to False. + error (bool, optional): If true, means the metric could not have been calculated (because dependancies do not exist or have this flag set to True). Defaults to False. + """ + def __init__( self, name_id: str, @@ -215,15 +277,7 @@ def __init__( was_calculated: bool = False, error: bool = False, ): - """This represents a metric in the evaluation derived from other metrics or list metrics (no circular dependancies!) - Args: - name_id (str): code-name of this metric, must be same as the member variable of PanopticResult - calc_func (Callable): the function to calculate this metric based on the PanopticResult object - long_name (str | None, optional): A longer descriptive name for printing/logging purposes. Defaults to None. - was_calculated (bool, optional): Whether this metric has been calculated or not. Defaults to False. - error (bool, optional): If true, means the metric could not have been calculated (because dependancies do not exist or have this flag set to True). Defaults to False. - """ self.id = name_id self.metric_type = metric_type self._calc_func = calc_func diff --git a/panoptica/panoptica_aggregator.py b/panoptica/panoptica_aggregator.py index 78e7463..0c0c527 100644 --- a/panoptica/panoptica_aggregator.py +++ b/panoptica/panoptica_aggregator.py @@ -30,9 +30,11 @@ # class Panoptica_Aggregator: - # internal_list_lock = Lock() - # - """Aggregator that calls evaluations and saves the resulting metrics per sample. Can be used to create statistics, ...""" + """Aggregator that manages evaluations and saves resulting metrics per sample. + + This class interfaces with the `Panoptica_Evaluator` to perform evaluations, + store results, and manage file outputs for statistical analysis. + """ def __init__( self, @@ -41,10 +43,18 @@ def __init__( log_times: bool = False, continue_file: bool = True, ): - """ + """Initializes the Panoptica_Aggregator. + Args: - panoptica_evaluator (Panoptica_Evaluator): The Panoptica_Evaluator used for the pipeline. - output_file (Path | None, optional): If given, will stream the sample results into this file. If the file is existent, will append results if not already there. Defaults to None. + panoptica_evaluator (Panoptica_Evaluator): The evaluator used for performing evaluations. + output_file (Path | str): Path to the output file for storing results. If the file exists, + results will be appended. If it doesn't, a new file will be created. + log_times (bool, optional): If True, computation times will be logged. Defaults to False. + continue_file (bool, optional): If True, results will continue from existing entries in the file. + Defaults to True. + + Raises: + AssertionError: If the output directory does not exist or if the file extension is not `.tsv`. """ self.__panoptica_evaluator = panoptica_evaluator self.__class_group_names = panoptica_evaluator.segmentation_class_groups_names @@ -115,10 +125,16 @@ def __init__( atexit.register(self.__exist_handler) def __exist_handler(self): + """Handles cleanup upon program exit by removing the temporary output buffer file.""" if self.__output_buffer_file is not None and self.__output_buffer_file.exists(): os.remove(str(self.__output_buffer_file)) def make_statistic(self) -> Panoptica_Statistic: + """Generates statistics from the aggregated evaluation results. + + Returns: + Panoptica_Statistic: The statistics object containing the results. + """ with filelock: obj = Panoptica_Statistic.from_file(self.__output_file) return obj @@ -129,14 +145,16 @@ def evaluate( reference_arr: np.ndarray, subject_name: str, ): - """Evaluates one case + """Evaluates a single case using the provided prediction and reference arrays. Args: - prediction_arr (np.ndarray): Prediction array - reference_arr (np.ndarray): reference array - subject_name (str | None, optional): Unique name of the sample. If none, will give it a name based on count. Defaults to None. - skip_already_existent (bool): If true, will skip subjects which were already evaluated instead of crashing. Defaults to False. - verbose (bool | None, optional): Verbose. Defaults to None. + prediction_arr (np.ndarray): The array containing the predicted segmentation. + reference_arr (np.ndarray): The array containing the ground truth segmentation. + subject_name (str): A unique name for the sample being evaluated. If none is provided, + a name will be generated based on the count. + + Raises: + ValueError: If the subject name has already been evaluated or is in process. """ # Read tmp file to see which sample names are blocked with inevalfilelock: @@ -164,6 +182,12 @@ def evaluate( self._save_one_subject(subject_name, res) def _save_one_subject(self, subject_name, result_grouped): + """Saves the evaluation results for a single subject. + + Args: + subject_name (str): The name of the subject whose results are being saved. + result_grouped (dict): A dictionary of grouped results from the evaluation. + """ with filelock: # content = [subject_name] @@ -186,9 +210,19 @@ def panoptica_evaluator(self): def _read_first_row(file: str | Path): + """Reads the first row of a TSV file. + + NOT THREAD SAFE BY ITSELF! + + Args: + file (str | Path): The path to the file from which to read the first row. + + Returns: + list: The first row of the file as a list of strings. + """ if isinstance(file, Path): file = str(file) - # NOT THREAD SAFE BY ITSELF! + # with open(str(file), "r", encoding="utf8", newline="") as tsvfile: rd = csv.reader(tsvfile, delimiter="\t", lineterminator="\n") @@ -202,7 +236,19 @@ def _read_first_row(file: str | Path): def _load_first_column_entries(file: str | Path): - # NOT THREAD SAFE BY ITSELF! + """Loads the entries from the first column of a TSV file. + + NOT THREAD SAFE BY ITSELF! + + Args: + file (str | Path): The path to the file from which to load entries. + + Returns: + list: A list of entries from the first column of the file. + + Raises: + AssertionError: If the file contains duplicate entries. + """ if isinstance(file, Path): file = str(file) with open(str(file), "r", encoding="utf8", newline="") as tsvfile: @@ -221,6 +267,12 @@ def _load_first_column_entries(file: str | Path): def _write_content(file: str | Path, content: list[list[str]]): + """Writes content to a TSV file. + + Args: + file (str | Path): The path to the file where content will be written. + content (list[list[str]]): A list of lists containing the rows of data to write. + """ if isinstance(file, Path): file = str(file) # NOT THREAD SAFE BY ITSELF! diff --git a/panoptica/panoptica_result.py b/panoptica/panoptica_result.py index 0db17c2..da9c884 100644 --- a/panoptica/panoptica_result.py +++ b/panoptica/panoptica_result.py @@ -286,6 +286,21 @@ def _calc_global_bin_metric( reference_arr, do_binarize: bool = True, ): + """ + Calculates a global binary metric based on predictions and references. + + Args: + metric (Metric): The metric to compute. + prediction_arr: The predicted values. + reference_arr: The ground truth values. + do_binarize (bool): Whether to binarize the input arrays. Defaults to True. + + Returns: + The calculated metric value. + + Raises: + MetricCouldNotBeComputedException: If the specified metric is not set. + """ if metric not in self._global_metrics: raise MetricCouldNotBeComputedException(f"Global Metric {metric} not set") @@ -321,6 +336,20 @@ def _add_metric( default_value=None, was_calculated: bool = False, ): + """ + Adds a new metric to the evaluation metrics. + + Args: + name_id (str): The unique identifier for the metric. + metric_type (MetricType): The type of the metric. + calc_func (Callable | None): The function to calculate the metric. + long_name (str | None): A longer, descriptive name for the metric. + default_value: The default value for the metric. + was_calculated (bool): Indicates if the metric has been calculated. + + Returns: + The default value of the metric. + """ setattr(self, name_id, default_value) # assert hasattr(self, name_id), f"added metric {name_id} but it is not a member variable of this class" if calc_func is None: @@ -338,7 +367,8 @@ def _add_metric( return default_value def calculate_all(self, print_errors: bool = False): - """Calculates all possible metrics that can be derived + """ + Calculates all possible metrics that can be derived. Args: print_errors (bool, optional): If true, will print every metric that could not be computed and its reason. Defaults to False. @@ -356,6 +386,16 @@ def calculate_all(self, print_errors: bool = False): print(f"Metric {k}: {v}") def _calc(self, k, v): + """ + Attempts to get the value of a metric and captures any exceptions. + + Args: + k: The metric key. + v: The metric value. + + Returns: + A tuple indicating success or failure and the corresponding value or exception. + """ try: v = getattr(self, k) return False, v @@ -389,6 +429,12 @@ def __str__(self) -> str: return text def to_dict(self) -> dict: + """ + Converts the metrics to a dictionary format. + + Returns: + A dictionary containing metric names and their values. + """ return { k: getattr(self, v.id) for k, v in self._evaluation_metrics.items() @@ -400,6 +446,19 @@ def evaluation_metrics(self): return self._evaluation_metrics def get_list_metric(self, metric: Metric, mode: MetricMode): + """ + Retrieves a list of metrics based on the given metric type and mode. + + Args: + metric (Metric): The metric to retrieve. + mode (MetricMode): The mode of the metric. + + Returns: + The corresponding list of metrics. + + Raises: + MetricCouldNotBeComputedException: If the metric cannot be found. + """ if metric in self._list_metrics: return self._list_metrics[metric][mode] else: @@ -408,6 +467,19 @@ def get_list_metric(self, metric: Metric, mode: MetricMode): ) def _calc_metric(self, metric_name: str, supress_error: bool = False): + """ + Calculates a specific metric by its name. + + Args: + metric_name (str): The name of the metric to calculate. + supress_error (bool): If true, suppresses errors during calculation. + + Returns: + The calculated metric value or raises an exception if it cannot be computed. + + Raises: + MetricCouldNotBeComputedException: If the metric cannot be found. + """ if metric_name in self._evaluation_metrics: try: value = self._evaluation_metrics[metric_name](self) @@ -426,6 +498,18 @@ def _calc_metric(self, metric_name: str, supress_error: bool = False): ) def __getattribute__(self, __name: str) -> Any: + """ + Retrieves an attribute, with special handling for evaluation metrics. + + Args: + __name (str): The name of the attribute to retrieve. + + Returns: + The attribute value. + + Raises: + MetricCouldNotBeComputedException: If the requested metric could not be computed. + """ attr = None try: attr = object.__getattribute__(self, __name) diff --git a/panoptica/panoptica_statistics.py b/panoptica/panoptica_statistics.py index c651fb1..96489b3 100644 --- a/panoptica/panoptica_statistics.py +++ b/panoptica/panoptica_statistics.py @@ -174,6 +174,7 @@ def get_summary_figure( self, metric: str, horizontal: bool = True, + sort: bool = True, # title overwrite? ): """Returns a figure object that shows the given metric for each group and its std @@ -191,6 +192,7 @@ def get_summary_figure( data=data_plot, orientation=orientation, score=metric, + sort=sort, ) # groupwise or in total diff --git a/panoptica/utils/config.py b/panoptica/utils/config.py index a3d5157..40c511c 100644 --- a/panoptica/utils/config.py +++ b/panoptica/utils/config.py @@ -7,10 +7,24 @@ def _register_helper_classes(yaml: YAML): + """Registers globally supported helper classes to a YAML instance. + + Args: + yaml (YAML): The YAML instance to register helper classes to. + """ [yaml.register_class(s) for s in supported_helper_classes] def _load_yaml(file: str | Path, registered_class=None): + """Loads a YAML file into a Python dictionary or object, with optional class registration. + + Args: + file (str | Path): Path to the YAML file. + registered_class (optional): Optional class to register with the YAML parser. + + Returns: + dict | object: Parsed content from the YAML file. + """ if isinstance(file, str): file = Path(file) yaml = YAML(typ="safe") @@ -24,6 +38,13 @@ def _load_yaml(file: str | Path, registered_class=None): def _save_yaml(data_dict: dict | object, out_file: str | Path, registered_class=None): + """Saves a Python dictionary or object to a YAML file, with optional class registration. + + Args: + data_dict (dict | object): Data to save. + out_file (str | Path): Output file path. + registered_class (optional): Class type to register with YAML if saving an object. + """ if isinstance(out_file, str): out_file = Path(out_file) @@ -47,13 +68,26 @@ def _save_yaml(data_dict: dict | object, out_file: str | Path, registered_class= # Universal Functions ######### def _register_class_to_yaml(cls): + """Registers a class to the global supported helper classes for YAML serialization. + + Args: + cls: The class to register. + """ global supported_helper_classes if cls not in supported_helper_classes: supported_helper_classes.append(cls) def _load_from_config(cls, path: str | Path): - # cls._register_permanently() + """Loads an instance of a class from a YAML configuration file. + + Args: + cls: The class type to instantiate. + path (str | Path): Path to the YAML configuration file. + + Returns: + An instance of the specified class, loaded from configuration. + """ if isinstance(path, str): path = Path(path) assert path.exists(), f"load_from_config: {path} does not exist" @@ -63,39 +97,82 @@ def _load_from_config(cls, path: str | Path): def _load_from_config_name(cls, name: str): + """Loads an instance of a class from a configuration file identified by name. + + Args: + cls: The class type to instantiate. + name (str): The name used to find the configuration file. + + Returns: + An instance of the specified class. + """ path = config_by_name(name) assert path.exists(), f"load_from_config: {path} does not exist" return _load_from_config(cls, path) def _save_to_config(obj, path: str | Path): + """Saves an instance of a class to a YAML configuration file. + + Args: + obj: The object to save. + path (str | Path): The file path to save the configuration. + """ if isinstance(path, str): path = Path(path) _save_yaml(obj, path, registered_class=type(obj)) def _save_to_config_by_name(obj, name: str): + """Saves an instance of a class to a configuration file by name. + + Args: + obj: The object to save. + name (str): The name used to determine the configuration file path. + """ dir, name = config_dir_by_name(name) _save_to_config(obj, dir.joinpath(name)) class SupportsConfig: - """Metaclass that allows a class to save and load objects by yaml configs""" + """Base class that provides methods for loading and saving instances as YAML configurations. + + This class should be inherited by classes that wish to have load and save functionality for YAML + configurations, with class registration to enable custom serialization and deserialization. + + Methods: + load_from_config(cls, path): Loads a class instance from a YAML file. + load_from_config_name(cls, name): Loads a class instance from a configuration file identified by name. + save_to_config(path): Saves the instance to a YAML file. + save_to_config_by_name(name): Saves the instance to a configuration file identified by name. + to_yaml(cls, representer, node): YAML serialization method (requires _yaml_repr). + from_yaml(cls, constructor, node): YAML deserialization method. + """ def __init__(self) -> None: + """Prevents instantiation of SupportsConfig as it is intended to be a metaclass.""" raise NotImplementedError(f"Tried to instantiate abstract class {type(self)}") def __init_subclass__(cls, **kwargs): - # Registers all subclasses of this + """Registers subclasses of SupportsConfig to enable YAML support.""" super().__init_subclass__(**kwargs) cls._register_permanently() @classmethod def _register_permanently(cls): + """Registers the class to globally supported helper classes.""" _register_class_to_yaml(cls) @classmethod def load_from_config(cls, path: str | Path): + """Loads an instance of the class from a YAML file. + + Args: + path (str | Path): The file path to load the configuration. + + Returns: + An instance of the class. + """ obj = _load_from_config(cls, path) assert isinstance( obj, cls @@ -104,19 +181,45 @@ def load_from_config(cls, path: str | Path): @classmethod def load_from_config_name(cls, name: str): + """Loads an instance of the class from a configuration file identified by name. + + Args: + name (str): The name used to find the configuration file. + + Returns: + An instance of the class. + """ obj = _load_from_config_name(cls, name) assert isinstance(obj, cls) return obj def save_to_config(self, path: str | Path): + """Saves the instance to a YAML configuration file. + + Args: + path (str | Path): The file path to save the configuration. + """ _save_to_config(self, path) def save_to_config_by_name(self, name: str): + """Saves the instance to a configuration file identified by name. + + Args: + name (str): The name used to determine the configuration file path. + """ _save_to_config_by_name(self, name) @classmethod def to_yaml(cls, representer, node): - # cls._register_permanently() + """Serializes the class to YAML format. + + Args: + representer: YAML representer instance. + node: The object instance to serialize. + + Returns: + YAML node: YAML-compatible node representation of the object. + """ assert hasattr( cls, "_yaml_repr" ), f"Class {cls.__name__} has no _yaml_repr(cls, node) defined" @@ -124,11 +227,27 @@ def to_yaml(cls, representer, node): @classmethod def from_yaml(cls, constructor, node): - # cls._register_permanently() + """Deserializes a YAML node to an instance of the class. + + Args: + constructor: YAML constructor instance. + node: YAML node to deserialize. + + Returns: + An instance of the class with attributes populated from YAML data. + """ data = constructor.construct_mapping(node, deep=True) return cls(**data) @classmethod @abstractmethod def _yaml_repr(cls, node) -> dict: + """Abstract method for representing the class in YAML. + + Args: + node: The object instance to represent in YAML. + + Returns: + dict: A dictionary representation of the class. + """ pass # return {"groups": node.__group_dictionary} diff --git a/panoptica/utils/constants.py b/panoptica/utils/constants.py index ecb7f63..dceb70d 100644 --- a/panoptica/utils/constants.py +++ b/panoptica/utils/constants.py @@ -10,6 +10,22 @@ class _Enum_Compare(Enum): + """An extended Enum class that supports additional comparison and YAML configuration functionality. + + This class enhances standard `Enum` capabilities, allowing comparisons with other enums or strings by + name and adding support for YAML serialization and deserialization methods. + + Methods: + __eq__(__value): Checks equality with another Enum or string. + __str__(): Returns a string representation of the Enum instance. + __repr__(): Returns a string representation for debugging. + load_from_config(cls, path): Loads an Enum instance from a configuration file. + load_from_config_name(cls, name): Loads an Enum instance from a configuration file identified by name. + save_to_config(path): Saves the Enum instance to a configuration file. + to_yaml(cls, representer, node): Serializes the Enum to YAML. + from_yaml(cls, constructor, node): Deserializes YAML data into an Enum instance. + """ + def __eq__(self, __value: object) -> bool: if isinstance(__value, Enum): namecheck = self.name == __value.name diff --git a/panoptica/utils/edge_case_handling.py b/panoptica/utils/edge_case_handling.py index e7dd5d1..35267bd 100644 --- a/panoptica/utils/edge_case_handling.py +++ b/panoptica/utils/edge_case_handling.py @@ -5,6 +5,23 @@ class EdgeCaseResult(_Enum_Compare): + """Enumeration of edge case values used for handling specific metric situations. + + This enum defines several common edge case values for handling zero-true-positive (zero-TP) + situations in various metrics. The values include infinity, NaN, zero, one, and None. + + Attributes: + INF: Represents infinity (`np.inf`). + NAN: Represents not-a-number (`np.nan`). + ZERO: Represents zero (0.0). + ONE: Represents one (1.0). + NONE: Represents a None value. + + Methods: + value: Returns the value associated with the edge case. + __call__(): Returns the numeric or None representation of the enum member. + """ + INF = auto() # np.inf NAN = auto() # np.nan ZERO = auto() # 0.0 @@ -29,6 +46,15 @@ def __call__(self): class EdgeCaseZeroTP(_Enum_Compare): + """Enum defining scenarios that could produce zero true positives (zero-TP) in metrics. + + Attributes: + NO_INSTANCES: No instances in both the prediction and reference. + EMPTY_PRED: The prediction is empty. + EMPTY_REF: The reference is empty. + NORMAL: A typical scenario with non-zero instances. + """ + NO_INSTANCES = auto() EMPTY_PRED = auto() EMPTY_REF = auto() @@ -39,6 +65,21 @@ def __hash__(self) -> int: class MetricZeroTPEdgeCaseHandling(SupportsConfig): + """Handles zero-TP edge cases for metrics, mapping different zero-TP scenarios to specific results. + + Attributes: + default_result (EdgeCaseResult | None): Default result if specific edge cases are not provided. + no_instances_result (EdgeCaseResult | None): Result when no instances are present. + empty_prediction_result (EdgeCaseResult | None): Result when prediction is empty. + empty_reference_result (EdgeCaseResult | None): Result when reference is empty. + normal (EdgeCaseResult | None): Result when a normal zero-TP scenario occurs. + + Methods: + __call__(tp, num_pred_instances, num_ref_instances): Determines if an edge case is detected and returns its result. + __eq__(value): Compares this handling object to another. + __str__(): String representation of edge cases. + _yaml_repr(cls, node): YAML representation for the edge case. + """ def __init__( self, @@ -119,6 +160,19 @@ def _yaml_repr(cls, node) -> dict: class EdgeCaseHandler(SupportsConfig): + """Manages edge cases across multiple metrics, including standard deviation handling for empty lists. + + Attributes: + listmetric_zeroTP_handling (dict): Dictionary mapping metrics to their zero-TP edge case handling. + empty_list_std (EdgeCaseResult): Default edge case for handling standard deviation of empty lists. + + Methods: + handle_zero_tp(metric, tp, num_pred_instances, num_ref_instances): Checks if an edge case exists and returns its result. + listmetric_zeroTP_handling: Returns the edge case handling dictionary. + get_metric_zero_tp_handle(metric): Returns the zero-TP handler for a specific metric. + handle_empty_list_std(): Handles standard deviation of empty lists. + _yaml_repr(cls, node): YAML representation of the handler. + """ def __init__( self, diff --git a/panoptica/utils/instancelabelmap.py b/panoptica/utils/instancelabelmap.py index 16fd33c..f0c4b8c 100644 --- a/panoptica/utils/instancelabelmap.py +++ b/panoptica/utils/instancelabelmap.py @@ -3,13 +3,42 @@ # Many-to-One Mapping class InstanceLabelMap(object): - # Mapping ((prediction_label, ...), (reference_label, ...)) + """Creates a mapping between prediction labels and reference labels in a many-to-one relationship. + + This class allows mapping multiple prediction labels to a single reference label. + It includes methods for adding new mappings, checking containment, retrieving + predictions mapped to a reference, and exporting the mapping as a dictionary. + + Attributes: + labelmap (dict[int, int]): Dictionary storing the prediction-to-reference label mappings. + + Methods: + add_labelmap_entry(pred_labels, ref_label): Adds a new entry mapping prediction labels to a reference label. + get_pred_labels_matched_to_ref(ref_label): Retrieves prediction labels mapped to a given reference label. + contains_pred(pred_label): Checks if a prediction label exists in the map. + contains_ref(ref_label): Checks if a reference label exists in the map. + contains_and(pred_label, ref_label): Checks if both a prediction and a reference label are in the map. + contains_or(pred_label, ref_label): Checks if either a prediction or reference label is in the map. + get_one_to_one_dictionary(): Returns the labelmap dictionary for a one-to-one view. + """ + labelmap: dict[int, int] def __init__(self) -> None: self.labelmap = {} def add_labelmap_entry(self, pred_labels: list[int] | int, ref_label: int): + """Adds an entry that maps prediction labels to a single reference label. + + Args: + pred_labels (list[int] | int): List of prediction labels or a single prediction label. + ref_label (int): Reference label to map to. + + Raises: + AssertionError: If `ref_label` is not an integer. + AssertionError: If any `pred_labels` are not integers. + Exception: If a prediction label is already mapped to a different reference label. + """ if not isinstance(pred_labels, list): pred_labels = [pred_labels] assert isinstance(ref_label, int), "add_labelmap_entry: got no int as ref_label" @@ -24,17 +53,50 @@ def add_labelmap_entry(self, pred_labels: list[int] | int, ref_label: int): self.labelmap[p] = ref_label def get_pred_labels_matched_to_ref(self, ref_label: int): + """Retrieves all prediction labels that map to a specified reference label. + + Args: + ref_label (int): The reference label to search. + + Returns: + list[int]: List of prediction labels mapped to `ref_label`. + """ return [k for k, v in self.labelmap.items() if v == ref_label] def contains_pred(self, pred_label: int): + """Checks if a prediction label exists in the map. + + Args: + pred_label (int): The prediction label to search. + + Returns: + bool: True if `pred_label` is in `labelmap`, otherwise False. + """ return pred_label in self.labelmap def contains_ref(self, ref_label: int): + """Checks if a reference label exists in the map. + + Args: + ref_label (int): The reference label to search. + + Returns: + bool: True if `ref_label` is in `labelmap` values, otherwise False. + """ return ref_label in self.labelmap.values() def contains_and( self, pred_label: int | None = None, ref_label: int | None = None ) -> bool: + """Checks if both a prediction and a reference label are in the map. + + Args: + pred_label (int | None): The prediction label to check. + ref_label (int | None): The reference label to check. + + Returns: + bool: True if both `pred_label` and `ref_label` are in the map; otherwise, False. + """ pred_in = True if pred_label is None else pred_label in self.labelmap ref_in = True if ref_label is None else ref_label in self.labelmap.values() return pred_in and ref_in @@ -42,11 +104,25 @@ def contains_and( def contains_or( self, pred_label: int | None = None, ref_label: int | None = None ) -> bool: + """Checks if either a prediction or reference label is in the map. + + Args: + pred_label (int | None): The prediction label to check. + ref_label (int | None): The reference label to check. + + Returns: + bool: True if either `pred_label` or `ref_label` are in the map; otherwise, False. + """ pred_in = True if pred_label is None else pred_label in self.labelmap ref_in = True if ref_label is None else ref_label in self.labelmap.values() return pred_in or ref_in def get_one_to_one_dictionary(self): + """Returns a copy of the labelmap dictionary for a one-to-one view. + + Returns: + dict[int, int]: The prediction-to-reference label mapping. + """ return self.labelmap def __str__(self) -> str: @@ -66,6 +142,15 @@ def __repr__(self) -> str: # Make all variables read-only! def __setattr__(self, attr, value): + """Overrides attribute setting to make attributes read-only after initialization. + + Args: + attr (str): Attribute name. + value (Any): Attribute value. + + Raises: + Exception: If trying to alter an existing attribute. + """ if hasattr(self, attr): raise Exception("Attempting to alter read-only value") diff --git a/panoptica/utils/label_group.py b/panoptica/utils/label_group.py index 34e0a08..0008450 100644 --- a/panoptica/utils/label_group.py +++ b/panoptica/utils/label_group.py @@ -5,18 +5,30 @@ class LabelGroup(SupportsConfig): - """Defines a group of labels that semantically belong to each other. Only labels within a group will be matched with each other""" + """Defines a group of labels that semantically belong together for segmentation purposes. + + Groups of labels define label sets that can be matched with each other. + For example, labels might represent different parts of a segmented object, and only those within the group are eligible for matching. + + Attributes: + value_labels (list[int]): List of integer labels representing segmentation group labels. + single_instance (bool): If True, the group represents a single instance without matching threshold consideration. + """ def __init__( self, value_labels: list[int] | int, single_instance: bool = False, ) -> None: - """Defines a group of labels that semantically belong to each other + """Initializes a LabelGroup with specified labels and single instance setting. Args: - value_labels (list[int]): Actually labels in the prediction and reference mask in this group. Defines the labels that can be matched to each other - single_instance (bool, optional): If true, will not use the matching_threshold as there is only one instance (large organ, ...). Defaults to False. + value_labels (list[int] | int): Labels in the prediction and reference mask for this group. + single_instance (bool, optional): If True, ignores matching threshold as only one instance exists. Defaults to False. + + Raises: + AssertionError: If `value_labels` is empty or if labels are not positive integers. + AssertionError: If `single_instance` is True but more than one label is provided. """ if isinstance(value_labels, int): value_labels = [value_labels] @@ -40,10 +52,12 @@ def __init__( @property def value_labels(self) -> list[int]: + """List of integer labels for this segmentation group.""" return self.__value_labels @property def single_instance(self) -> bool: + """Indicates if this group is treated as a single instance.""" return self.__single_instance def extract_label( @@ -51,14 +65,14 @@ def extract_label( array: np.ndarray, set_to_binary: bool = False, ): - """Extracts the labels of this class + """Extracts an array of the labels specific to this segmentation group. Args: - array (np.ndarray): Array to extract the segmentation group labels from - set_to_binary (bool, optional): If true, will output a binary array. Defaults to False. + array (np.ndarray): The array to filter for segmentation group labels. + set_to_binary (bool, optional): If True, outputs a binary array. Defaults to False. Returns: - np.ndarray: Array containing only the labels of this segmentation group + np.ndarray: An array with only the labels of this segmentation group. """ array = array.copy() array[np.isin(array, self.value_labels, invert=True)] = 0 @@ -70,6 +84,14 @@ def __call__( self, array: np.ndarray, ) -> np.ndarray: + """Extracts labels from an array for this group when the instance is called. + + Args: + array (np.ndarray): Array to filter for segmentation group labels. + + Returns: + np.ndarray: Array containing only the labels for this segmentation group. + """ return self.extract_label(array, set_to_binary=False) def __str__(self) -> str: @@ -87,6 +109,14 @@ def _yaml_repr(cls, node): class LabelMergeGroup(LabelGroup): + """Defines a group of labels that will be merged into a single label when extracted. + + Inherits from LabelGroup and sets extracted labels to binary format. + + Methods: + __call__(array): Extracts the label group as a binary array. + """ + def __init__( self, value_labels: list[int] | int, single_instance: bool = False ) -> None: @@ -96,14 +126,13 @@ def __call__( self, array: np.ndarray, ) -> np.ndarray: - """Extracts the labels of this class + """Extracts the labels of this group as a binary array. Args: - array (np.ndarray): Array to extract the segmentation group labels from - set_to_binary (bool, optional): If true, will output a binary array. Defaults to False. + array (np.ndarray): Array to filter for segmentation group labels. Returns: - np.ndarray: Array containing only the labels of this segmentation group + np.ndarray: Binary array representing presence or absence of group labels. """ return self.extract_label(array, set_to_binary=True) @@ -112,6 +141,14 @@ def __str__(self) -> str: class _LabelGroupAny(LabelGroup): + """Represents a group that includes all labels in the array with no specific segmentation constraints. + + Used to represent a group that does not restrict labels. + + Methods: + __call__(array, set_to_binary): Returns the unfiltered array. + """ + def __init__(self) -> None: pass @@ -128,14 +165,14 @@ def __call__( array: np.ndarray, set_to_binary: bool = False, ) -> np.ndarray: - """Extracts the labels of this class + """Returns the original array, unfiltered. Args: - array (np.ndarray): Array to extract the segmentation group labels from - set_to_binary (bool, optional): If true, will output a binary array. Defaults to False. + array (np.ndarray): The original array to return. + set_to_binary (bool, optional): Ignored in this implementation. Returns: - np.ndarray: Array containing only the labels of this segmentation group + np.ndarray: The original, unmodified array. """ array = array.copy() return array diff --git a/panoptica/utils/parallel_processing.py b/panoptica/utils/parallel_processing.py index 50e3691..b7a7bac 100644 --- a/panoptica/utils/parallel_processing.py +++ b/panoptica/utils/parallel_processing.py @@ -4,6 +4,20 @@ class NoDaemonProcess(Process): + """A subclass of `multiprocessing.Process` that overrides daemon behavior to always be non-daemonic. + + Useful for creating a process that allows child processes to spawn their own children, + as daemonic processes in Python cannot create further subprocesses. + + Attributes: + group (None): Reserved for future extension when using process groups. + target (Callable[..., object] | None): The callable object to be invoked by the process. + name (str | None): The name of the process, for identification. + args (tuple): Arguments to pass to the `target` function. + kwargs (dict): Keyword arguments to pass to the `target` function. + daemon (bool | None): Indicates if the process is daemonic (overridden to always be False). + """ + def __init__( self, group: None = None, @@ -33,4 +47,9 @@ def _set_daemon(self, value): # We sub-class multiprocessing.pool.Pool instead of multiprocessing.Pool # because the latter is only a wrapper function, not a proper class. class NonDaemonicPool(multiprocessing.pool.Pool): + """A version of `multiprocessing.pool.Pool` using non-daemonic processes, allowing child processes to spawn their own children. + + This class creates a pool of worker processes using `NoDaemonProcess` for situations where nested child processes are needed. + """ + Process = NoDaemonProcess diff --git a/panoptica/utils/processing_pair.py b/panoptica/utils/processing_pair.py index 6ef4640..f6901b7 100644 --- a/panoptica/utils/processing_pair.py +++ b/panoptica/utils/processing_pair.py @@ -13,9 +13,17 @@ class _ProcessingPair(ABC): - """ - Represents a general processing pair consisting of a reference array and a prediction array. Type of array can be arbitrary (integer recommended) - Every member is read-only! + """Represents a pair of processing arrays, typically prediction and reference arrays. + + This base class provides core functionality for processing and comparing prediction + and reference data arrays. Each instance contains two arrays and supports cropping and + data integrity checks. + + Attributes: + n_dim (int): The number of dimensions in the reference array. + crop (tuple[slice, ...] | None): The crop region applied to both arrays, if any. + is_cropped (bool): Indicates whether the arrays have been cropped. + uncropped_shape (tuple[int, ...]): The original shape of the arrays before cropping. """ _prediction_arr: np.ndarray @@ -28,12 +36,12 @@ class _ProcessingPair(ABC): def __init__( self, prediction_arr: np.ndarray, reference_arr: np.ndarray, dtype: type | None ) -> None: - """Initializes a general Processing Pair + """Initializes the processing pair with prediction and reference arrays. Args: - prediction_arr (np.ndarray): Numpy array containig the prediction labels - reference_arr (np.ndarray): Numpy array containig the reference labels - dtype (type | None): Datatype that is asserted. None for no assertion + prediction_arr (np.ndarray): Numpy array of prediction labels. + reference_arr (np.ndarray): Numpy array of reference labels. + dtype (type | None): The expected datatype of arrays. If None, no datatype check is performed. """ _check_array_integrity(prediction_arr, reference_arr, dtype=dtype) self._prediction_arr = prediction_arr @@ -51,6 +59,11 @@ def __init__( self.uncropped_shape: tuple[int, ...] = reference_arr.shape def crop_data(self, verbose: bool = False): + """Crops prediction and reference arrays to non-zero regions. + + Args: + verbose (bool, optional): If True, prints cropping details. Defaults to False. + """ if self.is_cropped: return if self.crop is None: @@ -72,6 +85,11 @@ def crop_data(self, verbose: bool = False): self.is_cropped = True def uncrop_data(self, verbose: bool = False): + """Restores the arrays to their original, uncropped shape. + + Args: + verbose (bool, optional): If True, prints uncropping details. Defaults to False. + """ if self.is_cropped == False: return assert ( @@ -94,6 +112,11 @@ def uncrop_data(self, verbose: bool = False): self.is_cropped = False def set_dtype(self, type): + """Sets the data type for both prediction and reference arrays. + + Args: + dtype (type): Expected integer type for the arrays. + """ assert np.issubdtype( type, int_type ), "set_dtype: tried to set dtype to something other than integers" @@ -136,8 +159,13 @@ def copy(self): class _ProcessingPairInstanced(_ProcessingPair): - """ - A ProcessingPair that contains instances, additionally has number of instances available + """Represents a processing pair with labeled instances, including unique label counts. + + This subclass tracks additional details about the number of unique instances in each array. + + Attributes: + n_prediction_instance (int): Number of unique prediction instances. + n_reference_instance (int): Number of unique reference instances. """ n_prediction_instance: int @@ -151,7 +179,15 @@ def __init__( n_prediction_instance: int | None = None, n_reference_instance: int | None = None, ) -> None: - # reduce to lowest uint + """Initializes a processing pair for instances. + + Args: + prediction_arr (np.ndarray): Array of predicted instance labels. + reference_arr (np.ndarray): Array of reference instance labels. + dtype (type | None): Expected data type of the arrays. + n_prediction_instance (int | None, optional): Pre-calculated number of prediction instances. + n_reference_instance (int | None, optional): Pre-calculated number of reference instances. + """ super().__init__(prediction_arr, reference_arr, dtype) if n_prediction_instance is None: self.n_prediction_instance = _count_unique_without_zeros(prediction_arr) @@ -178,20 +214,19 @@ def copy(self): def _check_array_integrity( prediction_arr: np.ndarray, reference_arr: np.ndarray, dtype: type | None = None ): - """ - Check the integrity of two numpy arrays. + """Validates integrity between two arrays, checking shape, dtype, and consistency with `dtype`. - Parameters: - - prediction_arr (np.ndarray): The array to be checked. - - reference_arr (np.ndarray): The reference array for comparison. - - dtype (type | None): The expected data type for both arrays. Defaults to None. + Args: + prediction_arr (np.ndarray): The array to be validated. + reference_arr (np.ndarray): The reference array for comparison. + dtype (type | None): Expected type of the arrays. If None, dtype validation is skipped. Raises: - - AssertionError: If prediction_arr or reference_arr are not numpy arrays. - - AssertionError: If the shapes of prediction_arr and reference_arr do not match. - - AssertionError: If the data types of prediction_arr and reference_arr do not match. - - AssertionError: If dtype is provided and the data types of prediction_arr and/or reference_arr - do not match the specified dtype. + AssertionError: If validation fails in any of the following cases: + - Arrays are not numpy arrays. + - Shapes of both arrays are not identical. + - Data types of both arrays do not match. + - Dtype mismatch when specified. Example: >>> _check_array_integrity(np.array([1, 2, 3]), np.array([4, 5, 6]), dtype=int) @@ -214,7 +249,10 @@ def _check_array_integrity( class SemanticPair(_ProcessingPair): - """A Processing pair that contains Semantic Labels""" + """Represents a semantic processing pair with integer-type arrays for label analysis. + + This class is tailored to scenarios where arrays contain semantic labels rather than instance IDs. + """ def __init__(self, prediction_arr: np.ndarray, reference_arr: np.ndarray) -> None: super().__init__(prediction_arr, reference_arr, dtype=int_type) @@ -243,9 +281,15 @@ def __init__( class MatchedInstancePair(_ProcessingPairInstanced): - """ - A Processing pair that contain Matched Instance Maps, i.e. each equal label in both maps are a match - Can be of any unsigned (but matching) integer type + """Represents a matched processing pair for instance maps, handling matched and unmatched labels. + + This class tracks both matched instances and any unmatched labels between prediction + and reference arrays. + + Attributes: + missed_reference_labels (list[int]): Reference labels with no matching prediction. + missed_prediction_labels (list[int]): Prediction labels with no matching reference. + matched_instances (list[int]): Labels matched between prediction and reference arrays. """ missed_reference_labels: list[int] @@ -319,6 +363,21 @@ def copy(self): @dataclass class EvaluateInstancePair: + """Represents an evaluation of instance segmentation results, comparing reference and prediction data. + + This class is used to store and evaluate metrics for instance segmentation, tracking the number of instances + and true positives (tp) alongside calculated metrics. + + Attributes: + reference_arr (np.ndarray): Array containing reference instance labels. + prediction_arr (np.ndarray): Array containing predicted instance labels. + num_pred_instances (int): The number of unique instances in the prediction array. + num_ref_instances (int): The number of unique instances in the reference array. + tp (int): The number of true positive matches between predicted and reference instances. + list_metrics (dict[Metric, list[float]]): Dictionary of metric calculations, where each key is a `Metric` + object, and each value is a list of metric scores (floats). + """ + reference_arr: np.ndarray prediction_arr: np.ndarray num_pred_instances: int @@ -328,6 +387,27 @@ class EvaluateInstancePair: class InputType(_Enum_Compare): + """Defines the types of input processing pairs available for evaluation. + + This enumeration provides different processing classes for handling various instance segmentation scenarios, + allowing flexible instantiation of processing pairs based on the desired comparison type. + + Attributes: + SEMANTIC (SemanticPair): Processes semantic labels, intended for cases without instances. + UNMATCHED_INSTANCE (UnmatchedInstancePair): Processes instance maps without requiring label matches. + MATCHED_INSTANCE (MatchedInstancePair): Processes instance maps with label matching between prediction + and reference. + + Methods: + __call__(self, prediction_arr: np.ndarray, reference_arr: np.ndarray) -> _ProcessingPair: + Creates a processing pair based on the specified `InputType`, using the provided prediction + and reference arrays. + + Example: + >>> input_type = InputType.MATCHED_INSTANCE + >>> processing_pair = input_type(prediction_arr, reference_arr) + """ + SEMANTIC = SemanticPair UNMATCHED_INSTANCE = UnmatchedInstancePair MATCHED_INSTANCE = MatchedInstancePair @@ -339,6 +419,15 @@ def __call__( class IntermediateStepsData: + """Manages intermediate data steps for a processing pipeline, storing and retrieving processing states. + + This class enables step-by-step tracking of data transformations during processing. + + Attributes: + original_input (_ProcessingPair | None): The original input data before processing steps. + _intermediatesteps (dict[str, _ProcessingPair]): Dictionary of intermediate processing steps. + """ + def __init__(self, original_input: _ProcessingPair | None): self._original_input = original_input self._intermediatesteps: dict[str, _ProcessingPair] = {} diff --git a/panoptica/utils/segmentation_class.py b/panoptica/utils/segmentation_class.py index 16308c4..a0a32b9 100644 --- a/panoptica/utils/segmentation_class.py +++ b/panoptica/utils/segmentation_class.py @@ -7,7 +7,24 @@ class SegmentationClassGroups(SupportsConfig): - # + """Represents a collection of segmentation class groups. + + This class manages groups of labels used in segmentation tasks, ensuring that each label is defined + exactly once across all groups. It supports both list and dictionary formats for group initialization. + + Attributes: + __group_dictionary (dict[str, LabelGroup]): A dictionary mapping group names to their respective LabelGroup instances. + __labels (list[int]): A flat list of unique labels collected from all LabelGroups. + + Args: + groups (list[LabelGroup] | dict[str, LabelGroup | tuple[list[int] | int, bool]]): + A list of `LabelGroup` instances or a dictionary where keys are group names (str) and values are either + `LabelGroup` instances or tuples containing a list of label values and a boolean. + + Raises: + AssertionError: If the same label is assigned to multiple groups. + """ + def __init__( self, groups: list[LabelGroup] | dict[str, LabelGroup | tuple[list[int] | int, bool]], @@ -45,6 +62,18 @@ def __init__( def has_defined_labels_for( self, arr: np.ndarray | list[int], raise_error: bool = False ): + """Checks if the labels in the provided array are defined in the segmentation class groups. + + Args: + arr (np.ndarray | list[int]): The array of labels to check. + raise_error (bool): If True, raises an error when an undefined label is found. Defaults to False. + + Returns: + bool: True if all labels are defined; False otherwise. + + Raises: + AssertionError: If an undefined label is found and raise_error is True. + """ if isinstance(arr, list): arr_labels = arr else: @@ -90,6 +119,14 @@ def _yaml_repr(cls, node): def list_duplicates(seq): + """Identifies duplicates in a sequence. + + Args: + seq (list): The input sequence to check for duplicates. + + Returns: + list: A list of duplicates found in the input sequence. + """ seen = set() seen_add = seen.add # adds all elements it doesn't know yet to seen and all other to seen_twice @@ -99,6 +136,14 @@ def list_duplicates(seq): class _NoSegmentationClassGroups(SegmentationClassGroups): + """Represents a placeholder for segmentation class groups with no defined labels. + + This class indicates that no specific segmentation groups or labels are defined, and any label is valid. + + Attributes: + __group_dictionary (dict[str, LabelGroup]): A dictionary with a single entry representing all labels as a group. + """ + def __init__(self) -> None: self.__group_dictionary = {NO_GROUP_KEY: _LabelGroupAny()} diff --git a/unit_tests/test_metrics.py b/unit_tests/test_metrics.py index f889c10..343cb87 100644 --- a/unit_tests/test_metrics.py +++ b/unit_tests/test_metrics.py @@ -152,6 +152,17 @@ def test_dsc_case_simple_identical_idx(self): ) self.assertEqual(dsc, 1.0) + def test_dsc_case_simple_identical_wrong_idx(self): + + pred_arr, ref_arr = case_simple_identical() + dsc = Metric.DSC( + reference_arr=ref_arr, + prediction_arr=pred_arr, + ref_instance_idx=2, + pred_instance_idx=2, + ) + self.assertEqual(dsc, 0.0) + def test_dsc_case_simple_nooverlap(self): pred_arr, ref_arr = case_simple_nooverlap()