diff --git a/examples/example_spine_instance.py b/examples/example_spine_instance.py index 354d5f2..25eac85 100644 --- a/examples/example_spine_instance.py +++ b/examples/example_spine_instance.py @@ -32,7 +32,7 @@ def main(): with cProfile.Profile() as pr: results = evaluator.evaluate(prediction_mask, reference_mask, verbose=False) - for groupname, (result, intermediate_steps_data) in results.items(): + for groupname, result in results.items(): print() print("### Group", groupname) print(result) diff --git a/examples/example_spine_instance_config.py b/examples/example_spine_instance_config.py index 4fbe74e..a038386 100644 --- a/examples/example_spine_instance_config.py +++ b/examples/example_spine_instance_config.py @@ -18,7 +18,7 @@ def main(): with cProfile.Profile() as pr: results = evaluator.evaluate(prediction_mask, reference_mask, verbose=False) - for groupname, (result, intermediate_steps_data) in results.items(): + for groupname, result in results.items(): print() print("### Group", groupname) print(result) diff --git a/examples/example_spine_semantic.py b/examples/example_spine_semantic.py index 0ec88f6..ae427ee 100644 --- a/examples/example_spine_semantic.py +++ b/examples/example_spine_semantic.py @@ -27,13 +27,13 @@ def main(): with cProfile.Profile() as pr: - result, intermediate_steps_data = evaluator.evaluate( - prediction_mask, reference_mask - )["ungrouped"] + result = evaluator.evaluate(prediction_mask, reference_mask)["ungrouped"] # To print the results, just call print print(result) + intermediate_steps_data = result.intermediate_steps_data + assert intermediate_steps_data is not None # To get the different intermediate arrays, just use the second returned object intermediate_steps_data.original_prediction_arr # Input prediction array, untouched intermediate_steps_data.original_reference_arr # Input reference array, untouched diff --git a/panoptica/instance_approximator.py b/panoptica/instance_approximator.py index aff01cb..5bb9733 100644 --- a/panoptica/instance_approximator.py +++ b/panoptica/instance_approximator.py @@ -4,7 +4,8 @@ from panoptica.utils.constants import CCABackend from panoptica._functionals import _connected_components -from panoptica.utils.numpy_utils import _get_smallest_fitting_uint + +# from panoptica.utils.numpy_utils import _get_smallest_fitting_uint from panoptica.utils.processing_pair import ( MatchedInstancePair, SemanticPair, @@ -80,7 +81,7 @@ def approximate_instances( AssertionError: If there are negative values in the semantic maps, which is not allowed. """ # Check validity - pred_labels, ref_labels = semantic_pair._pred_labels, semantic_pair._ref_labels + pred_labels, ref_labels = semantic_pair.pred_labels, semantic_pair.ref_labels pred_label_range = ( (np.min(pred_labels), np.max(pred_labels)) if len(pred_labels) > 0 @@ -95,10 +96,10 @@ def approximate_instances( min_value >= 0 ), "There are negative values in the semantic maps. This is not allowed!" # Set dtype to smalles fitting uint - max_value = max(np.max(pred_label_range[1]), np.max(ref_label_range[1])) - dtype = _get_smallest_fitting_uint(max_value) - semantic_pair.set_dtype(dtype) - print(f"-- Set dtype to {dtype}") if verbose else None + # max_value = max(np.max(pred_label_range[1]), np.max(ref_label_range[1])) + # dtype = _get_smallest_fitting_uint(max_value) + # semantic_pair.set_dtype(dtype) + # print(f"-- Set dtype to {dtype}") if verbose else None # Call algorithm instance_pair = self._approximate_instances(semantic_pair, **kwargs) @@ -153,26 +154,22 @@ def _approximate_instances( ) assert cca_backend is not None - empty_prediction = len(semantic_pair._pred_labels) == 0 - empty_reference = len(semantic_pair._ref_labels) == 0 + empty_prediction = len(semantic_pair.pred_labels) == 0 + empty_reference = len(semantic_pair.ref_labels) == 0 prediction_arr, n_prediction_instance = ( - _connected_components(semantic_pair._prediction_arr, cca_backend) + _connected_components(semantic_pair.prediction_arr, cca_backend) if not empty_prediction - else (semantic_pair._prediction_arr, 0) + else (semantic_pair.prediction_arr, 0) ) reference_arr, n_reference_instance = ( - _connected_components(semantic_pair._reference_arr, cca_backend) + _connected_components(semantic_pair.reference_arr, cca_backend) if not empty_reference - else (semantic_pair._reference_arr, 0) - ) - - dtype = _get_smallest_fitting_uint( - max(prediction_arr.max(), reference_arr.max()) + else (semantic_pair.reference_arr, 0) ) return UnmatchedInstancePair( - prediction_arr=prediction_arr.astype(dtype), - reference_arr=reference_arr.astype(dtype), + prediction_arr=prediction_arr, + reference_arr=reference_arr, n_prediction_instance=n_prediction_instance, n_reference_instance=n_reference_instance, ) diff --git a/panoptica/instance_matcher.py b/panoptica/instance_matcher.py index 25534c0..bb3d069 100644 --- a/panoptica/instance_matcher.py +++ b/panoptica/instance_matcher.py @@ -131,7 +131,7 @@ def map_instance_labels( # Build a MatchedInstancePair out of the newly derived data matched_instance_pair = MatchedInstancePair( prediction_arr=prediction_arr_relabeled, - reference_arr=processing_pair._reference_arr, + reference_arr=processing_pair.reference_arr, ) return matched_instance_pair diff --git a/panoptica/panoptica_aggregator.py b/panoptica/panoptica_aggregator.py index 0c0c527..e5782c6 100644 --- a/panoptica/panoptica_aggregator.py +++ b/panoptica/panoptica_aggregator.py @@ -192,7 +192,7 @@ def _save_one_subject(self, subject_name, result_grouped): # content = [subject_name] for groupname in self.__class_group_names: - result: PanopticaResult = result_grouped[groupname][0] + result: PanopticaResult = result_grouped[groupname] result_dict = result.to_dict() if result.computation_time is not None: result_dict[COMPUTATION_TIME_KEY] = result.computation_time diff --git a/panoptica/panoptica_evaluator.py b/panoptica/panoptica_evaluator.py index 9730501..6849f82 100644 --- a/panoptica/panoptica_evaluator.py +++ b/panoptica/panoptica_evaluator.py @@ -3,7 +3,7 @@ from panoptica.instance_approximator import InstanceApproximator from panoptica.instance_evaluator import evaluate_matched_instance from panoptica.instance_matcher import InstanceMatchingAlgorithm -from panoptica.metrics import Metric, _Metric +from panoptica.metrics import Metric from panoptica.panoptica_result import PanopticaResult from panoptica.utils.timing import measure_time from panoptica.utils import EdgeCaseHandler @@ -12,7 +12,6 @@ MatchedInstancePair, SemanticPair, UnmatchedInstancePair, - _ProcessingPair, InputType, EvaluateInstancePair, IntermediateStepsData, @@ -54,14 +53,18 @@ def __init__( expected_input (type, optional): Expected DataPair Input Type. Defaults to InputType.MATCHED_INSTANCE (which is type(MatchedInstancePair)). instance_approximator (InstanceApproximator | None, optional): Determines which instance approximator is used if necessary. Defaults to None. instance_matcher (InstanceMatchingAlgorithm | None, optional): Determines which instance matching algorithm is used if necessary. Defaults to None. - iou_threshold (float, optional): Iou Threshold for evaluation. Defaults to 0.5. + edge_case_handler (edge_case_handler, optional): EdgeCaseHandler to be used. If none, will create the default one - segmentation_class_groups (SegmentationClassGroups, optional): If not none, will evaluate per class group defined, instead of over all at the same time. + segmentation_class_groups (SegmentationClassGroups, optional): If not none, will evaluate per class group defined, instead of over all at the same time. A class group is a collection of labels that are considered of the same class / structure. + instance_metrics (list[Metric]): List of all metrics that should be calculated between all instances global_metrics (list[Metric]): List of all metrics that should be calculated on the global binary masks + decision_metric: (Metric | None, optional): This metric is the final decision point between True Positive and False Positive. Can be left away if the matching algorithm is used (it will match by a metric and threshold already) decision_threshold: (float | None, optional): Threshold for the decision_metric - log_times (bool): If true, will printout the times for the different phases of the pipeline. + + save_group_times(bool): If true, will save the computation time of each sample and put that into the result object. + log_times (bool): If true, will print the times for the different phases of the pipeline. verbose (bool): If true, will spit out more details than you want. """ self.__expected_input = expected_input @@ -117,7 +120,7 @@ def evaluate( save_group_times: bool | None = None, log_times: bool | None = None, verbose: bool | None = None, - ) -> dict[str, tuple[PanopticaResult, IntermediateStepsData]]: + ) -> dict[str, PanopticaResult]: processing_pair = self.__expected_input(prediction_arr, reference_arr) assert isinstance( processing_pair, self.__expected_input.value @@ -130,7 +133,7 @@ def evaluate( processing_pair.reference_arr, raise_error=True ) - result_grouped: dict[str, tuple[PanopticaResult, IntermediateStepsData]] = {} + result_grouped: dict[str, PanopticaResult] = {} for group_name, label_group in self.__segmentation_class_groups.items(): result_grouped[group_name] = self._evaluate_group( group_name, @@ -144,7 +147,7 @@ def evaluate( ), log_times=log_times, verbose=verbose, - )[1:] + ) return result_grouped @property @@ -166,7 +169,7 @@ def resulting_metric_keys(self) -> list[str]: dummy_input = MatchedInstancePair( np.ones((1, 1, 1), dtype=np.uint8), np.ones((1, 1, 1), dtype=np.uint8) ) - _, res, _ = self._evaluate_group( + res = self._evaluate_group( group_name="", label_group=LabelGroup(1, single_instance=False), processing_pair=dummy_input, @@ -188,7 +191,7 @@ def _evaluate_group( verbose: bool | None = None, log_times: bool | None = None, save_group_times: bool = False, - ): + ) -> PanopticaResult: assert isinstance(label_group, LabelGroup) if self.__save_group_times: start_time = perf_counter() @@ -208,7 +211,7 @@ def _evaluate_group( ) decision_threshold = 0.0 - result, intermediate_steps_data = panoptic_evaluate( + result = panoptic_evaluate( input_pair=processing_pair_grouped, edge_case_handler=self.__edge_case_handler, instance_approximator=self.__instance_approximator, @@ -225,7 +228,7 @@ def _evaluate_group( if save_group_times: duration = perf_counter() - start_time result.computation_time = duration - return group_name, result, intermediate_steps_data + return result def panoptic_evaluate( @@ -242,7 +245,7 @@ def panoptic_evaluate( verbose=False, verbose_calc=False, **kwargs, -) -> tuple[PanopticaResult, IntermediateStepsData]: +) -> PanopticaResult: """ Perform panoptic evaluation on the given processing pair. @@ -364,13 +367,14 @@ def panoptic_evaluate( list_metrics=processing_pair.list_metrics, global_metrics=global_metrics, edge_case_handler=edge_case_handler, + intermediate_steps_data=intermediate_steps_data, ) if isinstance(processing_pair, PanopticaResult): processing_pair._global_metrics = global_metrics if result_all: processing_pair.calculate_all(print_errors=verbose_calc) - return processing_pair, intermediate_steps_data + return processing_pair raise RuntimeError("End of panoptic pipeline reached without results") diff --git a/panoptica/panoptica_result.py b/panoptica/panoptica_result.py index da9c884..901b8d6 100644 --- a/panoptica/panoptica_result.py +++ b/panoptica/panoptica_result.py @@ -13,6 +13,7 @@ MetricType, ) from panoptica.utils import EdgeCaseHandler +from panoptica.utils.processing_pair import IntermediateStepsData class PanopticaResult(object): @@ -27,6 +28,7 @@ def __init__( list_metrics: dict[Metric, list[float]], edge_case_handler: EdgeCaseHandler, global_metrics: list[Metric] = [], + intermediate_steps_data: IntermediateStepsData | None = None, computation_time: float | None = None, ): """Result object for Panoptica, contains all calculatable metrics @@ -45,6 +47,7 @@ def __init__( empty_list_std = self._edge_case_handler.handle_empty_list_std().value self._global_metrics: list[Metric] = global_metrics self.computation_time = computation_time + self.intermediate_steps_data = intermediate_steps_data ###################### # Evaluation Metrics # ###################### diff --git a/panoptica/panoptica_statistics.py b/panoptica/panoptica_statistics.py index 9a35a49..165fd0b 100644 --- a/panoptica/panoptica_statistics.py +++ b/panoptica/panoptica_statistics.py @@ -5,7 +5,6 @@ try: import pandas as pd - import matplotlib.pyplot as plt import plotly.express as px import plotly.graph_objects as go except Exception as e: @@ -278,11 +277,14 @@ def get_summary_figure( manual_metric_range=manual_metric_range, ) - # groupwise or in total - # Mean over instances - # mean over subjects - # give below/above percentile of metric (the names) - # make auc curve as plot + +def make_autc_plots( + statistics_dict: dict[str | int | float, Panoptica_Statistic], + metric: str, + groups: list[str] | str | None = None, + alternate_groupnames: list[str] | str | None = None, +): + raise NotImplementedError("AUTC plots currently in works") def make_curve_over_setups( @@ -290,11 +292,15 @@ def make_curve_over_setups( metric: str, groups: list[str] | str | None = None, alternate_groupnames: list[str] | str | None = None, - fig: None = None, - plot_dotsize: int | None = None, - plot_lines: bool = True, - plot_std: bool = False, + fig: go.Figure | None = None, + plot_as_barchart=True, + plot_std: bool = True, + figure_title: str = "", + width: int = 850, + height: int = 1200, + manual_metric_range: None | tuple[float, float] = None, ): + # TODO make this flexibel whether the second grouping are the groups or metrics? if groups is None: groups = list(statistics_dict.values())[0].groupnames # @@ -302,6 +308,10 @@ def make_curve_over_setups( groups = [groups] if isinstance(alternate_groupnames, str): alternate_groupnames = [alternate_groupnames] + + assert ( + plot_as_barchart or len(groups) == 1 + ), "When plotting without barcharts, you cannot plot more than one group at the same time" # for setupname, stat in statistics_dict.items(): assert ( @@ -319,47 +329,58 @@ def make_curve_over_setups( if convert_x_to_digit: X = [float(s) for s in setupnames] else: - X = range(len(setupnames)) + X = setupnames if fig is None: - fig = plt.figure() - - if not convert_x_to_digit: - plt.xticks(X, setupnames) + fig = go.Figure() - plt.ylabel("Average " + metric) - plt.grid("major") # Y values are average metric values in that group and metric for idx, g in enumerate(groups): Y = [ ValueSummary(stat.get(g, metric, remove_nones=True)).avg for stat in statistics_dict.values() ] - Ystd = [ - ValueSummary(stat.get(g, metric, remove_nones=True)).std - for stat in statistics_dict.values() - ] - if plot_lines: - p = plt.plot( - X, - Y, - label=g if alternate_groupnames is None else alternate_groupnames[idx], - ) + name = g if alternate_groupnames is None else alternate_groupnames[idx] - if plot_std: - plt.fill_between( - X, - np.subtract(Y, Ystd), - np.add(Y, Ystd), - alpha=0.25, - edgecolor=p[-1].get_color(), - ) + if plot_std: + Ystd = [ + ValueSummary(stat.get(g, metric, remove_nones=True)).std + for stat in statistics_dict.values() + ] + else: + Ystd = None - if plot_dotsize is not None: - plt.scatter(X, Y, s=plot_dotsize) + if plot_as_barchart: + fig.add_trace( + go.Bar(name=name, x=X, y=Y, error_y=dict(type="data", array=Ystd)) + ) + else: + # lineplot + fig.add_trace( + go.Scatter( + x=X, + y=Y, + mode="lines+markers", + name="lines+markers", + error_y=dict(type="data", array=Ystd), + ) + ) - plt.legend() + fig.update_layout( + autosize=False, + barmode="group", + width=width, + height=height, + showlegend=True, + yaxis_title=metric, + xaxis_title="Different setups and groups", + font={"family": "Arial"}, + title=figure_title, + ) + fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor="gray") + if manual_metric_range is not None: + fig.update_xaxes(range=[manual_metric_range[0], manual_metric_range[1]]) return fig diff --git a/panoptica/utils/label_group.py b/panoptica/utils/label_group.py index 0008450..6eea08a 100644 --- a/panoptica/utils/label_group.py +++ b/panoptica/utils/label_group.py @@ -111,7 +111,8 @@ def _yaml_repr(cls, node): class LabelMergeGroup(LabelGroup): """Defines a group of labels that will be merged into a single label when extracted. - Inherits from LabelGroup and sets extracted labels to binary format. + Inherits from LabelGroup and sets extracted labels to a binary format. + This is useful for region-evaluation (e.g. BRATS), where you want to merge multiple labels into one before evaluation. Methods: __call__(array): Extracts the label group as a binary array. diff --git a/panoptica/utils/processing_pair.py b/panoptica/utils/processing_pair.py index 1e24062..aec5a06 100644 --- a/panoptica/utils/processing_pair.py +++ b/panoptica/utils/processing_pair.py @@ -7,6 +7,7 @@ from panoptica.utils.constants import _Enum_Compare from dataclasses import dataclass from panoptica.metrics import Metric +from panoptica.utils.numpy_utils import _get_smallest_fitting_uint uint_type: type = np.unsignedinteger int_type: type = np.integer @@ -26,16 +27,7 @@ class _ProcessingPair(ABC): uncropped_shape (tuple[int, ...]): The original shape of the arrays before cropping. """ - _prediction_arr: np.ndarray - _reference_arr: np.ndarray - # unique labels without zero - _ref_labels: tuple[int, ...] - _pred_labels: tuple[int, ...] - n_dim: int - - def __init__( - self, prediction_arr: np.ndarray, reference_arr: np.ndarray, dtype: type | None - ) -> None: + def __init__(self, prediction_arr: np.ndarray, reference_arr: np.ndarray) -> None: """Initializes the processing pair with prediction and reference arrays. Args: @@ -43,20 +35,25 @@ def __init__( reference_arr (np.ndarray): Numpy array of reference labels. dtype (type | None): The expected datatype of arrays. If None, no datatype check is performed. """ - _check_array_integrity(prediction_arr, reference_arr, dtype=dtype) - self._prediction_arr = prediction_arr - self._reference_arr = reference_arr - self.dtype = dtype - self.n_dim = reference_arr.ndim - self._ref_labels: tuple[int, ...] = tuple( + self.__prediction_arr: np.ndarray = prediction_arr + self.__reference_arr: np.ndarray = reference_arr + _check_array_integrity( + self.__prediction_arr, self.__reference_arr, dtype=int_type + ) + max_value = max(prediction_arr.max(), reference_arr.max()) + dtype = _get_smallest_fitting_uint(max_value) + self.set_dtype(dtype) + self.__dtype = dtype + self.__n_dim: int = reference_arr.ndim + self.__ref_labels: tuple[int, ...] = tuple( _unique_without_zeros(reference_arr) ) # type:ignore - self._pred_labels: tuple[int, ...] = tuple( + self.__pred_labels: tuple[int, ...] = tuple( _unique_without_zeros(prediction_arr) ) # type:ignore - self.crop: tuple[slice, ...] = None - self.is_cropped: bool = False - self.uncropped_shape: tuple[int, ...] = reference_arr.shape + self.__crop: tuple[slice, ...] = None + self.__is_cropped: bool = False + self.__uncropped_shape: tuple[int, ...] = reference_arr.shape def crop_data(self, verbose: bool = False): """Crops prediction and reference arrays to non-zero regions. @@ -64,25 +61,25 @@ def crop_data(self, verbose: bool = False): Args: verbose (bool, optional): If True, prints cropping details. Defaults to False. """ - if self.is_cropped: + if self.__is_cropped: return - if self.crop is None: - self.uncropped_shape = self._prediction_arr.shape - self.crop = _get_paired_crop( - self._prediction_arr, - self._reference_arr, + if self.__crop is None: + self.__uncropped_shape = self.__prediction_arr.shape + self.__crop = _get_paired_crop( + self.__prediction_arr, + self.__reference_arr, ) - self._prediction_arr = self._prediction_arr[self.crop] - self._reference_arr = self._reference_arr[self.crop] + self.__prediction_arr = self.__prediction_arr[self.__crop] + self.__reference_arr = self.__reference_arr[self.__crop] ( print( - f"-- Cropped from {self.uncropped_shape} to {self._prediction_arr.shape}" + f"-- Cropped from {self.__uncropped_shape} to {self.__prediction_arr.shape}" ) if verbose else None ) - self.is_cropped = True + self.__is_cropped = True def uncrop_data(self, verbose: bool = False): """Restores the arrays to their original, uncropped shape. @@ -90,26 +87,26 @@ def uncrop_data(self, verbose: bool = False): Args: verbose (bool, optional): If True, prints uncropping details. Defaults to False. """ - if self.is_cropped == False: + if self.__is_cropped == False: return assert ( - self.uncropped_shape is not None + self.__uncropped_shape is not None ), "Calling uncrop_data() without having cropped first" - prediction_arr = np.zeros(self.uncropped_shape) - prediction_arr[self.crop] = self._prediction_arr - self._prediction_arr = prediction_arr + prediction_arr = np.zeros(self.__uncropped_shape) + prediction_arr[self.__crop] = self.__prediction_arr + self.__prediction_arr = prediction_arr - reference_arr = np.zeros(self.uncropped_shape) - reference_arr[self.crop] = self._reference_arr + reference_arr = np.zeros(self.__uncropped_shape) + reference_arr[self.__crop] = self.__reference_arr ( print( - f"-- Uncropped from {self._reference_arr.shape} to {self.uncropped_shape}" + f"-- Uncropped from {self.__reference_arr.shape} to {self.__uncropped_shape}" ) if verbose else None ) - self._reference_arr = reference_arr - self.is_cropped = False + self.__reference_arr = reference_arr + self.__is_cropped = False def set_dtype(self, type): """Sets the data type for both prediction and reference arrays. @@ -120,43 +117,38 @@ def set_dtype(self, type): assert np.issubdtype( type, int_type ), "set_dtype: tried to set dtype to something other than integers" - self._prediction_arr = self._prediction_arr.astype(type) - self._reference_arr = self._reference_arr.astype(type) + self.__prediction_arr = self.__prediction_arr.astype(type) + self.__reference_arr = self.__reference_arr.astype(type) @property def prediction_arr(self): - return self._prediction_arr + return self.__prediction_arr @property def reference_arr(self): - return self._reference_arr + return self.__reference_arr @property def pred_labels(self): - return self._pred_labels + return self.__pred_labels @property def ref_labels(self): - return self._ref_labels + return self.__ref_labels + + @property + def n_dim(self): + return self.__n_dim def copy(self): """ Creates an exact copy of this object """ return type(self)( - prediction_arr=self._prediction_arr, - reference_arr=self._reference_arr, + prediction_arr=self.__prediction_arr, + reference_arr=self.__reference_arr, ) # type:ignore - # Make all variables read-only! - # def __setattr__(self, attr, value): - # if hasattr(self, attr): - # raise Exception("Attempting to alter read-only value") - - -# -# self.__dict__[attr] = value - class _ProcessingPairInstanced(_ProcessingPair): """Represents a processing pair with labeled instances, including unique label counts. @@ -175,7 +167,6 @@ def __init__( self, prediction_arr: np.ndarray, reference_arr: np.ndarray, - dtype: type | None, n_prediction_instance: int | None = None, n_reference_instance: int | None = None, ) -> None: @@ -188,7 +179,7 @@ def __init__( n_prediction_instance (int | None, optional): Pre-calculated number of prediction instances. n_reference_instance (int | None, optional): Pre-calculated number of reference instances. """ - super().__init__(prediction_arr, reference_arr, dtype) + super().__init__(prediction_arr, reference_arr) if n_prediction_instance is None: self.n_prediction_instance = _count_unique_without_zeros(prediction_arr) @@ -204,8 +195,8 @@ def copy(self): Creates an exact copy of this object """ return type(self)( - prediction_arr=self._prediction_arr, - reference_arr=self._reference_arr, + prediction_arr=self.prediction_arr, + reference_arr=self.reference_arr, n_prediction_instance=self.n_prediction_instance, n_reference_instance=self.n_reference_instance, ) # type:ignore @@ -237,6 +228,14 @@ def _check_array_integrity( assert ( prediction_arr.shape == reference_arr.shape ), f"shape mismatch, got {prediction_arr.shape},{reference_arr.shape}" + + min_value = min(prediction_arr.min(), reference_arr.min()) + assert ( + min_value >= 0 + ), "There are negative values in the semantic maps. This is not allowed!" + + # if prediction_arr.dtype != reference_arr.dtype: + # print(f"Dtype is equal in prediction and reference, got {prediction_arr.dtype},{reference_arr.dtype}. Intended?") # assert prediction_arr.dtype == reference_arr.dtype, f"dtype mismatch, got {prediction_arr.dtype},{reference_arr.dtype}" if dtype is not None: assert ( @@ -253,7 +252,7 @@ class SemanticPair(_ProcessingPair): """ def __init__(self, prediction_arr: np.ndarray, reference_arr: np.ndarray) -> None: - super().__init__(prediction_arr, reference_arr, dtype=int_type) + super().__init__(prediction_arr, reference_arr) class UnmatchedInstancePair(_ProcessingPairInstanced): @@ -272,7 +271,6 @@ def __init__( super().__init__( prediction_arr, reference_arr, - uint_type, n_prediction_instance, n_reference_instance, ) # type:ignore @@ -320,23 +318,22 @@ def __init__( super().__init__( prediction_arr, reference_arr, - uint_type, n_prediction_instance, n_reference_instance, ) # type:ignore if matched_instances is None: - matched_instances = [i for i in self._pred_labels if i in self._ref_labels] + matched_instances = [i for i in self.pred_labels if i in self.ref_labels] self.matched_instances = matched_instances if missed_reference_labels is None: missed_reference_labels = list( - [i for i in self._ref_labels if i not in self._pred_labels] + [i for i in self.ref_labels if i not in self.pred_labels] ) self.missed_reference_labels = missed_reference_labels if missed_prediction_labels is None: missed_prediction_labels = list( - [i for i in self._pred_labels if i not in self._ref_labels] + [i for i in self.pred_labels if i not in self.ref_labels] ) self.missed_prediction_labels = missed_prediction_labels @@ -349,8 +346,8 @@ def copy(self): Creates an exact copy of this object """ return type(self)( - prediction_arr=self._prediction_arr.copy(), - reference_arr=self._reference_arr.copy(), + prediction_arr=self.prediction_arr.copy(), + reference_arr=self.reference_arr.copy(), n_prediction_instance=self.n_prediction_instance, n_reference_instance=self.n_reference_instance, missed_reference_labels=self.missed_reference_labels, diff --git a/panoptica/utils/segmentation_class.py b/panoptica/utils/segmentation_class.py index a0a32b9..191fd5d 100644 --- a/panoptica/utils/segmentation_class.py +++ b/panoptica/utils/segmentation_class.py @@ -9,8 +9,9 @@ class SegmentationClassGroups(SupportsConfig): """Represents a collection of segmentation class groups. - This class manages groups of labels used in segmentation tasks, ensuring that each label is defined - exactly once across all groups. It supports both list and dictionary formats for group initialization. + This class manages groups of labels used in segmentation tasks. It supports both list and dictionary formats for group initialization. + SegmentationClassGroups are a collection of LabelGroups with names. So it maps a group name (str) to a LabelGroup. + A LabelGroup defines a collection of labels that belong to the same structure / region of interest. Attributes: __group_dictionary (dict[str, LabelGroup]): A dictionary mapping group names to their respective LabelGroup instances. @@ -20,9 +21,6 @@ class SegmentationClassGroups(SupportsConfig): groups (list[LabelGroup] | dict[str, LabelGroup | tuple[list[int] | int, bool]]): A list of `LabelGroup` instances or a dictionary where keys are group names (str) and values are either `LabelGroup` instances or tuples containing a list of label values and a boolean. - - Raises: - AssertionError: If the same label is assigned to multiple groups. """ def __init__( diff --git a/pyproject.toml b/pyproject.toml index 51863bd..3c638cd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,6 @@ coverage = ">=7.0.1" pytest-mock = "^3.6.0" pandas = "^2.1.0" plotly = "^5.16.1" -matplotlib = "^3.7.3" joblib = "^1.3.2" future = "^0.18.3" flake8 = ">=4.0.1" diff --git a/unit_tests/test_datatype.py b/unit_tests/test_datatype.py index e938581..536dbf0 100644 --- a/unit_tests/test_datatype.py +++ b/unit_tests/test_datatype.py @@ -4,6 +4,7 @@ # coverage html import os import unittest +import numpy as np from panoptica.metrics import ( Metric, @@ -16,6 +17,7 @@ EdgeCaseHandler, MetricZeroTPEdgeCaseHandling, ) +from panoptica import InputType class Test_EdgeCaseHandler(unittest.TestCase): @@ -47,7 +49,7 @@ def test_edgecasehandler_simple(self): # print(t) -class Test_Datatypes(unittest.TestCase): +class Test_Enums(unittest.TestCase): def setUp(self) -> None: os.environ["PANOPTICA_CITATION_REMINDER"] = "False" return super().setUp() @@ -114,3 +116,45 @@ def test_listmetric_emptylist(self): for mode in MetricMode: with self.assertRaises(MetricCouldNotBeComputedException): lmetric[mode] + + +class Test_ProcessingPair(unittest.TestCase): + def setUp(self) -> None: + os.environ["PANOPTICA_CITATION_REMINDER"] = "False" + return super().setUp() + + def test_semanticpair(self): + ddtypes = [ + np.int8, + np.int16, + np.int32, + np.int64, + np.uint8, + np.uint16, + np.uint32, + np.uint64, + ] + dtype_combinations = [(a, b) for a in ddtypes for b in ddtypes] + for da, db in dtype_combinations: + a = np.zeros([50, 50], dtype=da) + b = a.copy().astype(db) + a[20:40, 10:20] = 1 + b[20:35, 10:20] = 2 + + for it in InputType: + # SemanticPair accepts everything + # For Unmatched and MatchedInstancePair, the numpys must be uints! + it(a, b) + + c = -a + d = -b + + if c.min() < 0: + with self.assertRaises(AssertionError): + it(c, b) + if d.min() < 0: + with self.assertRaises(AssertionError): + it(a, d) + if c.min() < 0 or d.min() < 0: + with self.assertRaises(AssertionError): + it(c, d) diff --git a/unit_tests/test_panoptic_evaluator.py b/unit_tests/test_panoptic_evaluator.py index 1038723..dea5c13 100644 --- a/unit_tests/test_panoptic_evaluator.py +++ b/unit_tests/test_panoptic_evaluator.py @@ -63,7 +63,7 @@ def test_simple_evaluation(self): instance_matcher=NaiveThresholdMatching(), ) - result, debug_data = evaluator.evaluate(b, a)["ungrouped"] + result = evaluator.evaluate(b, a)["ungrouped"] print(result) self.assertEqual(result.tp, 1) self.assertEqual(result.fp, 0) @@ -83,7 +83,7 @@ def test_simple_evaluation_instance_multiclass(self): instance_matcher=NaiveThresholdMatching(), ) - result, debug_data = evaluator.evaluate(b, a)["ungrouped"] + result = evaluator.evaluate(b, a)["ungrouped"] print(result) self.assertAlmostEqual(result.global_bin_dsc, 0.8571428571428571) self.assertEqual(result.tp, 1) @@ -104,7 +104,7 @@ def test_simple_evaluation_DSC(self): instance_matcher=NaiveThresholdMatching(), ) - result, debug_data = evaluator.evaluate(b, a)["ungrouped"] + result = evaluator.evaluate(b, a)["ungrouped"] print(result) self.assertEqual(result.tp, 1) self.assertEqual(result.fp, 0) @@ -124,7 +124,7 @@ def test_simple_evaluation_DSC_partial(self): instance_metrics=[Metric.DSC], ) - result, debug_data = evaluator.evaluate(b, a)["ungrouped"] + result = evaluator.evaluate(b, a)["ungrouped"] print(result) self.assertEqual(result.tp, 1) self.assertEqual(result.fp, 0) @@ -150,7 +150,7 @@ def test_simple_evaluation_ASSD(self): ), ) - result, debug_data = evaluator.evaluate(b, a)["ungrouped"] + result = evaluator.evaluate(b, a)["ungrouped"] print(result) self.assertEqual(result.tp, 1) self.assertEqual(result.fp, 0) @@ -172,7 +172,7 @@ def test_simple_evaluation_ASSD_negative(self): ), ) - result, debug_data = evaluator.evaluate(b, a)["ungrouped"] + result = evaluator.evaluate(b, a)["ungrouped"] print(result) self.assertEqual(result.tp, 0) self.assertEqual(result.fp, 1) @@ -192,7 +192,7 @@ def test_pred_empty(self): instance_matcher=NaiveThresholdMatching(), ) - result, debug_data = evaluator.evaluate(b, a)["ungrouped"] + result = evaluator.evaluate(b, a)["ungrouped"] print(result) self.assertEqual(result.tp, 0) self.assertEqual(result.fp, 0) @@ -214,7 +214,7 @@ def test_no_TP_but_overlap(self): instance_matcher=NaiveThresholdMatching(), ) - result, debug_data = evaluator.evaluate(b, a)["ungrouped"] + result = evaluator.evaluate(b, a)["ungrouped"] print(result) self.assertEqual(result.tp, 0) self.assertEqual(result.fp, 1) @@ -237,7 +237,7 @@ def test_ref_empty(self): instance_matcher=NaiveThresholdMatching(), ) - result, debug_data = evaluator.evaluate(b, a)["ungrouped"] + result = evaluator.evaluate(b, a)["ungrouped"] print(result) self.assertEqual(result.tp, 0) self.assertEqual(result.fp, 1) @@ -258,7 +258,7 @@ def test_both_empty(self): instance_matcher=NaiveThresholdMatching(), ) - result, debug_data = evaluator.evaluate(b, a)["ungrouped"] + result = evaluator.evaluate(b, a)["ungrouped"] print(result) self.assertEqual(result.tp, 0) self.assertEqual(result.fp, 0) @@ -285,22 +285,18 @@ def test_dtype_evaluation(self): a[20:40, 10:20] = 1 b[20:35, 10:20] = 2 - if da != db: - self.assertRaises(AssertionError, SemanticPair, b, a) - else: + evaluator = Panoptica_Evaluator( + expected_input=InputType.SEMANTIC, + instance_approximator=ConnectedComponentsInstanceApproximator(), + instance_matcher=NaiveThresholdMatching(), + ) - evaluator = Panoptica_Evaluator( - expected_input=InputType.SEMANTIC, - instance_approximator=ConnectedComponentsInstanceApproximator(), - instance_matcher=NaiveThresholdMatching(), - ) - - result, debug_data = evaluator.evaluate(b, a)["ungrouped"] - print(result) - self.assertEqual(result.tp, 1) - self.assertEqual(result.fp, 0) - self.assertEqual(result.sq, 0.75) - self.assertEqual(result.pq, 0.75) + result = evaluator.evaluate(b, a)["ungrouped"] + print(result) + self.assertEqual(result.tp, 1) + self.assertEqual(result.fp, 0) + self.assertEqual(result.sq, 0.75) + self.assertEqual(result.pq, 0.75) def test_simple_evaluation_maximize_matcher(self): a = np.zeros([50, 50], dtype=np.uint16) @@ -314,7 +310,7 @@ def test_simple_evaluation_maximize_matcher(self): instance_matcher=MaximizeMergeMatching(), ) - result, debug_data = evaluator.evaluate(b, a)["ungrouped"] + result = evaluator.evaluate(b, a)["ungrouped"] print(result) self.assertEqual(result.tp, 1) self.assertEqual(result.fp, 0) @@ -334,7 +330,7 @@ def test_simple_evaluation_maximize_matcher_overlaptwo(self): instance_matcher=MaximizeMergeMatching(), ) - result, debug_data = evaluator.evaluate(b, a)["ungrouped"] + result = evaluator.evaluate(b, a)["ungrouped"] print(result) self.assertEqual(result.tp, 1) self.assertEqual(result.fp, 0) @@ -356,7 +352,7 @@ def test_simple_evaluation_maximize_matcher_overlap(self): instance_matcher=MaximizeMergeMatching(), ) - result, debug_data = evaluator.evaluate(b, a)["ungrouped"] + result = evaluator.evaluate(b, a)["ungrouped"] print(result) self.assertEqual(result.tp, 1) self.assertEqual(result.fp, 1) @@ -378,7 +374,7 @@ def test_single_instance_mode(self): segmentation_class_groups=SegmentationClassGroups({"organ": (5, True)}), ) - result, debug_data = evaluator.evaluate(b, a)["organ"] + result = evaluator.evaluate(b, a)["organ"] print(result) self.assertEqual(result.tp, 1) self.assertEqual(result.fp, 0) @@ -398,7 +394,7 @@ def test_single_instance_mode_nooverlap(self): segmentation_class_groups=SegmentationClassGroups({"organ": (5, True)}), ) - result, debug_data = evaluator.evaluate(b, a)["organ"] + result = evaluator.evaluate(b, a)["organ"] print(result) self.assertEqual(result.tp, 1) self.assertEqual(result.fp, 0)