From fb6eeb15d972c89b572d876f5cb11d82f86e4581 Mon Sep 17 00:00:00 2001 From: iback Date: Fri, 25 Oct 2024 12:55:14 +0000 Subject: [PATCH 1/2] added at least GPT docstrings for everything --- panoptica/_functionals.py | 53 +++--- panoptica/metrics/assd.py | 31 +++- panoptica/metrics/metrics.py | 128 ++++++++------ panoptica/panoptica_aggregator.py | 98 +++++++---- panoptica/panoptica_result.py | 125 ++++++++++---- panoptica/panoptica_statistics.py | 44 ++--- panoptica/utils/config.py | 137 +++++++++++++-- panoptica/utils/constants.py | 16 ++ panoptica/utils/edge_case_handling.py | 86 +++++++--- panoptica/utils/instancelabelmap.py | 103 ++++++++++-- panoptica/utils/label_group.py | 85 ++++++---- panoptica/utils/parallel_processing.py | 19 +++ panoptica/utils/processing_pair.py | 223 +++++++++++++++---------- panoptica/utils/segmentation_class.py | 69 ++++++-- unit_tests/test_metrics.py | 6 + 15 files changed, 871 insertions(+), 352 deletions(-) diff --git a/panoptica/_functionals.py b/panoptica/_functionals.py index 3b6a6b1..9b19bbb 100644 --- a/panoptica/_functionals.py +++ b/panoptica/_functionals.py @@ -33,11 +33,7 @@ def _calc_overlapping_labels( # instance_pairs = [(reference_arr, prediction_arr, i, j) for i, j in overlapping_indices] # (ref, pred) - return [ - (int(i % (max_ref)), int(i // (max_ref))) - for i in np.unique(overlap_arr) - if i > max_ref - ] + return [(int(i % (max_ref)), int(i // (max_ref))) for i in np.unique(overlap_arr) if i > max_ref] def _calc_matching_metric_of_overlapping_labels( @@ -67,13 +63,8 @@ def _calc_matching_metric_of_overlapping_labels( with Pool() as pool: mm_values = pool.starmap(matching_metric.value, instance_pairs) - mm_pairs = [ - (i, (instance_pairs[idx][2], instance_pairs[idx][3])) - for idx, i in enumerate(mm_values) - ] - mm_pairs = sorted( - mm_pairs, key=lambda x: x[0], reverse=not matching_metric.decreasing - ) + mm_pairs = [(i, (instance_pairs[idx][2], instance_pairs[idx][3])) for idx, i in enumerate(mm_values)] + mm_pairs = sorted(mm_pairs, key=lambda x: x[0], reverse=not matching_metric.decreasing) return mm_pairs @@ -141,6 +132,23 @@ def _get_paired_crop( reference_arr: np.ndarray, px_pad: int = 2, ): + """Calculates a bounding box based on paired prediction and reference arrays. + + This function combines the prediction and reference arrays, checks if they are identical, + and computes a bounding box around the non-zero regions. If both arrays are completely zero, + a small value is added to ensure the bounding box is valid. + + Args: + prediction_arr (np.ndarray): The predicted segmentation array. + reference_arr (np.ndarray): The ground truth segmentation array. + px_pad (int, optional): Padding to apply around the bounding box. Defaults to 2. + + Returns: + np.ndarray: The bounding box coordinates around the combined non-zero regions. + + Raises: + AssertionError: If the prediction and reference arrays do not have the same shape. + """ assert prediction_arr.shape == reference_arr.shape combined = prediction_arr + reference_arr @@ -150,10 +158,17 @@ def _get_paired_crop( def _round_to_n(value: float | int, n_significant_digits: int = 2): - return ( - value - if value == 0 - else round( - value, -int(math.floor(math.log10(abs(value)))) + (n_significant_digits - 1) - ) - ) + """Rounds a number to a specified number of significant digits. + + This function rounds the given value to the specified number of significant digits. + If the value is zero, it is returned unchanged. + + Args: + value (float | int): The number to be rounded. + n_significant_digits (int, optional): The number of significant digits to round to. + Defaults to 2. + + Returns: + float: The rounded value. + """ + return value if value == 0 else round(value, -int(math.floor(math.log10(abs(value)))) + (n_significant_digits - 1)) diff --git a/panoptica/metrics/assd.py b/panoptica/metrics/assd.py index b98b407..f33a809 100644 --- a/panoptica/metrics/assd.py +++ b/panoptica/metrics/assd.py @@ -85,12 +85,8 @@ def __surface_distances(reference, prediction, voxelspacing=None, connectivity=1 # raise RuntimeError("The second supplied array does not contain any binary object.") # extract only 1-pixel border line of objects - result_border = prediction ^ binary_erosion( - prediction, structure=footprint, iterations=1 - ) - reference_border = reference ^ binary_erosion( - reference, structure=footprint, iterations=1 - ) + result_border = prediction ^ binary_erosion(prediction, structure=footprint, iterations=1) + reference_border = reference ^ binary_erosion(reference, structure=footprint, iterations=1) # compute average surface distance # Note: scipys distance transform is calculated only inside the borders of the @@ -107,6 +103,29 @@ def _distance_transform_edt( return_distances=True, return_indices=False, ): + """Computes the Euclidean distance transform and/or feature transform of a binary array. + + This function calculates the Euclidean distance transform (EDT) of a binary array, + which gives the distance from each non-zero point to the nearest zero point. It can + also return the feature transform, which provides indices to the nearest non-zero point. + + Args: + input_array (np.ndarray): The input binary array where non-zero values are considered + foreground. + sampling (optional): A sequence or array that specifies the spacing along each dimension. + If provided, scales the distances by the sampling value along each axis. + return_distances (bool, optional): If True, returns the distance transform. Default is True. + return_indices (bool, optional): If True, returns the feature transform with indices to + the nearest foreground points. Default is False. + + Returns: + np.ndarray or tuple[np.ndarray, ...]: If `return_distances` is True, returns the distance + transform as an array. If `return_indices` is True, returns the feature transform. If both + are True, returns a tuple with the distance and feature transforms. + + Raises: + ValueError: If the input array is empty or has unsupported dimensions. + """ # calculate the feature transform # input = np.atleast_1d(np.where(input, 1, 0).astype(np.int8)) # if sampling is not None: diff --git a/panoptica/metrics/metrics.py b/panoptica/metrics/metrics.py index 5cbce45..2d00638 100644 --- a/panoptica/metrics/metrics.py +++ b/panoptica/metrics/metrics.py @@ -20,7 +20,25 @@ @dataclass class _Metric: - """A Metric class containing a name, whether higher or lower values is better, and a function to calculate that metric between two instances in an array""" + """Represents a metric with a name, direction (increasing or decreasing), and a calculation function. + + This class provides a framework for defining and calculating metrics, which can be used + to evaluate the similarity or performance between reference and prediction arrays. + The metric direction indicates whether higher or lower values are better. + + Attributes: + name (str): Short name of the metric. + long_name (str): Full descriptive name of the metric. + decreasing (bool): If True, lower metric values are better; otherwise, higher values are preferred. + _metric_function (Callable): A callable function that computes the metric + between two input arrays. + + Example: + >>> my_metric = _Metric(name="accuracy", long_name="Accuracy", decreasing=False, _metric_function=accuracy_function) + >>> score = my_metric(reference_array, prediction_array) + >>> print(score) + + """ name: str long_name: str @@ -36,13 +54,25 @@ def __call__( *args, **kwargs, ) -> int | float: + """Calculates the metric between reference and prediction arrays. + + Args: + reference_arr (np.ndarray): The reference array. + prediction_arr (np.ndarray): The prediction array. + ref_instance_idx (int, optional): The instance index to filter in the reference array. + pred_instance_idx (int | list[int], optional): Instance index or indices to filter in + the prediction array. + *args: Additional positional arguments for the metric function. + **kwargs: Additional keyword arguments for the metric function. + + Returns: + int | float: The computed metric value. + """ if ref_instance_idx is not None and pred_instance_idx is not None: reference_arr = reference_arr.copy() == ref_instance_idx if isinstance(pred_instance_idx, int): pred_instance_idx = [pred_instance_idx] - prediction_arr = np.isin( - prediction_arr.copy(), pred_instance_idx - ) # type:ignore + prediction_arr = np.isin(prediction_arr.copy(), pred_instance_idx) # type:ignore return self._metric_function(reference_arr, prediction_arr, *args, **kwargs) def __eq__(self, __value: object) -> bool: @@ -60,18 +90,34 @@ def __repr__(self) -> str: return str(self) def __hash__(self) -> int: + """Hash based on metric name, constrained to fit within 8 digits. + + Returns: + int: The hash value of the metric. + """ return abs(hash(self.name)) % (10**8) @property def increasing(self): + """Indicates if higher values of the metric are better. + + Returns: + bool: True if increasing values are preferred, otherwise False. + """ return not self.decreasing - def score_beats_threshold( - self, matching_score: float, matching_threshold: float - ) -> bool: - return (self.increasing and matching_score >= matching_threshold) or ( - self.decreasing and matching_score <= matching_threshold - ) + def score_beats_threshold(self, matching_score: float, matching_threshold: float) -> bool: + """Determines if a matching score meets a specified threshold. + + Args: + matching_score (float): The score to evaluate. + matching_threshold (float): The threshold value to compare against. + + Returns: + bool: True if the score meets the threshold, taking into account the + metric's preferred direction. + """ + return (self.increasing and matching_score >= matching_threshold) or (self.decreasing and matching_score <= matching_threshold) class DirectValueMeta(EnumMeta): @@ -138,9 +184,7 @@ def __call__( **kwargs, ) - def score_beats_threshold( - self, matching_score: float, matching_threshold: float - ) -> bool: + def score_beats_threshold(self, matching_score: float, matching_threshold: float) -> bool: """Calculates whether a score beats a specified threshold Args: @@ -150,9 +194,7 @@ def score_beats_threshold( Returns: bool: True if the matching_score beats the threshold, False otherwise. """ - return (self.increasing and matching_score >= matching_threshold) or ( - self.decreasing and matching_score <= matching_threshold - ) + return (self.increasing and matching_score >= matching_threshold) or (self.decreasing and matching_score <= matching_threshold) @property def name(self): @@ -206,6 +248,16 @@ def __init__(self, *args: object) -> None: class Evaluation_Metric: + """This represents a metric in the evaluation derived from other metrics or list metrics (no circular dependancies!) + + Args: + name_id (str): code-name of this metric, must be same as the member variable of PanopticResult + calc_func (Callable): the function to calculate this metric based on the PanopticResult object + long_name (str | None, optional): A longer descriptive name for printing/logging purposes. Defaults to None. + was_calculated (bool, optional): Whether this metric has been calculated or not. Defaults to False. + error (bool, optional): If true, means the metric could not have been calculated (because dependancies do not exist or have this flag set to True). Defaults to False. + """ + def __init__( self, name_id: str, @@ -215,15 +267,7 @@ def __init__( was_calculated: bool = False, error: bool = False, ): - """This represents a metric in the evaluation derived from other metrics or list metrics (no circular dependancies!) - Args: - name_id (str): code-name of this metric, must be same as the member variable of PanopticResult - calc_func (Callable): the function to calculate this metric based on the PanopticResult object - long_name (str | None, optional): A longer descriptive name for printing/logging purposes. Defaults to None. - was_calculated (bool, optional): Whether this metric has been calculated or not. Defaults to False. - error (bool, optional): If true, means the metric could not have been calculated (because dependancies do not exist or have this flag set to True). Defaults to False. - """ self.id = name_id self.metric_type = metric_type self._calc_func = calc_func @@ -249,9 +293,7 @@ def __call__(self, result_obj: "PanopticaResult") -> Any: # ERROR if self._error: if self._error_obj is None: - self._error_obj = MetricCouldNotBeComputedException( - f"Metric {self.id} requested, but could not be computed" - ) + self._error_obj = MetricCouldNotBeComputedException(f"Metric {self.id} requested, but could not be computed") raise self._error_obj # Already calculated? if self._was_calculated: @@ -259,12 +301,8 @@ def __call__(self, result_obj: "PanopticaResult") -> Any: # Calculate it try: - assert ( - not self._was_calculated - ), f"Metric {self.id} was called to compute, but is set to have been already calculated" - assert ( - self._calc_func is not None - ), f"Metric {self.id} was called to compute, but has no calculation function set" + assert not self._was_calculated, f"Metric {self.id} was called to compute, but is set to have been already calculated" + assert self._calc_func is not None, f"Metric {self.id} was called to compute, but has no calculation function set" value = self._calc_func(result_obj) except MetricCouldNotBeComputedException as e: value = e @@ -309,32 +347,20 @@ def __init__( else: self.AVG = None if self.ALL is None else np.average(self.ALL) self.SUM = None if self.ALL is None else np.sum(self.ALL) - self.MIN = ( - None if self.ALL is None or len(self.ALL) == 0 else np.min(self.ALL) - ) - self.MAX = ( - None if self.ALL is None or len(self.ALL) == 0 else np.max(self.ALL) - ) - - self.STD = ( - None - if self.ALL is None - else empty_list_std if len(self.ALL) == 0 else np.std(self.ALL) - ) + self.MIN = None if self.ALL is None or len(self.ALL) == 0 else np.min(self.ALL) + self.MAX = None if self.ALL is None or len(self.ALL) == 0 else np.max(self.ALL) + + self.STD = None if self.ALL is None else empty_list_std if len(self.ALL) == 0 else np.std(self.ALL) def __getitem__(self, mode: MetricMode | str): if self.error: - raise MetricCouldNotBeComputedException( - f"Metric {self.id} has not been calculated, add it to your eval_metrics" - ) + raise MetricCouldNotBeComputedException(f"Metric {self.id} has not been calculated, add it to your eval_metrics") if isinstance(mode, MetricMode): mode = mode.name if hasattr(self, mode): return getattr(self, mode) else: - raise MetricCouldNotBeComputedException( - f"List_Metric {self.id} does not contain {mode} member" - ) + raise MetricCouldNotBeComputedException(f"List_Metric {self.id} does not contain {mode} member") if __name__ == "__main__": diff --git a/panoptica/panoptica_aggregator.py b/panoptica/panoptica_aggregator.py index 78e7463..e0d761d 100644 --- a/panoptica/panoptica_aggregator.py +++ b/panoptica/panoptica_aggregator.py @@ -30,9 +30,11 @@ # class Panoptica_Aggregator: - # internal_list_lock = Lock() - # - """Aggregator that calls evaluations and saves the resulting metrics per sample. Can be used to create statistics, ...""" + """Aggregator that manages evaluations and saves resulting metrics per sample. + + This class interfaces with the `Panoptica_Evaluator` to perform evaluations, + store results, and manage file outputs for statistical analysis. + """ def __init__( self, @@ -41,10 +43,18 @@ def __init__( log_times: bool = False, continue_file: bool = True, ): - """ + """Initializes the Panoptica_Aggregator. + Args: - panoptica_evaluator (Panoptica_Evaluator): The Panoptica_Evaluator used for the pipeline. - output_file (Path | None, optional): If given, will stream the sample results into this file. If the file is existent, will append results if not already there. Defaults to None. + panoptica_evaluator (Panoptica_Evaluator): The evaluator used for performing evaluations. + output_file (Path | str): Path to the output file for storing results. If the file exists, + results will be appended. If it doesn't, a new file will be created. + log_times (bool, optional): If True, computation times will be logged. Defaults to False. + continue_file (bool, optional): If True, results will continue from existing entries in the file. + Defaults to True. + + Raises: + AssertionError: If the output directory does not exist or if the file extension is not `.tsv`. """ self.__panoptica_evaluator = panoptica_evaluator self.__class_group_names = panoptica_evaluator.segmentation_class_groups_names @@ -56,9 +66,7 @@ def __init__( if isinstance(output_file, str): output_file = Path(output_file) # uses tsv - assert ( - output_file.parent.exists() - ), f"Directory {str(output_file.parent)} does not exist" + assert output_file.parent.exists(), f"Directory {str(output_file.parent)} does not exist" out_file_path = str(output_file) @@ -72,19 +80,13 @@ def __init__( else: out_file_path += ".tsv" # add extension - out_buffer_file: Path = Path(out_file_path).parent.joinpath( - "panoptica_aggregator_tmp.tsv" - ) + out_buffer_file: Path = Path(out_file_path).parent.joinpath("panoptica_aggregator_tmp.tsv") self.__output_buffer_file = out_buffer_file Path(out_file_path).parent.mkdir(parents=True, exist_ok=True) self.__output_file = out_file_path - header = ["subject_name"] + [ - f"{g}-{m}" - for g in self.__class_group_names - for m in self.__evaluation_metrics - ] + header = ["subject_name"] + [f"{g}-{m}" for g in self.__class_group_names for m in self.__evaluation_metrics] header_hash = hash("+".join(header)) if not output_file.exists(): @@ -98,9 +100,7 @@ def __init__( continue_file = True else: # TODO should also hash panoptica_evaluator just to make sure! and then save into header of file - assert header_hash == hash( - "+".join(header_list) - ), "Hash of header not the same! You are using a different setup!" + assert header_hash == hash("+".join(header_list)), "Hash of header not the same! You are using a different setup!" if out_buffer_file.exists(): os.remove(out_buffer_file) @@ -115,10 +115,16 @@ def __init__( atexit.register(self.__exist_handler) def __exist_handler(self): + """Handles cleanup upon program exit by removing the temporary output buffer file.""" if self.__output_buffer_file is not None and self.__output_buffer_file.exists(): os.remove(str(self.__output_buffer_file)) def make_statistic(self) -> Panoptica_Statistic: + """Generates statistics from the aggregated evaluation results. + + Returns: + Panoptica_Statistic: The statistics object containing the results. + """ with filelock: obj = Panoptica_Statistic.from_file(self.__output_file) return obj @@ -129,14 +135,16 @@ def evaluate( reference_arr: np.ndarray, subject_name: str, ): - """Evaluates one case + """Evaluates a single case using the provided prediction and reference arrays. Args: - prediction_arr (np.ndarray): Prediction array - reference_arr (np.ndarray): reference array - subject_name (str | None, optional): Unique name of the sample. If none, will give it a name based on count. Defaults to None. - skip_already_existent (bool): If true, will skip subjects which were already evaluated instead of crashing. Defaults to False. - verbose (bool | None, optional): Verbose. Defaults to None. + prediction_arr (np.ndarray): The array containing the predicted segmentation. + reference_arr (np.ndarray): The array containing the ground truth segmentation. + subject_name (str): A unique name for the sample being evaluated. If none is provided, + a name will be generated based on the count. + + Raises: + ValueError: If the subject name has already been evaluated or is in process. """ # Read tmp file to see which sample names are blocked with inevalfilelock: @@ -164,6 +172,12 @@ def evaluate( self._save_one_subject(subject_name, res) def _save_one_subject(self, subject_name, result_grouped): + """Saves the evaluation results for a single subject. + + Args: + subject_name (str): The name of the subject whose results are being saved. + result_grouped (dict): A dictionary of grouped results from the evaluation. + """ with filelock: # content = [subject_name] @@ -186,9 +200,19 @@ def panoptica_evaluator(self): def _read_first_row(file: str | Path): + """Reads the first row of a TSV file. + + NOT THREAD SAFE BY ITSELF! + + Args: + file (str | Path): The path to the file from which to read the first row. + + Returns: + list: The first row of the file as a list of strings. + """ if isinstance(file, Path): file = str(file) - # NOT THREAD SAFE BY ITSELF! + # with open(str(file), "r", encoding="utf8", newline="") as tsvfile: rd = csv.reader(tsvfile, delimiter="\t", lineterminator="\n") @@ -202,7 +226,19 @@ def _read_first_row(file: str | Path): def _load_first_column_entries(file: str | Path): - # NOT THREAD SAFE BY ITSELF! + """Loads the entries from the first column of a TSV file. + + NOT THREAD SAFE BY ITSELF! + + Args: + file (str | Path): The path to the file from which to load entries. + + Returns: + list: A list of entries from the first column of the file. + + Raises: + AssertionError: If the file contains duplicate entries. + """ if isinstance(file, Path): file = str(file) with open(str(file), "r", encoding="utf8", newline="") as tsvfile: @@ -221,6 +257,12 @@ def _load_first_column_entries(file: str | Path): def _write_content(file: str | Path, content: list[list[str]]): + """Writes content to a TSV file. + + Args: + file (str | Path): The path to the file where content will be written. + content (list[list[str]]): A list of lists containing the rows of data to write. + """ if isinstance(file, Path): file = str(file) # NOT THREAD SAFE BY ITSELF! diff --git a/panoptica/panoptica_result.py b/panoptica/panoptica_result.py index 0db17c2..12341ec 100644 --- a/panoptica/panoptica_result.py +++ b/panoptica/panoptica_result.py @@ -255,25 +255,19 @@ def __init__( num_pred_instances=self.num_pred_instances, num_ref_instances=self.num_ref_instances, ) - self._list_metrics[m] = Evaluation_List_Metric( - m, empty_list_std, list_metrics[m], is_edge_case, edge_case_result - ) + self._list_metrics[m] = Evaluation_List_Metric(m, empty_list_std, list_metrics[m], is_edge_case, edge_case_result) # even if not available, set the global vars default_value = None was_calculated = False if m in self._global_metrics and arrays_present: - default_value = self._calc_global_bin_metric( - m, pred_binary, ref_binary, do_binarize=False - ) + default_value = self._calc_global_bin_metric(m, pred_binary, ref_binary, do_binarize=False) was_calculated = True self._add_metric( f"global_bin_{m.name.lower()}", MetricType.GLOBAL, - lambda x: MetricCouldNotBeComputedException( - f"Global Metric {m} not set" - ), + lambda x: MetricCouldNotBeComputedException(f"Global Metric {m} not set"), long_name="Global Binary " + m.value.long_name, default_value=default_value, was_calculated=was_calculated, @@ -286,6 +280,21 @@ def _calc_global_bin_metric( reference_arr, do_binarize: bool = True, ): + """ + Calculates a global binary metric based on predictions and references. + + Args: + metric (Metric): The metric to compute. + prediction_arr: The predicted values. + reference_arr: The ground truth values. + do_binarize (bool): Whether to binarize the input arrays. Defaults to True. + + Returns: + The calculated metric value. + + Raises: + MetricCouldNotBeComputedException: If the specified metric is not set. + """ if metric not in self._global_metrics: raise MetricCouldNotBeComputedException(f"Global Metric {metric} not set") @@ -301,9 +310,7 @@ def _calc_global_bin_metric( prediction_empty = pred_binary.sum() == 0 reference_empty = ref_binary.sum() == 0 if prediction_empty or reference_empty: - is_edgecase, result = self._edge_case_handler.handle_zero_tp( - metric, 0, int(prediction_empty), int(reference_empty) - ) + is_edgecase, result = self._edge_case_handler.handle_zero_tp(metric, 0, int(prediction_empty), int(reference_empty)) if is_edgecase: return result @@ -321,6 +328,20 @@ def _add_metric( default_value=None, was_calculated: bool = False, ): + """ + Adds a new metric to the evaluation metrics. + + Args: + name_id (str): The unique identifier for the metric. + metric_type (MetricType): The type of the metric. + calc_func (Callable | None): The function to calculate the metric. + long_name (str | None): A longer, descriptive name for the metric. + default_value: The default value for the metric. + was_calculated (bool): Indicates if the metric has been calculated. + + Returns: + The default value of the metric. + """ setattr(self, name_id, default_value) # assert hasattr(self, name_id), f"added metric {name_id} but it is not a member variable of this class" if calc_func is None: @@ -338,7 +359,8 @@ def _add_metric( return default_value def calculate_all(self, print_errors: bool = False): - """Calculates all possible metrics that can be derived + """ + Calculates all possible metrics that can be derived. Args: print_errors (bool, optional): If true, will print every metric that could not be computed and its reason. Defaults to False. @@ -356,6 +378,16 @@ def calculate_all(self, print_errors: bool = False): print(f"Metric {k}: {v}") def _calc(self, k, v): + """ + Attempts to get the value of a metric and captures any exceptions. + + Args: + k: The metric key. + v: The metric value. + + Returns: + A tuple indicating success or failure and the corresponding value or exception. + """ try: v = getattr(self, k) return False, v @@ -389,25 +421,51 @@ def __str__(self) -> str: return text def to_dict(self) -> dict: - return { - k: getattr(self, v.id) - for k, v in self._evaluation_metrics.items() - if (v._error == False and v._was_calculated) - } + """ + Converts the metrics to a dictionary format. + + Returns: + A dictionary containing metric names and their values. + """ + return {k: getattr(self, v.id) for k, v in self._evaluation_metrics.items() if (v._error == False and v._was_calculated)} @property def evaluation_metrics(self): return self._evaluation_metrics def get_list_metric(self, metric: Metric, mode: MetricMode): + """ + Retrieves a list of metrics based on the given metric type and mode. + + Args: + metric (Metric): The metric to retrieve. + mode (MetricMode): The mode of the metric. + + Returns: + The corresponding list of metrics. + + Raises: + MetricCouldNotBeComputedException: If the metric cannot be found. + """ if metric in self._list_metrics: return self._list_metrics[metric][mode] else: - raise MetricCouldNotBeComputedException( - f"{metric} could not be found, have you set it in eval_metrics during evaluation?" - ) + raise MetricCouldNotBeComputedException(f"{metric} could not be found, have you set it in eval_metrics during evaluation?") def _calc_metric(self, metric_name: str, supress_error: bool = False): + """ + Calculates a specific metric by its name. + + Args: + metric_name (str): The name of the metric to calculate. + supress_error (bool): If true, suppresses errors during calculation. + + Returns: + The calculated metric value or raises an exception if it cannot be computed. + + Raises: + MetricCouldNotBeComputedException: If the metric cannot be found. + """ if metric_name in self._evaluation_metrics: try: value = self._evaluation_metrics[metric_name](self) @@ -421,11 +479,21 @@ def _calc_metric(self, metric_name: str, supress_error: bool = False): self._evaluation_metrics[metric_name]._was_calculated = True return value else: - raise MetricCouldNotBeComputedException( - f"could not find metric with name {metric_name}" - ) + raise MetricCouldNotBeComputedException(f"could not find metric with name {metric_name}") def __getattribute__(self, __name: str) -> Any: + """ + Retrieves an attribute, with special handling for evaluation metrics. + + Args: + __name (str): The name of the attribute to retrieve. + + Returns: + The attribute value. + + Raises: + MetricCouldNotBeComputedException: If the requested metric could not be computed. + """ attr = None try: attr = object.__getattribute__(self, __name) @@ -438,15 +506,10 @@ def __getattribute__(self, __name: str) -> Any: raise e if __name == "_evaluation_metrics": return attr - if ( - object.__getattribute__(self, "_evaluation_metrics") is not None - and __name in self._evaluation_metrics.keys() - ): + if object.__getattribute__(self, "_evaluation_metrics") is not None and __name in self._evaluation_metrics.keys(): if attr is None: if self._evaluation_metrics[__name]._error: - raise MetricCouldNotBeComputedException( - f"Requested metric {__name} that could not be computed" - ) + raise MetricCouldNotBeComputedException(f"Requested metric {__name} that could not be computed") elif not self._evaluation_metrics[__name]._was_calculated: value = self._calc_metric(__name) setattr(self, __name, value) diff --git a/panoptica/panoptica_statistics.py b/panoptica/panoptica_statistics.py index c651fb1..d3f4303 100644 --- a/panoptica/panoptica_statistics.py +++ b/panoptica/panoptica_statistics.py @@ -43,9 +43,7 @@ def from_file(cls, file: str): rows = [row for row in rd] header = rows[0] - assert ( - header[0] == "subject_name" - ), "First column is not subject_names, something wrong with the file?" + assert header[0] == "subject_name", "First column is not subject_names, something wrong with the file?" keys_in_order = list([tuple(c.split("-")) for c in header[1:]]) metric_names = [] @@ -82,19 +80,13 @@ def from_file(cls, file: str): return Panoptica_Statistic(subj_names=subj_names, value_dict=value_dict) def _assertgroup(self, group): - assert ( - group in self.__groupnames - ), f"group {group} not existent, only got groups {self.__groupnames}" + assert group in self.__groupnames, f"group {group} not existent, only got groups {self.__groupnames}" def _assertmetric(self, metric): - assert ( - metric in self.__metricnames - ), f"metric {metric} not existent, only got metrics {self.__metricnames}" + assert metric in self.__metricnames, f"metric {metric} not existent, only got metrics {self.__metricnames}" def _assertsubject(self, subjectname): - assert ( - subjectname in self.__subj_names - ), f"subject {subjectname} not in list of subjects, got {self.__subj_names}" + assert subjectname in self.__subj_names, f"subject {subjectname} not in list of subjects, got {self.__subj_names}" def get(self, group, metric) -> list[float]: """Returns the list of values for given group and metric @@ -125,10 +117,7 @@ def get_one_subject(self, subjectname: str): """ self._assertsubject(subjectname) sidx = self.__subj_names.index(subjectname) - return { - g: {m: self.get(g, m)[sidx] for m in self.__metricnames} - for g in self.__groupnames - } + return {g: {m: self.get(g, m)[sidx] for m in self.__metricnames} for g in self.__groupnames} def get_across_groups(self, metric): """Given metric, gives list of all values (even across groups!) Treat with care! @@ -145,10 +134,7 @@ def get_across_groups(self, metric): return values def get_summary_dict(self): - return { - g: {m: self.get_summary(g, m) for m in self.__metricnames} - for g in self.__groupnames - } + return {g: {m: self.get_summary(g, m) for m in self.__metricnames} for g in self.__groupnames} def get_summary(self, group, metric): # TODO maybe more here? range, stuff like that @@ -174,6 +160,7 @@ def get_summary_figure( self, metric: str, horizontal: bool = True, + sort: bool = True, # title overwrite? ): """Returns a figure object that shows the given metric for each group and its std @@ -191,6 +178,7 @@ def get_summary_figure( data=data_plot, orientation=orientation, score=metric, + sort=sort, ) # groupwise or in total @@ -218,9 +206,7 @@ def make_curve_over_setups( alternate_groupnames = [alternate_groupnames] # for setupname, stat in statistics_dict.items(): - assert ( - metric in stat.metricnames - ), f"metric {metric} not in statistic obj {setupname}" + assert metric in stat.metricnames, f"metric {metric} not in statistic obj {setupname}" setupnames = list(statistics_dict.keys()) convert_x_to_digit = True @@ -296,14 +282,10 @@ def plot_box( if sort: df_by_spec_count = df_data.groupby(graph_name).mean() df_by_spec_count = dict(df_by_spec_count[score].items()) - df_data["mean"] = df_data[graph_name].apply( - lambda x: df_by_spec_count[x] * (1 if orientation == "h" else -1) - ) + df_data["mean"] = df_data[graph_name].apply(lambda x: df_by_spec_count[x] * (1 if orientation == "h" else -1)) df_data = df_data.sort_values(by="mean") if orientation == "v": - fig = px.strip( - df_data, x=graph_name, y=score, stripmode="overlay", orientation=orientation - ) + fig = px.strip(df_data, x=graph_name, y=score, stripmode="overlay", orientation=orientation) fig.update_traces(marker={"size": 5, "color": "#555555"}) for e in data.keys(): fig.add_trace( @@ -314,9 +296,7 @@ def plot_box( ) ) else: - fig = px.strip( - df_data, y=graph_name, x=score, stripmode="overlay", orientation=orientation - ) + fig = px.strip(df_data, y=graph_name, x=score, stripmode="overlay", orientation=orientation) fig.update_traces(marker={"size": 5, "color": "#555555"}) for e in data.keys(): fig.add_trace( diff --git a/panoptica/utils/config.py b/panoptica/utils/config.py index a3d5157..30d332c 100644 --- a/panoptica/utils/config.py +++ b/panoptica/utils/config.py @@ -7,10 +7,24 @@ def _register_helper_classes(yaml: YAML): + """Registers globally supported helper classes to a YAML instance. + + Args: + yaml (YAML): The YAML instance to register helper classes to. + """ [yaml.register_class(s) for s in supported_helper_classes] def _load_yaml(file: str | Path, registered_class=None): + """Loads a YAML file into a Python dictionary or object, with optional class registration. + + Args: + file (str | Path): Path to the YAML file. + registered_class (optional): Optional class to register with the YAML parser. + + Returns: + dict | object: Parsed content from the YAML file. + """ if isinstance(file, str): file = Path(file) yaml = YAML(typ="safe") @@ -24,6 +38,13 @@ def _load_yaml(file: str | Path, registered_class=None): def _save_yaml(data_dict: dict | object, out_file: str | Path, registered_class=None): + """Saves a Python dictionary or object to a YAML file, with optional class registration. + + Args: + data_dict (dict | object): Data to save. + out_file (str | Path): Output file path. + registered_class (optional): Class type to register with YAML if saving an object. + """ if isinstance(out_file, str): out_file = Path(out_file) @@ -47,13 +68,26 @@ def _save_yaml(data_dict: dict | object, out_file: str | Path, registered_class= # Universal Functions ######### def _register_class_to_yaml(cls): + """Registers a class to the global supported helper classes for YAML serialization. + + Args: + cls: The class to register. + """ global supported_helper_classes if cls not in supported_helper_classes: supported_helper_classes.append(cls) def _load_from_config(cls, path: str | Path): - # cls._register_permanently() + """Loads an instance of a class from a YAML configuration file. + + Args: + cls: The class type to instantiate. + path (str | Path): Path to the YAML configuration file. + + Returns: + An instance of the specified class, loaded from configuration. + """ if isinstance(path, str): path = Path(path) assert path.exists(), f"load_from_config: {path} does not exist" @@ -63,72 +97,153 @@ def _load_from_config(cls, path: str | Path): def _load_from_config_name(cls, name: str): + """Loads an instance of a class from a configuration file identified by name. + + Args: + cls: The class type to instantiate. + name (str): The name used to find the configuration file. + + Returns: + An instance of the specified class. + """ path = config_by_name(name) assert path.exists(), f"load_from_config: {path} does not exist" return _load_from_config(cls, path) def _save_to_config(obj, path: str | Path): + """Saves an instance of a class to a YAML configuration file. + + Args: + obj: The object to save. + path (str | Path): The file path to save the configuration. + """ if isinstance(path, str): path = Path(path) _save_yaml(obj, path, registered_class=type(obj)) def _save_to_config_by_name(obj, name: str): + """Saves an instance of a class to a configuration file by name. + + Args: + obj: The object to save. + name (str): The name used to determine the configuration file path. + """ dir, name = config_dir_by_name(name) _save_to_config(obj, dir.joinpath(name)) class SupportsConfig: - """Metaclass that allows a class to save and load objects by yaml configs""" + """Base class that provides methods for loading and saving instances as YAML configurations. + + This class should be inherited by classes that wish to have load and save functionality for YAML + configurations, with class registration to enable custom serialization and deserialization. + + Methods: + load_from_config(cls, path): Loads a class instance from a YAML file. + load_from_config_name(cls, name): Loads a class instance from a configuration file identified by name. + save_to_config(path): Saves the instance to a YAML file. + save_to_config_by_name(name): Saves the instance to a configuration file identified by name. + to_yaml(cls, representer, node): YAML serialization method (requires _yaml_repr). + from_yaml(cls, constructor, node): YAML deserialization method. + """ def __init__(self) -> None: + """Prevents instantiation of SupportsConfig as it is intended to be a metaclass.""" raise NotImplementedError(f"Tried to instantiate abstract class {type(self)}") def __init_subclass__(cls, **kwargs): - # Registers all subclasses of this + """Registers subclasses of SupportsConfig to enable YAML support.""" super().__init_subclass__(**kwargs) cls._register_permanently() @classmethod def _register_permanently(cls): + """Registers the class to globally supported helper classes.""" _register_class_to_yaml(cls) @classmethod def load_from_config(cls, path: str | Path): + """Loads an instance of the class from a YAML file. + + Args: + path (str | Path): The file path to load the configuration. + + Returns: + An instance of the class. + """ obj = _load_from_config(cls, path) - assert isinstance( - obj, cls - ), f"loaded object was not of the correct class, expected {cls.__name__} but got {type(obj)}" + assert isinstance(obj, cls), f"loaded object was not of the correct class, expected {cls.__name__} but got {type(obj)}" return obj @classmethod def load_from_config_name(cls, name: str): + """Loads an instance of the class from a configuration file identified by name. + + Args: + name (str): The name used to find the configuration file. + + Returns: + An instance of the class. + """ obj = _load_from_config_name(cls, name) assert isinstance(obj, cls) return obj def save_to_config(self, path: str | Path): + """Saves the instance to a YAML configuration file. + + Args: + path (str | Path): The file path to save the configuration. + """ _save_to_config(self, path) def save_to_config_by_name(self, name: str): + """Saves the instance to a configuration file identified by name. + + Args: + name (str): The name used to determine the configuration file path. + """ _save_to_config_by_name(self, name) @classmethod def to_yaml(cls, representer, node): - # cls._register_permanently() - assert hasattr( - cls, "_yaml_repr" - ), f"Class {cls.__name__} has no _yaml_repr(cls, node) defined" + """Serializes the class to YAML format. + + Args: + representer: YAML representer instance. + node: The object instance to serialize. + + Returns: + YAML node: YAML-compatible node representation of the object. + """ + assert hasattr(cls, "_yaml_repr"), f"Class {cls.__name__} has no _yaml_repr(cls, node) defined" return representer.represent_mapping("!" + cls.__name__, cls._yaml_repr(node)) @classmethod def from_yaml(cls, constructor, node): - # cls._register_permanently() + """Deserializes a YAML node to an instance of the class. + + Args: + constructor: YAML constructor instance. + node: YAML node to deserialize. + + Returns: + An instance of the class with attributes populated from YAML data. + """ data = constructor.construct_mapping(node, deep=True) return cls(**data) @classmethod @abstractmethod def _yaml_repr(cls, node) -> dict: + """Abstract method for representing the class in YAML. + + Args: + node: The object instance to represent in YAML. + + Returns: + dict: A dictionary representation of the class. + """ pass # return {"groups": node.__group_dictionary} diff --git a/panoptica/utils/constants.py b/panoptica/utils/constants.py index ecb7f63..dceb70d 100644 --- a/panoptica/utils/constants.py +++ b/panoptica/utils/constants.py @@ -10,6 +10,22 @@ class _Enum_Compare(Enum): + """An extended Enum class that supports additional comparison and YAML configuration functionality. + + This class enhances standard `Enum` capabilities, allowing comparisons with other enums or strings by + name and adding support for YAML serialization and deserialization methods. + + Methods: + __eq__(__value): Checks equality with another Enum or string. + __str__(): Returns a string representation of the Enum instance. + __repr__(): Returns a string representation for debugging. + load_from_config(cls, path): Loads an Enum instance from a configuration file. + load_from_config_name(cls, name): Loads an Enum instance from a configuration file identified by name. + save_to_config(path): Saves the Enum instance to a configuration file. + to_yaml(cls, representer, node): Serializes the Enum to YAML. + from_yaml(cls, constructor, node): Deserializes YAML data into an Enum instance. + """ + def __eq__(self, __value: object) -> bool: if isinstance(__value, Enum): namecheck = self.name == __value.name diff --git a/panoptica/utils/edge_case_handling.py b/panoptica/utils/edge_case_handling.py index e7dd5d1..c9c0b0f 100644 --- a/panoptica/utils/edge_case_handling.py +++ b/panoptica/utils/edge_case_handling.py @@ -5,6 +5,23 @@ class EdgeCaseResult(_Enum_Compare): + """Enumeration of edge case values used for handling specific metric situations. + + This enum defines several common edge case values for handling zero-true-positive (zero-TP) + situations in various metrics. The values include infinity, NaN, zero, one, and None. + + Attributes: + INF: Represents infinity (`np.inf`). + NAN: Represents not-a-number (`np.nan`). + ZERO: Represents zero (0.0). + ONE: Represents one (1.0). + NONE: Represents a None value. + + Methods: + value: Returns the value associated with the edge case. + __call__(): Returns the numeric or None representation of the enum member. + """ + INF = auto() # np.inf NAN = auto() # np.nan ZERO = auto() # 0.0 @@ -29,6 +46,15 @@ def __call__(self): class EdgeCaseZeroTP(_Enum_Compare): + """Enum defining scenarios that could produce zero true positives (zero-TP) in metrics. + + Attributes: + NO_INSTANCES: No instances in both the prediction and reference. + EMPTY_PRED: The prediction is empty. + EMPTY_REF: The reference is empty. + NORMAL: A typical scenario with non-zero instances. + """ + NO_INSTANCES = auto() EMPTY_PRED = auto() EMPTY_REF = auto() @@ -39,6 +65,21 @@ def __hash__(self) -> int: class MetricZeroTPEdgeCaseHandling(SupportsConfig): + """Handles zero-TP edge cases for metrics, mapping different zero-TP scenarios to specific results. + + Attributes: + default_result (EdgeCaseResult | None): Default result if specific edge cases are not provided. + no_instances_result (EdgeCaseResult | None): Result when no instances are present. + empty_prediction_result (EdgeCaseResult | None): Result when prediction is empty. + empty_reference_result (EdgeCaseResult | None): Result when reference is empty. + normal (EdgeCaseResult | None): Result when a normal zero-TP scenario occurs. + + Methods: + __call__(tp, num_pred_instances, num_ref_instances): Determines if an edge case is detected and returns its result. + __eq__(value): Compares this handling object to another. + __str__(): String representation of edge cases. + _yaml_repr(cls, node): YAML representation for the edge case. + """ def __init__( self, @@ -57,26 +98,12 @@ def __init__( self._default_result = default_result self._edgecase_dict: dict[EdgeCaseZeroTP, EdgeCaseResult] = {} - self._edgecase_dict[EdgeCaseZeroTP.EMPTY_PRED] = ( - empty_prediction_result - if empty_prediction_result is not None - else default_result - ) - self._edgecase_dict[EdgeCaseZeroTP.EMPTY_REF] = ( - empty_reference_result - if empty_reference_result is not None - else default_result - ) - self._edgecase_dict[EdgeCaseZeroTP.NO_INSTANCES] = ( - no_instances_result if no_instances_result is not None else default_result - ) - self._edgecase_dict[EdgeCaseZeroTP.NORMAL] = ( - normal if normal is not None else default_result - ) + self._edgecase_dict[EdgeCaseZeroTP.EMPTY_PRED] = empty_prediction_result if empty_prediction_result is not None else default_result + self._edgecase_dict[EdgeCaseZeroTP.EMPTY_REF] = empty_reference_result if empty_reference_result is not None else default_result + self._edgecase_dict[EdgeCaseZeroTP.NO_INSTANCES] = no_instances_result if no_instances_result is not None else default_result + self._edgecase_dict[EdgeCaseZeroTP.NORMAL] = normal if normal is not None else default_result - def __call__( - self, tp: int, num_pred_instances, num_ref_instances - ) -> tuple[bool, float | None]: + def __call__(self, tp: int, num_pred_instances, num_ref_instances) -> tuple[bool, float | None]: if tp != 0: return False, EdgeCaseResult.NONE.value # @@ -119,6 +146,19 @@ def _yaml_repr(cls, node) -> dict: class EdgeCaseHandler(SupportsConfig): + """Manages edge cases across multiple metrics, including standard deviation handling for empty lists. + + Attributes: + listmetric_zeroTP_handling (dict): Dictionary mapping metrics to their zero-TP edge case handling. + empty_list_std (EdgeCaseResult): Default edge case for handling standard deviation of empty lists. + + Methods: + handle_zero_tp(metric, tp, num_pred_instances, num_ref_instances): Checks if an edge case exists and returns its result. + listmetric_zeroTP_handling: Returns the edge case handling dictionary. + get_metric_zero_tp_handle(metric): Returns the zero-TP handler for a specific metric. + handle_empty_list_std(): Handles standard deviation of empty lists. + _yaml_repr(cls, node): YAML representation of the handler. + """ def __init__( self, @@ -147,9 +187,7 @@ def __init__( }, empty_list_std: EdgeCaseResult = EdgeCaseResult.NAN, ) -> None: - self.__listmetric_zeroTP_handling: dict[ - Metric, MetricZeroTPEdgeCaseHandling - ] = listmetric_zeroTP_handling + self.__listmetric_zeroTP_handling: dict[Metric, MetricZeroTPEdgeCaseHandling] = listmetric_zeroTP_handling self.__empty_list_std: EdgeCaseResult = empty_list_std def handle_zero_tp( @@ -176,9 +214,7 @@ def handle_zero_tp( if tp != 0: return False, EdgeCaseResult.NONE.value if metric not in self.__listmetric_zeroTP_handling: - raise NotImplementedError( - f"Metric {metric} encountered zero TP, but no edge handling available" - ) + raise NotImplementedError(f"Metric {metric} encountered zero TP, but no edge handling available") return self.__listmetric_zeroTP_handling[metric]( tp=tp, diff --git a/panoptica/utils/instancelabelmap.py b/panoptica/utils/instancelabelmap.py index 16fd33c..a0caa45 100644 --- a/panoptica/utils/instancelabelmap.py +++ b/panoptica/utils/instancelabelmap.py @@ -3,19 +3,46 @@ # Many-to-One Mapping class InstanceLabelMap(object): - # Mapping ((prediction_label, ...), (reference_label, ...)) + """Creates a mapping between prediction labels and reference labels in a many-to-one relationship. + + This class allows mapping multiple prediction labels to a single reference label. + It includes methods for adding new mappings, checking containment, retrieving + predictions mapped to a reference, and exporting the mapping as a dictionary. + + Attributes: + labelmap (dict[int, int]): Dictionary storing the prediction-to-reference label mappings. + + Methods: + add_labelmap_entry(pred_labels, ref_label): Adds a new entry mapping prediction labels to a reference label. + get_pred_labels_matched_to_ref(ref_label): Retrieves prediction labels mapped to a given reference label. + contains_pred(pred_label): Checks if a prediction label exists in the map. + contains_ref(ref_label): Checks if a reference label exists in the map. + contains_and(pred_label, ref_label): Checks if both a prediction and a reference label are in the map. + contains_or(pred_label, ref_label): Checks if either a prediction or reference label is in the map. + get_one_to_one_dictionary(): Returns the labelmap dictionary for a one-to-one view. + """ + labelmap: dict[int, int] def __init__(self) -> None: self.labelmap = {} def add_labelmap_entry(self, pred_labels: list[int] | int, ref_label: int): + """Adds an entry that maps prediction labels to a single reference label. + + Args: + pred_labels (list[int] | int): List of prediction labels or a single prediction label. + ref_label (int): Reference label to map to. + + Raises: + AssertionError: If `ref_label` is not an integer. + AssertionError: If any `pred_labels` are not integers. + Exception: If a prediction label is already mapped to a different reference label. + """ if not isinstance(pred_labels, list): pred_labels = [pred_labels] assert isinstance(ref_label, int), "add_labelmap_entry: got no int as ref_label" - assert np.all( - [isinstance(r, int) for r in pred_labels] - ), "add_labelmap_entry: got no int as pred_label" + assert np.all([isinstance(r, int) for r in pred_labels]), "add_labelmap_entry: got no int as pred_label" for p in pred_labels: if p in self.labelmap and self.labelmap[p] != ref_label: raise Exception( @@ -24,38 +51,79 @@ def add_labelmap_entry(self, pred_labels: list[int] | int, ref_label: int): self.labelmap[p] = ref_label def get_pred_labels_matched_to_ref(self, ref_label: int): + """Retrieves all prediction labels that map to a specified reference label. + + Args: + ref_label (int): The reference label to search. + + Returns: + list[int]: List of prediction labels mapped to `ref_label`. + """ return [k for k, v in self.labelmap.items() if v == ref_label] def contains_pred(self, pred_label: int): + """Checks if a prediction label exists in the map. + + Args: + pred_label (int): The prediction label to search. + + Returns: + bool: True if `pred_label` is in `labelmap`, otherwise False. + """ return pred_label in self.labelmap def contains_ref(self, ref_label: int): + """Checks if a reference label exists in the map. + + Args: + ref_label (int): The reference label to search. + + Returns: + bool: True if `ref_label` is in `labelmap` values, otherwise False. + """ return ref_label in self.labelmap.values() - def contains_and( - self, pred_label: int | None = None, ref_label: int | None = None - ) -> bool: + def contains_and(self, pred_label: int | None = None, ref_label: int | None = None) -> bool: + """Checks if both a prediction and a reference label are in the map. + + Args: + pred_label (int | None): The prediction label to check. + ref_label (int | None): The reference label to check. + + Returns: + bool: True if both `pred_label` and `ref_label` are in the map; otherwise, False. + """ pred_in = True if pred_label is None else pred_label in self.labelmap ref_in = True if ref_label is None else ref_label in self.labelmap.values() return pred_in and ref_in - def contains_or( - self, pred_label: int | None = None, ref_label: int | None = None - ) -> bool: + def contains_or(self, pred_label: int | None = None, ref_label: int | None = None) -> bool: + """Checks if either a prediction or reference label is in the map. + + Args: + pred_label (int | None): The prediction label to check. + ref_label (int | None): The reference label to check. + + Returns: + bool: True if either `pred_label` or `ref_label` are in the map; otherwise, False. + """ pred_in = True if pred_label is None else pred_label in self.labelmap ref_in = True if ref_label is None else ref_label in self.labelmap.values() return pred_in or ref_in def get_one_to_one_dictionary(self): + """Returns a copy of the labelmap dictionary for a one-to-one view. + + Returns: + dict[int, int]: The prediction-to-reference label mapping. + """ return self.labelmap def __str__(self) -> str: return str( list( [ - str(tuple(k for k in self.labelmap.keys() if self.labelmap[k] == v)) - + " -> " - + str(v) + str(tuple(k for k in self.labelmap.keys() if self.labelmap[k] == v)) + " -> " + str(v) for v in set(self.labelmap.values()) ] ) @@ -66,6 +134,15 @@ def __repr__(self) -> str: # Make all variables read-only! def __setattr__(self, attr, value): + """Overrides attribute setting to make attributes read-only after initialization. + + Args: + attr (str): Attribute name. + value (Any): Attribute value. + + Raises: + Exception: If trying to alter an existing attribute. + """ if hasattr(self, attr): raise Exception("Attempting to alter read-only value") diff --git a/panoptica/utils/label_group.py b/panoptica/utils/label_group.py index 34e0a08..8f1ac72 100644 --- a/panoptica/utils/label_group.py +++ b/panoptica/utils/label_group.py @@ -5,45 +5,53 @@ class LabelGroup(SupportsConfig): - """Defines a group of labels that semantically belong to each other. Only labels within a group will be matched with each other""" + """Defines a group of labels that semantically belong together for segmentation purposes. + + Groups of labels define label sets that can be matched with each other. + For example, labels might represent different parts of a segmented object, and only those within the group are eligible for matching. + + Attributes: + value_labels (list[int]): List of integer labels representing segmentation group labels. + single_instance (bool): If True, the group represents a single instance without matching threshold consideration. + """ def __init__( self, value_labels: list[int] | int, single_instance: bool = False, ) -> None: - """Defines a group of labels that semantically belong to each other + """Initializes a LabelGroup with specified labels and single instance setting. Args: - value_labels (list[int]): Actually labels in the prediction and reference mask in this group. Defines the labels that can be matched to each other - single_instance (bool, optional): If true, will not use the matching_threshold as there is only one instance (large organ, ...). Defaults to False. + value_labels (list[int] | int): Labels in the prediction and reference mask for this group. + single_instance (bool, optional): If True, ignores matching threshold as only one instance exists. Defaults to False. + + Raises: + AssertionError: If `value_labels` is empty or if labels are not positive integers. + AssertionError: If `single_instance` is True but more than one label is provided. """ if isinstance(value_labels, int): value_labels = [value_labels] value_labels = list(set(value_labels)) - assert ( - len(value_labels) >= 1 - ), f"You tried to define a LabelGroup without any specified labels, got {value_labels}" + assert len(value_labels) >= 1, f"You tried to define a LabelGroup without any specified labels, got {value_labels}" self.__value_labels = value_labels - assert np.all( - [v > 0 for v in self.__value_labels] - ), f"Given value labels are not >0, got {value_labels}" + assert np.all([v > 0 for v in self.__value_labels]), f"Given value labels are not >0, got {value_labels}" self.__single_instance = single_instance if self.__single_instance: - assert ( - len(value_labels) == 1 - ), f"single_instance set to True, but got more than one label for this group, got {value_labels}" + assert len(value_labels) == 1, f"single_instance set to True, but got more than one label for this group, got {value_labels}" LabelGroup._register_permanently() @property def value_labels(self) -> list[int]: + """List of integer labels for this segmentation group.""" return self.__value_labels @property def single_instance(self) -> bool: + """Indicates if this group is treated as a single instance.""" return self.__single_instance def extract_label( @@ -51,14 +59,14 @@ def extract_label( array: np.ndarray, set_to_binary: bool = False, ): - """Extracts the labels of this class + """Extracts an array of the labels specific to this segmentation group. Args: - array (np.ndarray): Array to extract the segmentation group labels from - set_to_binary (bool, optional): If true, will output a binary array. Defaults to False. + array (np.ndarray): The array to filter for segmentation group labels. + set_to_binary (bool, optional): If True, outputs a binary array. Defaults to False. Returns: - np.ndarray: Array containing only the labels of this segmentation group + np.ndarray: An array with only the labels of this segmentation group. """ array = array.copy() array[np.isin(array, self.value_labels, invert=True)] = 0 @@ -70,6 +78,14 @@ def __call__( self, array: np.ndarray, ) -> np.ndarray: + """Extracts labels from an array for this group when the instance is called. + + Args: + array (np.ndarray): Array to filter for segmentation group labels. + + Returns: + np.ndarray: Array containing only the labels for this segmentation group. + """ return self.extract_label(array, set_to_binary=False) def __str__(self) -> str: @@ -87,23 +103,28 @@ def _yaml_repr(cls, node): class LabelMergeGroup(LabelGroup): - def __init__( - self, value_labels: list[int] | int, single_instance: bool = False - ) -> None: + """Defines a group of labels that will be merged into a single label when extracted. + + Inherits from LabelGroup and sets extracted labels to binary format. + + Methods: + __call__(array): Extracts the label group as a binary array. + """ + + def __init__(self, value_labels: list[int] | int, single_instance: bool = False) -> None: super().__init__(value_labels, single_instance) def __call__( self, array: np.ndarray, ) -> np.ndarray: - """Extracts the labels of this class + """Extracts the labels of this group as a binary array. Args: - array (np.ndarray): Array to extract the segmentation group labels from - set_to_binary (bool, optional): If true, will output a binary array. Defaults to False. + array (np.ndarray): Array to filter for segmentation group labels. Returns: - np.ndarray: Array containing only the labels of this segmentation group + np.ndarray: Binary array representing presence or absence of group labels. """ return self.extract_label(array, set_to_binary=True) @@ -112,6 +133,14 @@ def __str__(self) -> str: class _LabelGroupAny(LabelGroup): + """Represents a group that includes all labels in the array with no specific segmentation constraints. + + Used to represent a group that does not restrict labels. + + Methods: + __call__(array, set_to_binary): Returns the unfiltered array. + """ + def __init__(self) -> None: pass @@ -128,14 +157,14 @@ def __call__( array: np.ndarray, set_to_binary: bool = False, ) -> np.ndarray: - """Extracts the labels of this class + """Returns the original array, unfiltered. Args: - array (np.ndarray): Array to extract the segmentation group labels from - set_to_binary (bool, optional): If true, will output a binary array. Defaults to False. + array (np.ndarray): The original array to return. + set_to_binary (bool, optional): Ignored in this implementation. Returns: - np.ndarray: Array containing only the labels of this segmentation group + np.ndarray: The original, unmodified array. """ array = array.copy() return array diff --git a/panoptica/utils/parallel_processing.py b/panoptica/utils/parallel_processing.py index 50e3691..b7a7bac 100644 --- a/panoptica/utils/parallel_processing.py +++ b/panoptica/utils/parallel_processing.py @@ -4,6 +4,20 @@ class NoDaemonProcess(Process): + """A subclass of `multiprocessing.Process` that overrides daemon behavior to always be non-daemonic. + + Useful for creating a process that allows child processes to spawn their own children, + as daemonic processes in Python cannot create further subprocesses. + + Attributes: + group (None): Reserved for future extension when using process groups. + target (Callable[..., object] | None): The callable object to be invoked by the process. + name (str | None): The name of the process, for identification. + args (tuple): Arguments to pass to the `target` function. + kwargs (dict): Keyword arguments to pass to the `target` function. + daemon (bool | None): Indicates if the process is daemonic (overridden to always be False). + """ + def __init__( self, group: None = None, @@ -33,4 +47,9 @@ def _set_daemon(self, value): # We sub-class multiprocessing.pool.Pool instead of multiprocessing.Pool # because the latter is only a wrapper function, not a proper class. class NonDaemonicPool(multiprocessing.pool.Pool): + """A version of `multiprocessing.pool.Pool` using non-daemonic processes, allowing child processes to spawn their own children. + + This class creates a pool of worker processes using `NoDaemonProcess` for situations where nested child processes are needed. + """ + Process = NoDaemonProcess diff --git a/panoptica/utils/processing_pair.py b/panoptica/utils/processing_pair.py index 6ef4640..1514ef8 100644 --- a/panoptica/utils/processing_pair.py +++ b/panoptica/utils/processing_pair.py @@ -13,9 +13,17 @@ class _ProcessingPair(ABC): - """ - Represents a general processing pair consisting of a reference array and a prediction array. Type of array can be arbitrary (integer recommended) - Every member is read-only! + """Represents a pair of processing arrays, typically prediction and reference arrays. + + This base class provides core functionality for processing and comparing prediction + and reference data arrays. Each instance contains two arrays and supports cropping and + data integrity checks. + + Attributes: + n_dim (int): The number of dimensions in the reference array. + crop (tuple[slice, ...] | None): The crop region applied to both arrays, if any. + is_cropped (bool): Indicates whether the arrays have been cropped. + uncropped_shape (tuple[int, ...]): The original shape of the arrays before cropping. """ _prediction_arr: np.ndarray @@ -25,32 +33,31 @@ class _ProcessingPair(ABC): _pred_labels: tuple[int, ...] n_dim: int - def __init__( - self, prediction_arr: np.ndarray, reference_arr: np.ndarray, dtype: type | None - ) -> None: - """Initializes a general Processing Pair + def __init__(self, prediction_arr: np.ndarray, reference_arr: np.ndarray, dtype: type | None) -> None: + """Initializes the processing pair with prediction and reference arrays. Args: - prediction_arr (np.ndarray): Numpy array containig the prediction labels - reference_arr (np.ndarray): Numpy array containig the reference labels - dtype (type | None): Datatype that is asserted. None for no assertion + prediction_arr (np.ndarray): Numpy array of prediction labels. + reference_arr (np.ndarray): Numpy array of reference labels. + dtype (type | None): The expected datatype of arrays. If None, no datatype check is performed. """ _check_array_integrity(prediction_arr, reference_arr, dtype=dtype) self._prediction_arr = prediction_arr self._reference_arr = reference_arr self.dtype = dtype self.n_dim = reference_arr.ndim - self._ref_labels: tuple[int, ...] = tuple( - _unique_without_zeros(reference_arr) - ) # type:ignore - self._pred_labels: tuple[int, ...] = tuple( - _unique_without_zeros(prediction_arr) - ) # type:ignore + self._ref_labels: tuple[int, ...] = tuple(_unique_without_zeros(reference_arr)) # type:ignore + self._pred_labels: tuple[int, ...] = tuple(_unique_without_zeros(prediction_arr)) # type:ignore self.crop: tuple[slice, ...] = None self.is_cropped: bool = False self.uncropped_shape: tuple[int, ...] = reference_arr.shape def crop_data(self, verbose: bool = False): + """Crops prediction and reference arrays to non-zero regions. + + Args: + verbose (bool, optional): If True, prints cropping details. Defaults to False. + """ if self.is_cropped: return if self.crop is None: @@ -62,41 +69,35 @@ def crop_data(self, verbose: bool = False): self._prediction_arr = self._prediction_arr[self.crop] self._reference_arr = self._reference_arr[self.crop] - ( - print( - f"-- Cropped from {self.uncropped_shape} to {self._prediction_arr.shape}" - ) - if verbose - else None - ) + (print(f"-- Cropped from {self.uncropped_shape} to {self._prediction_arr.shape}") if verbose else None) self.is_cropped = True def uncrop_data(self, verbose: bool = False): + """Restores the arrays to their original, uncropped shape. + + Args: + verbose (bool, optional): If True, prints uncropping details. Defaults to False. + """ if self.is_cropped == False: return - assert ( - self.uncropped_shape is not None - ), "Calling uncrop_data() without having cropped first" + assert self.uncropped_shape is not None, "Calling uncrop_data() without having cropped first" prediction_arr = np.zeros(self.uncropped_shape) prediction_arr[self.crop] = self._prediction_arr self._prediction_arr = prediction_arr reference_arr = np.zeros(self.uncropped_shape) reference_arr[self.crop] = self._reference_arr - ( - print( - f"-- Uncropped from {self._reference_arr.shape} to {self.uncropped_shape}" - ) - if verbose - else None - ) + (print(f"-- Uncropped from {self._reference_arr.shape} to {self.uncropped_shape}") if verbose else None) self._reference_arr = reference_arr self.is_cropped = False def set_dtype(self, type): - assert np.issubdtype( - type, int_type - ), "set_dtype: tried to set dtype to something other than integers" + """Sets the data type for both prediction and reference arrays. + + Args: + dtype (type): Expected integer type for the arrays. + """ + assert np.issubdtype(type, int_type), "set_dtype: tried to set dtype to something other than integers" self._prediction_arr = self._prediction_arr.astype(type) self._reference_arr = self._reference_arr.astype(type) @@ -136,8 +137,13 @@ def copy(self): class _ProcessingPairInstanced(_ProcessingPair): - """ - A ProcessingPair that contains instances, additionally has number of instances available + """Represents a processing pair with labeled instances, including unique label counts. + + This subclass tracks additional details about the number of unique instances in each array. + + Attributes: + n_prediction_instance (int): Number of unique prediction instances. + n_reference_instance (int): Number of unique reference instances. """ n_prediction_instance: int @@ -151,7 +157,15 @@ def __init__( n_prediction_instance: int | None = None, n_reference_instance: int | None = None, ) -> None: - # reduce to lowest uint + """Initializes a processing pair for instances. + + Args: + prediction_arr (np.ndarray): Array of predicted instance labels. + reference_arr (np.ndarray): Array of reference instance labels. + dtype (type | None): Expected data type of the arrays. + n_prediction_instance (int | None, optional): Pre-calculated number of prediction instances. + n_reference_instance (int | None, optional): Pre-calculated number of reference instances. + """ super().__init__(prediction_arr, reference_arr, dtype) if n_prediction_instance is None: self.n_prediction_instance = _count_unique_without_zeros(prediction_arr) @@ -175,23 +189,20 @@ def copy(self): ) # type:ignore -def _check_array_integrity( - prediction_arr: np.ndarray, reference_arr: np.ndarray, dtype: type | None = None -): - """ - Check the integrity of two numpy arrays. +def _check_array_integrity(prediction_arr: np.ndarray, reference_arr: np.ndarray, dtype: type | None = None): + """Validates integrity between two arrays, checking shape, dtype, and consistency with `dtype`. - Parameters: - - prediction_arr (np.ndarray): The array to be checked. - - reference_arr (np.ndarray): The reference array for comparison. - - dtype (type | None): The expected data type for both arrays. Defaults to None. + Args: + prediction_arr (np.ndarray): The array to be validated. + reference_arr (np.ndarray): The reference array for comparison. + dtype (type | None): Expected type of the arrays. If None, dtype validation is skipped. Raises: - - AssertionError: If prediction_arr or reference_arr are not numpy arrays. - - AssertionError: If the shapes of prediction_arr and reference_arr do not match. - - AssertionError: If the data types of prediction_arr and reference_arr do not match. - - AssertionError: If dtype is provided and the data types of prediction_arr and/or reference_arr - do not match the specified dtype. + AssertionError: If validation fails in any of the following cases: + - Arrays are not numpy arrays. + - Shapes of both arrays are not identical. + - Data types of both arrays do not match. + - Dtype mismatch when specified. Example: >>> _check_array_integrity(np.array([1, 2, 3]), np.array([4, 5, 6]), dtype=int) @@ -199,12 +210,8 @@ def _check_array_integrity( assert isinstance(prediction_arr, np.ndarray) and isinstance( reference_arr, np.ndarray ), "prediction and/or reference are not numpy arrays" - assert ( - prediction_arr.shape == reference_arr.shape - ), f"shape mismatch, got {prediction_arr.shape},{reference_arr.shape}" - assert ( - prediction_arr.dtype == reference_arr.dtype - ), f"dtype mismatch, got {prediction_arr.dtype},{reference_arr.dtype}" + assert prediction_arr.shape == reference_arr.shape, f"shape mismatch, got {prediction_arr.shape},{reference_arr.shape}" + assert prediction_arr.dtype == reference_arr.dtype, f"dtype mismatch, got {prediction_arr.dtype},{reference_arr.dtype}" if dtype is not None: assert ( np.issubdtype(prediction_arr.dtype, dtype) @@ -214,7 +221,10 @@ def _check_array_integrity( class SemanticPair(_ProcessingPair): - """A Processing pair that contains Semantic Labels""" + """Represents a semantic processing pair with integer-type arrays for label analysis. + + This class is tailored to scenarios where arrays contain semantic labels rather than instance IDs. + """ def __init__(self, prediction_arr: np.ndarray, reference_arr: np.ndarray) -> None: super().__init__(prediction_arr, reference_arr, dtype=int_type) @@ -243,9 +253,15 @@ def __init__( class MatchedInstancePair(_ProcessingPairInstanced): - """ - A Processing pair that contain Matched Instance Maps, i.e. each equal label in both maps are a match - Can be of any unsigned (but matching) integer type + """Represents a matched processing pair for instance maps, handling matched and unmatched labels. + + This class tracks both matched instances and any unmatched labels between prediction + and reference arrays. + + Attributes: + missed_reference_labels (list[int]): Reference labels with no matching prediction. + missed_prediction_labels (list[int]): Prediction labels with no matching reference. + matched_instances (list[int]): Labels matched between prediction and reference arrays. """ missed_reference_labels: list[int] @@ -287,15 +303,11 @@ def __init__( self.matched_instances = matched_instances if missed_reference_labels is None: - missed_reference_labels = list( - [i for i in self._ref_labels if i not in self._pred_labels] - ) + missed_reference_labels = list([i for i in self._ref_labels if i not in self._pred_labels]) self.missed_reference_labels = missed_reference_labels if missed_prediction_labels is None: - missed_prediction_labels = list( - [i for i in self._pred_labels if i not in self._ref_labels] - ) + missed_prediction_labels = list([i for i in self._pred_labels if i not in self._ref_labels]) self.missed_prediction_labels = missed_prediction_labels @property @@ -319,6 +331,21 @@ def copy(self): @dataclass class EvaluateInstancePair: + """Represents an evaluation of instance segmentation results, comparing reference and prediction data. + + This class is used to store and evaluate metrics for instance segmentation, tracking the number of instances + and true positives (tp) alongside calculated metrics. + + Attributes: + reference_arr (np.ndarray): Array containing reference instance labels. + prediction_arr (np.ndarray): Array containing predicted instance labels. + num_pred_instances (int): The number of unique instances in the prediction array. + num_ref_instances (int): The number of unique instances in the reference array. + tp (int): The number of true positive matches between predicted and reference instances. + list_metrics (dict[Metric, list[float]]): Dictionary of metric calculations, where each key is a `Metric` + object, and each value is a list of metric scores (floats). + """ + reference_arr: np.ndarray prediction_arr: np.ndarray num_pred_instances: int @@ -328,24 +355,50 @@ class EvaluateInstancePair: class InputType(_Enum_Compare): + """Defines the types of input processing pairs available for evaluation. + + This enumeration provides different processing classes for handling various instance segmentation scenarios, + allowing flexible instantiation of processing pairs based on the desired comparison type. + + Attributes: + SEMANTIC (SemanticPair): Processes semantic labels, intended for cases without instances. + UNMATCHED_INSTANCE (UnmatchedInstancePair): Processes instance maps without requiring label matches. + MATCHED_INSTANCE (MatchedInstancePair): Processes instance maps with label matching between prediction + and reference. + + Methods: + __call__(self, prediction_arr: np.ndarray, reference_arr: np.ndarray) -> _ProcessingPair: + Creates a processing pair based on the specified `InputType`, using the provided prediction + and reference arrays. + + Example: + >>> input_type = InputType.MATCHED_INSTANCE + >>> processing_pair = input_type(prediction_arr, reference_arr) + """ + SEMANTIC = SemanticPair UNMATCHED_INSTANCE = UnmatchedInstancePair MATCHED_INSTANCE = MatchedInstancePair - def __call__( - self, prediction_arr: np.ndarray, reference_arr: np.ndarray - ) -> _ProcessingPair: + def __call__(self, prediction_arr: np.ndarray, reference_arr: np.ndarray) -> _ProcessingPair: return self.value(prediction_arr, reference_arr) class IntermediateStepsData: + """Manages intermediate data steps for a processing pipeline, storing and retrieving processing states. + + This class enables step-by-step tracking of data transformations during processing. + + Attributes: + original_input (_ProcessingPair | None): The original input data before processing steps. + _intermediatesteps (dict[str, _ProcessingPair]): Dictionary of intermediate processing steps. + """ + def __init__(self, original_input: _ProcessingPair | None): self._original_input = original_input self._intermediatesteps: dict[str, _ProcessingPair] = {} - def add_intermediate_arr_data( - self, processing_pair: _ProcessingPair, inputtype: InputType - ): + def add_intermediate_arr_data(self, processing_pair: _ProcessingPair, inputtype: InputType): type_name = inputtype.name self.add_intermediate_data(type_name, processing_pair) @@ -355,36 +408,26 @@ def add_intermediate_data(self, key, value): @property def original_prediction_arr(self): - assert ( - self._original_input is not None - ), "Original prediction_arr is None, there are no intermediate steps" + assert self._original_input is not None, "Original prediction_arr is None, there are no intermediate steps" return self._original_input.prediction_arr @property def original_reference_arr(self): - assert ( - self._original_input is not None - ), "Original reference_arr is None, there are no intermediate steps" + assert self._original_input is not None, "Original reference_arr is None, there are no intermediate steps" return self._original_input.reference_arr def prediction_arr(self, inputtype: InputType): type_name = inputtype.name procpair = self[type_name] - assert isinstance( - procpair, _ProcessingPair - ), f"step {type_name} is not a processing pair, error" + assert isinstance(procpair, _ProcessingPair), f"step {type_name} is not a processing pair, error" return procpair.prediction_arr def reference_arr(self, inputtype: InputType): type_name = inputtype.name procpair = self[type_name] - assert isinstance( - procpair, _ProcessingPair - ), f"step {type_name} is not a processing pair, error" + assert isinstance(procpair, _ProcessingPair), f"step {type_name} is not a processing pair, error" return procpair.reference_arr def __getitem__(self, key): - assert ( - key in self._intermediatesteps - ), f"key {key} not in intermediate steps, maybe the step was skipped?" + assert key in self._intermediatesteps, f"key {key} not in intermediate steps, maybe the step was skipped?" return self._intermediatesteps[key] diff --git a/panoptica/utils/segmentation_class.py b/panoptica/utils/segmentation_class.py index 16308c4..85901c0 100644 --- a/panoptica/utils/segmentation_class.py +++ b/panoptica/utils/segmentation_class.py @@ -7,7 +7,24 @@ class SegmentationClassGroups(SupportsConfig): - # + """Represents a collection of segmentation class groups. + + This class manages groups of labels used in segmentation tasks, ensuring that each label is defined + exactly once across all groups. It supports both list and dictionary formats for group initialization. + + Attributes: + __group_dictionary (dict[str, LabelGroup]): A dictionary mapping group names to their respective LabelGroup instances. + __labels (list[int]): A flat list of unique labels collected from all LabelGroups. + + Args: + groups (list[LabelGroup] | dict[str, LabelGroup | tuple[list[int] | int, bool]]): + A list of `LabelGroup` instances or a dictionary where keys are group names (str) and values are either + `LabelGroup` instances or tuples containing a list of label values and a boolean. + + Raises: + AssertionError: If the same label is assigned to multiple groups. + """ + def __init__( self, groups: list[LabelGroup] | dict[str, LabelGroup | tuple[list[int] | int, bool]], @@ -17,9 +34,7 @@ def __init__( # maps name of group to the group itself if isinstance(groups, list): - self.__group_dictionary = { - f"group_{idx}": g for idx, g in enumerate(groups) - } + self.__group_dictionary = {f"group_{idx}": g for idx, g in enumerate(groups)} elif isinstance(groups, dict): # transform dict into list of LabelGroups for i, g in groups.items(): @@ -30,11 +45,7 @@ def __init__( self.__group_dictionary[name_lower] = LabelGroup(g[0], g[1]) # needs to check that each label is accounted for exactly ONCE - labels = [ - value_label - for lg in self.__group_dictionary.values() - for value_label in lg.value_labels - ] + labels = [value_label for lg in self.__group_dictionary.values() for value_label in lg.value_labels] duplicates = list_duplicates(labels) if len(duplicates) > 0: print( @@ -42,9 +53,19 @@ def __init__( ) self.__labels = labels - def has_defined_labels_for( - self, arr: np.ndarray | list[int], raise_error: bool = False - ): + def has_defined_labels_for(self, arr: np.ndarray | list[int], raise_error: bool = False): + """Checks if the labels in the provided array are defined in the segmentation class groups. + + Args: + arr (np.ndarray | list[int]): The array of labels to check. + raise_error (bool): If True, raises an error when an undefined label is found. Defaults to False. + + Returns: + bool: True if all labels are defined; False otherwise. + + Raises: + AssertionError: If an undefined label is found and raise_error is True. + """ if isinstance(arr, list): arr_labels = arr else: @@ -90,6 +111,14 @@ def _yaml_repr(cls, node): def list_duplicates(seq): + """Identifies duplicates in a sequence. + + Args: + seq (list): The input sequence to check for duplicates. + + Returns: + list: A list of duplicates found in the input sequence. + """ seen = set() seen_add = seen.add # adds all elements it doesn't know yet to seen and all other to seen_twice @@ -99,12 +128,18 @@ def list_duplicates(seq): class _NoSegmentationClassGroups(SegmentationClassGroups): + """Represents a placeholder for segmentation class groups with no defined labels. + + This class indicates that no specific segmentation groups or labels are defined, and any label is valid. + + Attributes: + __group_dictionary (dict[str, LabelGroup]): A dictionary with a single entry representing all labels as a group. + """ + def __init__(self) -> None: self.__group_dictionary = {NO_GROUP_KEY: _LabelGroupAny()} - def has_defined_labels_for( - self, arr: np.ndarray | list[int], raise_error: bool = False - ): + def has_defined_labels_for(self, arr: np.ndarray | list[int], raise_error: bool = False): return True def __str__(self) -> str: @@ -125,9 +160,7 @@ def keys(self) -> list[str]: @property def labels(self): - raise Exception( - "_NoSegmentationClassGroups has no explicit definition of labels" - ) + raise Exception("_NoSegmentationClassGroups has no explicit definition of labels") @classmethod def _yaml_repr(cls, node): diff --git a/unit_tests/test_metrics.py b/unit_tests/test_metrics.py index e26203c..9948826 100644 --- a/unit_tests/test_metrics.py +++ b/unit_tests/test_metrics.py @@ -142,6 +142,12 @@ def test_dsc_case_simple_identical_idx(self): dsc = Metric.DSC(reference_arr=ref_arr, prediction_arr=pred_arr, ref_instance_idx=1, pred_instance_idx=1) self.assertEqual(dsc, 1.0) + def test_dsc_case_simple_identical_wrong_idx(self): + + pred_arr, ref_arr = case_simple_identical() + dsc = Metric.DSC(reference_arr=ref_arr, prediction_arr=pred_arr, ref_instance_idx=2, pred_instance_idx=2) + self.assertEqual(dsc, 0.0) + def test_dsc_case_simple_nooverlap(self): pred_arr, ref_arr = case_simple_nooverlap() From dbbb54b833b876f4afb41dee72f55aee2c8d0a4c Mon Sep 17 00:00:00 2001 From: "brainless-bot[bot]" <153751247+brainless-bot[bot]@users.noreply.github.com> Date: Fri, 25 Oct 2024 12:56:40 +0000 Subject: [PATCH 2/2] Autoformat with black --- panoptica/_functionals.py | 23 ++++++-- panoptica/metrics/assd.py | 8 ++- panoptica/metrics/metrics.py | 56 +++++++++++++----- panoptica/panoptica_aggregator.py | 18 ++++-- panoptica/panoptica_result.py | 39 ++++++++++--- panoptica/panoptica_statistics.py | 42 ++++++++++---- panoptica/utils/config.py | 8 ++- panoptica/utils/edge_case_handling.py | 32 +++++++--- panoptica/utils/instancelabelmap.py | 16 +++-- panoptica/utils/label_group.py | 16 +++-- panoptica/utils/processing_pair.py | 84 +++++++++++++++++++++------ panoptica/utils/segmentation_class.py | 22 +++++-- unit_tests/test_metrics.py | 7 ++- 13 files changed, 286 insertions(+), 85 deletions(-) diff --git a/panoptica/_functionals.py b/panoptica/_functionals.py index 9b19bbb..4c8f10c 100644 --- a/panoptica/_functionals.py +++ b/panoptica/_functionals.py @@ -33,7 +33,11 @@ def _calc_overlapping_labels( # instance_pairs = [(reference_arr, prediction_arr, i, j) for i, j in overlapping_indices] # (ref, pred) - return [(int(i % (max_ref)), int(i // (max_ref))) for i in np.unique(overlap_arr) if i > max_ref] + return [ + (int(i % (max_ref)), int(i // (max_ref))) + for i in np.unique(overlap_arr) + if i > max_ref + ] def _calc_matching_metric_of_overlapping_labels( @@ -63,8 +67,13 @@ def _calc_matching_metric_of_overlapping_labels( with Pool() as pool: mm_values = pool.starmap(matching_metric.value, instance_pairs) - mm_pairs = [(i, (instance_pairs[idx][2], instance_pairs[idx][3])) for idx, i in enumerate(mm_values)] - mm_pairs = sorted(mm_pairs, key=lambda x: x[0], reverse=not matching_metric.decreasing) + mm_pairs = [ + (i, (instance_pairs[idx][2], instance_pairs[idx][3])) + for idx, i in enumerate(mm_values) + ] + mm_pairs = sorted( + mm_pairs, key=lambda x: x[0], reverse=not matching_metric.decreasing + ) return mm_pairs @@ -171,4 +180,10 @@ def _round_to_n(value: float | int, n_significant_digits: int = 2): Returns: float: The rounded value. """ - return value if value == 0 else round(value, -int(math.floor(math.log10(abs(value)))) + (n_significant_digits - 1)) + return ( + value + if value == 0 + else round( + value, -int(math.floor(math.log10(abs(value)))) + (n_significant_digits - 1) + ) + ) diff --git a/panoptica/metrics/assd.py b/panoptica/metrics/assd.py index f33a809..41d6937 100644 --- a/panoptica/metrics/assd.py +++ b/panoptica/metrics/assd.py @@ -85,8 +85,12 @@ def __surface_distances(reference, prediction, voxelspacing=None, connectivity=1 # raise RuntimeError("The second supplied array does not contain any binary object.") # extract only 1-pixel border line of objects - result_border = prediction ^ binary_erosion(prediction, structure=footprint, iterations=1) - reference_border = reference ^ binary_erosion(reference, structure=footprint, iterations=1) + result_border = prediction ^ binary_erosion( + prediction, structure=footprint, iterations=1 + ) + reference_border = reference ^ binary_erosion( + reference, structure=footprint, iterations=1 + ) # compute average surface distance # Note: scipys distance transform is calculated only inside the borders of the diff --git a/panoptica/metrics/metrics.py b/panoptica/metrics/metrics.py index 2d00638..7e30c53 100644 --- a/panoptica/metrics/metrics.py +++ b/panoptica/metrics/metrics.py @@ -72,7 +72,9 @@ def __call__( reference_arr = reference_arr.copy() == ref_instance_idx if isinstance(pred_instance_idx, int): pred_instance_idx = [pred_instance_idx] - prediction_arr = np.isin(prediction_arr.copy(), pred_instance_idx) # type:ignore + prediction_arr = np.isin( + prediction_arr.copy(), pred_instance_idx + ) # type:ignore return self._metric_function(reference_arr, prediction_arr, *args, **kwargs) def __eq__(self, __value: object) -> bool: @@ -106,7 +108,9 @@ def increasing(self): """ return not self.decreasing - def score_beats_threshold(self, matching_score: float, matching_threshold: float) -> bool: + def score_beats_threshold( + self, matching_score: float, matching_threshold: float + ) -> bool: """Determines if a matching score meets a specified threshold. Args: @@ -117,7 +121,9 @@ def score_beats_threshold(self, matching_score: float, matching_threshold: float bool: True if the score meets the threshold, taking into account the metric's preferred direction. """ - return (self.increasing and matching_score >= matching_threshold) or (self.decreasing and matching_score <= matching_threshold) + return (self.increasing and matching_score >= matching_threshold) or ( + self.decreasing and matching_score <= matching_threshold + ) class DirectValueMeta(EnumMeta): @@ -184,7 +190,9 @@ def __call__( **kwargs, ) - def score_beats_threshold(self, matching_score: float, matching_threshold: float) -> bool: + def score_beats_threshold( + self, matching_score: float, matching_threshold: float + ) -> bool: """Calculates whether a score beats a specified threshold Args: @@ -194,7 +202,9 @@ def score_beats_threshold(self, matching_score: float, matching_threshold: float Returns: bool: True if the matching_score beats the threshold, False otherwise. """ - return (self.increasing and matching_score >= matching_threshold) or (self.decreasing and matching_score <= matching_threshold) + return (self.increasing and matching_score >= matching_threshold) or ( + self.decreasing and matching_score <= matching_threshold + ) @property def name(self): @@ -293,7 +303,9 @@ def __call__(self, result_obj: "PanopticaResult") -> Any: # ERROR if self._error: if self._error_obj is None: - self._error_obj = MetricCouldNotBeComputedException(f"Metric {self.id} requested, but could not be computed") + self._error_obj = MetricCouldNotBeComputedException( + f"Metric {self.id} requested, but could not be computed" + ) raise self._error_obj # Already calculated? if self._was_calculated: @@ -301,8 +313,12 @@ def __call__(self, result_obj: "PanopticaResult") -> Any: # Calculate it try: - assert not self._was_calculated, f"Metric {self.id} was called to compute, but is set to have been already calculated" - assert self._calc_func is not None, f"Metric {self.id} was called to compute, but has no calculation function set" + assert ( + not self._was_calculated + ), f"Metric {self.id} was called to compute, but is set to have been already calculated" + assert ( + self._calc_func is not None + ), f"Metric {self.id} was called to compute, but has no calculation function set" value = self._calc_func(result_obj) except MetricCouldNotBeComputedException as e: value = e @@ -347,20 +363,32 @@ def __init__( else: self.AVG = None if self.ALL is None else np.average(self.ALL) self.SUM = None if self.ALL is None else np.sum(self.ALL) - self.MIN = None if self.ALL is None or len(self.ALL) == 0 else np.min(self.ALL) - self.MAX = None if self.ALL is None or len(self.ALL) == 0 else np.max(self.ALL) - - self.STD = None if self.ALL is None else empty_list_std if len(self.ALL) == 0 else np.std(self.ALL) + self.MIN = ( + None if self.ALL is None or len(self.ALL) == 0 else np.min(self.ALL) + ) + self.MAX = ( + None if self.ALL is None or len(self.ALL) == 0 else np.max(self.ALL) + ) + + self.STD = ( + None + if self.ALL is None + else empty_list_std if len(self.ALL) == 0 else np.std(self.ALL) + ) def __getitem__(self, mode: MetricMode | str): if self.error: - raise MetricCouldNotBeComputedException(f"Metric {self.id} has not been calculated, add it to your eval_metrics") + raise MetricCouldNotBeComputedException( + f"Metric {self.id} has not been calculated, add it to your eval_metrics" + ) if isinstance(mode, MetricMode): mode = mode.name if hasattr(self, mode): return getattr(self, mode) else: - raise MetricCouldNotBeComputedException(f"List_Metric {self.id} does not contain {mode} member") + raise MetricCouldNotBeComputedException( + f"List_Metric {self.id} does not contain {mode} member" + ) if __name__ == "__main__": diff --git a/panoptica/panoptica_aggregator.py b/panoptica/panoptica_aggregator.py index e0d761d..0c0c527 100644 --- a/panoptica/panoptica_aggregator.py +++ b/panoptica/panoptica_aggregator.py @@ -66,7 +66,9 @@ def __init__( if isinstance(output_file, str): output_file = Path(output_file) # uses tsv - assert output_file.parent.exists(), f"Directory {str(output_file.parent)} does not exist" + assert ( + output_file.parent.exists() + ), f"Directory {str(output_file.parent)} does not exist" out_file_path = str(output_file) @@ -80,13 +82,19 @@ def __init__( else: out_file_path += ".tsv" # add extension - out_buffer_file: Path = Path(out_file_path).parent.joinpath("panoptica_aggregator_tmp.tsv") + out_buffer_file: Path = Path(out_file_path).parent.joinpath( + "panoptica_aggregator_tmp.tsv" + ) self.__output_buffer_file = out_buffer_file Path(out_file_path).parent.mkdir(parents=True, exist_ok=True) self.__output_file = out_file_path - header = ["subject_name"] + [f"{g}-{m}" for g in self.__class_group_names for m in self.__evaluation_metrics] + header = ["subject_name"] + [ + f"{g}-{m}" + for g in self.__class_group_names + for m in self.__evaluation_metrics + ] header_hash = hash("+".join(header)) if not output_file.exists(): @@ -100,7 +108,9 @@ def __init__( continue_file = True else: # TODO should also hash panoptica_evaluator just to make sure! and then save into header of file - assert header_hash == hash("+".join(header_list)), "Hash of header not the same! You are using a different setup!" + assert header_hash == hash( + "+".join(header_list) + ), "Hash of header not the same! You are using a different setup!" if out_buffer_file.exists(): os.remove(out_buffer_file) diff --git a/panoptica/panoptica_result.py b/panoptica/panoptica_result.py index 12341ec..da9c884 100644 --- a/panoptica/panoptica_result.py +++ b/panoptica/panoptica_result.py @@ -255,19 +255,25 @@ def __init__( num_pred_instances=self.num_pred_instances, num_ref_instances=self.num_ref_instances, ) - self._list_metrics[m] = Evaluation_List_Metric(m, empty_list_std, list_metrics[m], is_edge_case, edge_case_result) + self._list_metrics[m] = Evaluation_List_Metric( + m, empty_list_std, list_metrics[m], is_edge_case, edge_case_result + ) # even if not available, set the global vars default_value = None was_calculated = False if m in self._global_metrics and arrays_present: - default_value = self._calc_global_bin_metric(m, pred_binary, ref_binary, do_binarize=False) + default_value = self._calc_global_bin_metric( + m, pred_binary, ref_binary, do_binarize=False + ) was_calculated = True self._add_metric( f"global_bin_{m.name.lower()}", MetricType.GLOBAL, - lambda x: MetricCouldNotBeComputedException(f"Global Metric {m} not set"), + lambda x: MetricCouldNotBeComputedException( + f"Global Metric {m} not set" + ), long_name="Global Binary " + m.value.long_name, default_value=default_value, was_calculated=was_calculated, @@ -310,7 +316,9 @@ def _calc_global_bin_metric( prediction_empty = pred_binary.sum() == 0 reference_empty = ref_binary.sum() == 0 if prediction_empty or reference_empty: - is_edgecase, result = self._edge_case_handler.handle_zero_tp(metric, 0, int(prediction_empty), int(reference_empty)) + is_edgecase, result = self._edge_case_handler.handle_zero_tp( + metric, 0, int(prediction_empty), int(reference_empty) + ) if is_edgecase: return result @@ -427,7 +435,11 @@ def to_dict(self) -> dict: Returns: A dictionary containing metric names and their values. """ - return {k: getattr(self, v.id) for k, v in self._evaluation_metrics.items() if (v._error == False and v._was_calculated)} + return { + k: getattr(self, v.id) + for k, v in self._evaluation_metrics.items() + if (v._error == False and v._was_calculated) + } @property def evaluation_metrics(self): @@ -450,7 +462,9 @@ def get_list_metric(self, metric: Metric, mode: MetricMode): if metric in self._list_metrics: return self._list_metrics[metric][mode] else: - raise MetricCouldNotBeComputedException(f"{metric} could not be found, have you set it in eval_metrics during evaluation?") + raise MetricCouldNotBeComputedException( + f"{metric} could not be found, have you set it in eval_metrics during evaluation?" + ) def _calc_metric(self, metric_name: str, supress_error: bool = False): """ @@ -479,7 +493,9 @@ def _calc_metric(self, metric_name: str, supress_error: bool = False): self._evaluation_metrics[metric_name]._was_calculated = True return value else: - raise MetricCouldNotBeComputedException(f"could not find metric with name {metric_name}") + raise MetricCouldNotBeComputedException( + f"could not find metric with name {metric_name}" + ) def __getattribute__(self, __name: str) -> Any: """ @@ -506,10 +522,15 @@ def __getattribute__(self, __name: str) -> Any: raise e if __name == "_evaluation_metrics": return attr - if object.__getattribute__(self, "_evaluation_metrics") is not None and __name in self._evaluation_metrics.keys(): + if ( + object.__getattribute__(self, "_evaluation_metrics") is not None + and __name in self._evaluation_metrics.keys() + ): if attr is None: if self._evaluation_metrics[__name]._error: - raise MetricCouldNotBeComputedException(f"Requested metric {__name} that could not be computed") + raise MetricCouldNotBeComputedException( + f"Requested metric {__name} that could not be computed" + ) elif not self._evaluation_metrics[__name]._was_calculated: value = self._calc_metric(__name) setattr(self, __name, value) diff --git a/panoptica/panoptica_statistics.py b/panoptica/panoptica_statistics.py index d3f4303..96489b3 100644 --- a/panoptica/panoptica_statistics.py +++ b/panoptica/panoptica_statistics.py @@ -43,7 +43,9 @@ def from_file(cls, file: str): rows = [row for row in rd] header = rows[0] - assert header[0] == "subject_name", "First column is not subject_names, something wrong with the file?" + assert ( + header[0] == "subject_name" + ), "First column is not subject_names, something wrong with the file?" keys_in_order = list([tuple(c.split("-")) for c in header[1:]]) metric_names = [] @@ -80,13 +82,19 @@ def from_file(cls, file: str): return Panoptica_Statistic(subj_names=subj_names, value_dict=value_dict) def _assertgroup(self, group): - assert group in self.__groupnames, f"group {group} not existent, only got groups {self.__groupnames}" + assert ( + group in self.__groupnames + ), f"group {group} not existent, only got groups {self.__groupnames}" def _assertmetric(self, metric): - assert metric in self.__metricnames, f"metric {metric} not existent, only got metrics {self.__metricnames}" + assert ( + metric in self.__metricnames + ), f"metric {metric} not existent, only got metrics {self.__metricnames}" def _assertsubject(self, subjectname): - assert subjectname in self.__subj_names, f"subject {subjectname} not in list of subjects, got {self.__subj_names}" + assert ( + subjectname in self.__subj_names + ), f"subject {subjectname} not in list of subjects, got {self.__subj_names}" def get(self, group, metric) -> list[float]: """Returns the list of values for given group and metric @@ -117,7 +125,10 @@ def get_one_subject(self, subjectname: str): """ self._assertsubject(subjectname) sidx = self.__subj_names.index(subjectname) - return {g: {m: self.get(g, m)[sidx] for m in self.__metricnames} for g in self.__groupnames} + return { + g: {m: self.get(g, m)[sidx] for m in self.__metricnames} + for g in self.__groupnames + } def get_across_groups(self, metric): """Given metric, gives list of all values (even across groups!) Treat with care! @@ -134,7 +145,10 @@ def get_across_groups(self, metric): return values def get_summary_dict(self): - return {g: {m: self.get_summary(g, m) for m in self.__metricnames} for g in self.__groupnames} + return { + g: {m: self.get_summary(g, m) for m in self.__metricnames} + for g in self.__groupnames + } def get_summary(self, group, metric): # TODO maybe more here? range, stuff like that @@ -206,7 +220,9 @@ def make_curve_over_setups( alternate_groupnames = [alternate_groupnames] # for setupname, stat in statistics_dict.items(): - assert metric in stat.metricnames, f"metric {metric} not in statistic obj {setupname}" + assert ( + metric in stat.metricnames + ), f"metric {metric} not in statistic obj {setupname}" setupnames = list(statistics_dict.keys()) convert_x_to_digit = True @@ -282,10 +298,14 @@ def plot_box( if sort: df_by_spec_count = df_data.groupby(graph_name).mean() df_by_spec_count = dict(df_by_spec_count[score].items()) - df_data["mean"] = df_data[graph_name].apply(lambda x: df_by_spec_count[x] * (1 if orientation == "h" else -1)) + df_data["mean"] = df_data[graph_name].apply( + lambda x: df_by_spec_count[x] * (1 if orientation == "h" else -1) + ) df_data = df_data.sort_values(by="mean") if orientation == "v": - fig = px.strip(df_data, x=graph_name, y=score, stripmode="overlay", orientation=orientation) + fig = px.strip( + df_data, x=graph_name, y=score, stripmode="overlay", orientation=orientation + ) fig.update_traces(marker={"size": 5, "color": "#555555"}) for e in data.keys(): fig.add_trace( @@ -296,7 +316,9 @@ def plot_box( ) ) else: - fig = px.strip(df_data, y=graph_name, x=score, stripmode="overlay", orientation=orientation) + fig = px.strip( + df_data, y=graph_name, x=score, stripmode="overlay", orientation=orientation + ) fig.update_traces(marker={"size": 5, "color": "#555555"}) for e in data.keys(): fig.add_trace( diff --git a/panoptica/utils/config.py b/panoptica/utils/config.py index 30d332c..40c511c 100644 --- a/panoptica/utils/config.py +++ b/panoptica/utils/config.py @@ -174,7 +174,9 @@ def load_from_config(cls, path: str | Path): An instance of the class. """ obj = _load_from_config(cls, path) - assert isinstance(obj, cls), f"loaded object was not of the correct class, expected {cls.__name__} but got {type(obj)}" + assert isinstance( + obj, cls + ), f"loaded object was not of the correct class, expected {cls.__name__} but got {type(obj)}" return obj @classmethod @@ -218,7 +220,9 @@ def to_yaml(cls, representer, node): Returns: YAML node: YAML-compatible node representation of the object. """ - assert hasattr(cls, "_yaml_repr"), f"Class {cls.__name__} has no _yaml_repr(cls, node) defined" + assert hasattr( + cls, "_yaml_repr" + ), f"Class {cls.__name__} has no _yaml_repr(cls, node) defined" return representer.represent_mapping("!" + cls.__name__, cls._yaml_repr(node)) @classmethod diff --git a/panoptica/utils/edge_case_handling.py b/panoptica/utils/edge_case_handling.py index c9c0b0f..35267bd 100644 --- a/panoptica/utils/edge_case_handling.py +++ b/panoptica/utils/edge_case_handling.py @@ -98,12 +98,26 @@ def __init__( self._default_result = default_result self._edgecase_dict: dict[EdgeCaseZeroTP, EdgeCaseResult] = {} - self._edgecase_dict[EdgeCaseZeroTP.EMPTY_PRED] = empty_prediction_result if empty_prediction_result is not None else default_result - self._edgecase_dict[EdgeCaseZeroTP.EMPTY_REF] = empty_reference_result if empty_reference_result is not None else default_result - self._edgecase_dict[EdgeCaseZeroTP.NO_INSTANCES] = no_instances_result if no_instances_result is not None else default_result - self._edgecase_dict[EdgeCaseZeroTP.NORMAL] = normal if normal is not None else default_result + self._edgecase_dict[EdgeCaseZeroTP.EMPTY_PRED] = ( + empty_prediction_result + if empty_prediction_result is not None + else default_result + ) + self._edgecase_dict[EdgeCaseZeroTP.EMPTY_REF] = ( + empty_reference_result + if empty_reference_result is not None + else default_result + ) + self._edgecase_dict[EdgeCaseZeroTP.NO_INSTANCES] = ( + no_instances_result if no_instances_result is not None else default_result + ) + self._edgecase_dict[EdgeCaseZeroTP.NORMAL] = ( + normal if normal is not None else default_result + ) - def __call__(self, tp: int, num_pred_instances, num_ref_instances) -> tuple[bool, float | None]: + def __call__( + self, tp: int, num_pred_instances, num_ref_instances + ) -> tuple[bool, float | None]: if tp != 0: return False, EdgeCaseResult.NONE.value # @@ -187,7 +201,9 @@ def __init__( }, empty_list_std: EdgeCaseResult = EdgeCaseResult.NAN, ) -> None: - self.__listmetric_zeroTP_handling: dict[Metric, MetricZeroTPEdgeCaseHandling] = listmetric_zeroTP_handling + self.__listmetric_zeroTP_handling: dict[ + Metric, MetricZeroTPEdgeCaseHandling + ] = listmetric_zeroTP_handling self.__empty_list_std: EdgeCaseResult = empty_list_std def handle_zero_tp( @@ -214,7 +230,9 @@ def handle_zero_tp( if tp != 0: return False, EdgeCaseResult.NONE.value if metric not in self.__listmetric_zeroTP_handling: - raise NotImplementedError(f"Metric {metric} encountered zero TP, but no edge handling available") + raise NotImplementedError( + f"Metric {metric} encountered zero TP, but no edge handling available" + ) return self.__listmetric_zeroTP_handling[metric]( tp=tp, diff --git a/panoptica/utils/instancelabelmap.py b/panoptica/utils/instancelabelmap.py index a0caa45..f0c4b8c 100644 --- a/panoptica/utils/instancelabelmap.py +++ b/panoptica/utils/instancelabelmap.py @@ -42,7 +42,9 @@ def add_labelmap_entry(self, pred_labels: list[int] | int, ref_label: int): if not isinstance(pred_labels, list): pred_labels = [pred_labels] assert isinstance(ref_label, int), "add_labelmap_entry: got no int as ref_label" - assert np.all([isinstance(r, int) for r in pred_labels]), "add_labelmap_entry: got no int as pred_label" + assert np.all( + [isinstance(r, int) for r in pred_labels] + ), "add_labelmap_entry: got no int as pred_label" for p in pred_labels: if p in self.labelmap and self.labelmap[p] != ref_label: raise Exception( @@ -83,7 +85,9 @@ def contains_ref(self, ref_label: int): """ return ref_label in self.labelmap.values() - def contains_and(self, pred_label: int | None = None, ref_label: int | None = None) -> bool: + def contains_and( + self, pred_label: int | None = None, ref_label: int | None = None + ) -> bool: """Checks if both a prediction and a reference label are in the map. Args: @@ -97,7 +101,9 @@ def contains_and(self, pred_label: int | None = None, ref_label: int | None = No ref_in = True if ref_label is None else ref_label in self.labelmap.values() return pred_in and ref_in - def contains_or(self, pred_label: int | None = None, ref_label: int | None = None) -> bool: + def contains_or( + self, pred_label: int | None = None, ref_label: int | None = None + ) -> bool: """Checks if either a prediction or reference label is in the map. Args: @@ -123,7 +129,9 @@ def __str__(self) -> str: return str( list( [ - str(tuple(k for k in self.labelmap.keys() if self.labelmap[k] == v)) + " -> " + str(v) + str(tuple(k for k in self.labelmap.keys() if self.labelmap[k] == v)) + + " -> " + + str(v) for v in set(self.labelmap.values()) ] ) diff --git a/panoptica/utils/label_group.py b/panoptica/utils/label_group.py index 8f1ac72..0008450 100644 --- a/panoptica/utils/label_group.py +++ b/panoptica/utils/label_group.py @@ -35,12 +35,18 @@ def __init__( value_labels = list(set(value_labels)) - assert len(value_labels) >= 1, f"You tried to define a LabelGroup without any specified labels, got {value_labels}" + assert ( + len(value_labels) >= 1 + ), f"You tried to define a LabelGroup without any specified labels, got {value_labels}" self.__value_labels = value_labels - assert np.all([v > 0 for v in self.__value_labels]), f"Given value labels are not >0, got {value_labels}" + assert np.all( + [v > 0 for v in self.__value_labels] + ), f"Given value labels are not >0, got {value_labels}" self.__single_instance = single_instance if self.__single_instance: - assert len(value_labels) == 1, f"single_instance set to True, but got more than one label for this group, got {value_labels}" + assert ( + len(value_labels) == 1 + ), f"single_instance set to True, but got more than one label for this group, got {value_labels}" LabelGroup._register_permanently() @@ -111,7 +117,9 @@ class LabelMergeGroup(LabelGroup): __call__(array): Extracts the label group as a binary array. """ - def __init__(self, value_labels: list[int] | int, single_instance: bool = False) -> None: + def __init__( + self, value_labels: list[int] | int, single_instance: bool = False + ) -> None: super().__init__(value_labels, single_instance) def __call__( diff --git a/panoptica/utils/processing_pair.py b/panoptica/utils/processing_pair.py index 1514ef8..f6901b7 100644 --- a/panoptica/utils/processing_pair.py +++ b/panoptica/utils/processing_pair.py @@ -33,7 +33,9 @@ class _ProcessingPair(ABC): _pred_labels: tuple[int, ...] n_dim: int - def __init__(self, prediction_arr: np.ndarray, reference_arr: np.ndarray, dtype: type | None) -> None: + def __init__( + self, prediction_arr: np.ndarray, reference_arr: np.ndarray, dtype: type | None + ) -> None: """Initializes the processing pair with prediction and reference arrays. Args: @@ -46,8 +48,12 @@ def __init__(self, prediction_arr: np.ndarray, reference_arr: np.ndarray, dtype: self._reference_arr = reference_arr self.dtype = dtype self.n_dim = reference_arr.ndim - self._ref_labels: tuple[int, ...] = tuple(_unique_without_zeros(reference_arr)) # type:ignore - self._pred_labels: tuple[int, ...] = tuple(_unique_without_zeros(prediction_arr)) # type:ignore + self._ref_labels: tuple[int, ...] = tuple( + _unique_without_zeros(reference_arr) + ) # type:ignore + self._pred_labels: tuple[int, ...] = tuple( + _unique_without_zeros(prediction_arr) + ) # type:ignore self.crop: tuple[slice, ...] = None self.is_cropped: bool = False self.uncropped_shape: tuple[int, ...] = reference_arr.shape @@ -69,7 +75,13 @@ def crop_data(self, verbose: bool = False): self._prediction_arr = self._prediction_arr[self.crop] self._reference_arr = self._reference_arr[self.crop] - (print(f"-- Cropped from {self.uncropped_shape} to {self._prediction_arr.shape}") if verbose else None) + ( + print( + f"-- Cropped from {self.uncropped_shape} to {self._prediction_arr.shape}" + ) + if verbose + else None + ) self.is_cropped = True def uncrop_data(self, verbose: bool = False): @@ -80,14 +92,22 @@ def uncrop_data(self, verbose: bool = False): """ if self.is_cropped == False: return - assert self.uncropped_shape is not None, "Calling uncrop_data() without having cropped first" + assert ( + self.uncropped_shape is not None + ), "Calling uncrop_data() without having cropped first" prediction_arr = np.zeros(self.uncropped_shape) prediction_arr[self.crop] = self._prediction_arr self._prediction_arr = prediction_arr reference_arr = np.zeros(self.uncropped_shape) reference_arr[self.crop] = self._reference_arr - (print(f"-- Uncropped from {self._reference_arr.shape} to {self.uncropped_shape}") if verbose else None) + ( + print( + f"-- Uncropped from {self._reference_arr.shape} to {self.uncropped_shape}" + ) + if verbose + else None + ) self._reference_arr = reference_arr self.is_cropped = False @@ -97,7 +117,9 @@ def set_dtype(self, type): Args: dtype (type): Expected integer type for the arrays. """ - assert np.issubdtype(type, int_type), "set_dtype: tried to set dtype to something other than integers" + assert np.issubdtype( + type, int_type + ), "set_dtype: tried to set dtype to something other than integers" self._prediction_arr = self._prediction_arr.astype(type) self._reference_arr = self._reference_arr.astype(type) @@ -189,7 +211,9 @@ def copy(self): ) # type:ignore -def _check_array_integrity(prediction_arr: np.ndarray, reference_arr: np.ndarray, dtype: type | None = None): +def _check_array_integrity( + prediction_arr: np.ndarray, reference_arr: np.ndarray, dtype: type | None = None +): """Validates integrity between two arrays, checking shape, dtype, and consistency with `dtype`. Args: @@ -210,8 +234,12 @@ def _check_array_integrity(prediction_arr: np.ndarray, reference_arr: np.ndarray assert isinstance(prediction_arr, np.ndarray) and isinstance( reference_arr, np.ndarray ), "prediction and/or reference are not numpy arrays" - assert prediction_arr.shape == reference_arr.shape, f"shape mismatch, got {prediction_arr.shape},{reference_arr.shape}" - assert prediction_arr.dtype == reference_arr.dtype, f"dtype mismatch, got {prediction_arr.dtype},{reference_arr.dtype}" + assert ( + prediction_arr.shape == reference_arr.shape + ), f"shape mismatch, got {prediction_arr.shape},{reference_arr.shape}" + assert ( + prediction_arr.dtype == reference_arr.dtype + ), f"dtype mismatch, got {prediction_arr.dtype},{reference_arr.dtype}" if dtype is not None: assert ( np.issubdtype(prediction_arr.dtype, dtype) @@ -303,11 +331,15 @@ def __init__( self.matched_instances = matched_instances if missed_reference_labels is None: - missed_reference_labels = list([i for i in self._ref_labels if i not in self._pred_labels]) + missed_reference_labels = list( + [i for i in self._ref_labels if i not in self._pred_labels] + ) self.missed_reference_labels = missed_reference_labels if missed_prediction_labels is None: - missed_prediction_labels = list([i for i in self._pred_labels if i not in self._ref_labels]) + missed_prediction_labels = list( + [i for i in self._pred_labels if i not in self._ref_labels] + ) self.missed_prediction_labels = missed_prediction_labels @property @@ -380,7 +412,9 @@ class InputType(_Enum_Compare): UNMATCHED_INSTANCE = UnmatchedInstancePair MATCHED_INSTANCE = MatchedInstancePair - def __call__(self, prediction_arr: np.ndarray, reference_arr: np.ndarray) -> _ProcessingPair: + def __call__( + self, prediction_arr: np.ndarray, reference_arr: np.ndarray + ) -> _ProcessingPair: return self.value(prediction_arr, reference_arr) @@ -398,7 +432,9 @@ def __init__(self, original_input: _ProcessingPair | None): self._original_input = original_input self._intermediatesteps: dict[str, _ProcessingPair] = {} - def add_intermediate_arr_data(self, processing_pair: _ProcessingPair, inputtype: InputType): + def add_intermediate_arr_data( + self, processing_pair: _ProcessingPair, inputtype: InputType + ): type_name = inputtype.name self.add_intermediate_data(type_name, processing_pair) @@ -408,26 +444,36 @@ def add_intermediate_data(self, key, value): @property def original_prediction_arr(self): - assert self._original_input is not None, "Original prediction_arr is None, there are no intermediate steps" + assert ( + self._original_input is not None + ), "Original prediction_arr is None, there are no intermediate steps" return self._original_input.prediction_arr @property def original_reference_arr(self): - assert self._original_input is not None, "Original reference_arr is None, there are no intermediate steps" + assert ( + self._original_input is not None + ), "Original reference_arr is None, there are no intermediate steps" return self._original_input.reference_arr def prediction_arr(self, inputtype: InputType): type_name = inputtype.name procpair = self[type_name] - assert isinstance(procpair, _ProcessingPair), f"step {type_name} is not a processing pair, error" + assert isinstance( + procpair, _ProcessingPair + ), f"step {type_name} is not a processing pair, error" return procpair.prediction_arr def reference_arr(self, inputtype: InputType): type_name = inputtype.name procpair = self[type_name] - assert isinstance(procpair, _ProcessingPair), f"step {type_name} is not a processing pair, error" + assert isinstance( + procpair, _ProcessingPair + ), f"step {type_name} is not a processing pair, error" return procpair.reference_arr def __getitem__(self, key): - assert key in self._intermediatesteps, f"key {key} not in intermediate steps, maybe the step was skipped?" + assert ( + key in self._intermediatesteps + ), f"key {key} not in intermediate steps, maybe the step was skipped?" return self._intermediatesteps[key] diff --git a/panoptica/utils/segmentation_class.py b/panoptica/utils/segmentation_class.py index 85901c0..a0a32b9 100644 --- a/panoptica/utils/segmentation_class.py +++ b/panoptica/utils/segmentation_class.py @@ -34,7 +34,9 @@ def __init__( # maps name of group to the group itself if isinstance(groups, list): - self.__group_dictionary = {f"group_{idx}": g for idx, g in enumerate(groups)} + self.__group_dictionary = { + f"group_{idx}": g for idx, g in enumerate(groups) + } elif isinstance(groups, dict): # transform dict into list of LabelGroups for i, g in groups.items(): @@ -45,7 +47,11 @@ def __init__( self.__group_dictionary[name_lower] = LabelGroup(g[0], g[1]) # needs to check that each label is accounted for exactly ONCE - labels = [value_label for lg in self.__group_dictionary.values() for value_label in lg.value_labels] + labels = [ + value_label + for lg in self.__group_dictionary.values() + for value_label in lg.value_labels + ] duplicates = list_duplicates(labels) if len(duplicates) > 0: print( @@ -53,7 +59,9 @@ def __init__( ) self.__labels = labels - def has_defined_labels_for(self, arr: np.ndarray | list[int], raise_error: bool = False): + def has_defined_labels_for( + self, arr: np.ndarray | list[int], raise_error: bool = False + ): """Checks if the labels in the provided array are defined in the segmentation class groups. Args: @@ -139,7 +147,9 @@ class _NoSegmentationClassGroups(SegmentationClassGroups): def __init__(self) -> None: self.__group_dictionary = {NO_GROUP_KEY: _LabelGroupAny()} - def has_defined_labels_for(self, arr: np.ndarray | list[int], raise_error: bool = False): + def has_defined_labels_for( + self, arr: np.ndarray | list[int], raise_error: bool = False + ): return True def __str__(self) -> str: @@ -160,7 +170,9 @@ def keys(self) -> list[str]: @property def labels(self): - raise Exception("_NoSegmentationClassGroups has no explicit definition of labels") + raise Exception( + "_NoSegmentationClassGroups has no explicit definition of labels" + ) @classmethod def _yaml_repr(cls, node): diff --git a/unit_tests/test_metrics.py b/unit_tests/test_metrics.py index a6ef90b..343cb87 100644 --- a/unit_tests/test_metrics.py +++ b/unit_tests/test_metrics.py @@ -155,7 +155,12 @@ def test_dsc_case_simple_identical_idx(self): def test_dsc_case_simple_identical_wrong_idx(self): pred_arr, ref_arr = case_simple_identical() - dsc = Metric.DSC(reference_arr=ref_arr, prediction_arr=pred_arr, ref_instance_idx=2, pred_instance_idx=2) + dsc = Metric.DSC( + reference_arr=ref_arr, + prediction_arr=pred_arr, + ref_instance_idx=2, + pred_instance_idx=2, + ) self.assertEqual(dsc, 0.0) def test_dsc_case_simple_nooverlap(self):