From 5808e462b2f5821f9c47dae49b847f6fc61301b7 Mon Sep 17 00:00:00 2001 From: "brainless-bot[bot]" <153751247+brainless-bot[bot]@users.noreply.github.com> Date: Wed, 7 Aug 2024 15:42:48 +0000 Subject: [PATCH] Autoformat with black --- examples/example_spine_instance_config.py | 4 +- examples/example_spine_semantic.py | 4 +- panoptica/instance_evaluator.py | 18 ++++-- panoptica/metrics/metrics.py | 70 +++++++++++++++++------ panoptica/panoptica_evaluator.py | 43 ++++++++++---- panoptica/panoptica_result.py | 26 +++++++-- panoptica/utils/edge_case_handling.py | 36 +++++++++--- panoptica/utils/processing_pair.py | 60 ++++++++++++++----- unit_tests/test_panoptic_result.py | 4 +- 9 files changed, 205 insertions(+), 60 deletions(-) diff --git a/examples/example_spine_instance_config.py b/examples/example_spine_instance_config.py index 2855add..56c61e4 100644 --- a/examples/example_spine_instance_config.py +++ b/examples/example_spine_instance_config.py @@ -10,7 +10,9 @@ reference_mask = read_nifti(directory + "/spine_seg/matched_instance/ref.nii.gz") prediction_mask = read_nifti(directory + "/spine_seg/matched_instance/pred.nii.gz") -evaluator = Panoptica_Evaluator.load_from_config_name("panoptica_evaluator_unmatched_instance") +evaluator = Panoptica_Evaluator.load_from_config_name( + "panoptica_evaluator_unmatched_instance" +) with cProfile.Profile() as pr: diff --git a/examples/example_spine_semantic.py b/examples/example_spine_semantic.py index 2793592..a2e5a32 100644 --- a/examples/example_spine_semantic.py +++ b/examples/example_spine_semantic.py @@ -25,7 +25,9 @@ with cProfile.Profile() as pr: if __name__ == "__main__": - result, debug_data = evaluator.evaluate(prediction_mask, reference_mask)["ungrouped"] + result, debug_data = evaluator.evaluate(prediction_mask, reference_mask)[ + "ungrouped" + ] print(result) pr.dump_stats(directory + "/semantic_example.log") diff --git a/panoptica/instance_evaluator.py b/panoptica/instance_evaluator.py index 9b784e7..303abf0 100644 --- a/panoptica/instance_evaluator.py +++ b/panoptica/instance_evaluator.py @@ -28,7 +28,9 @@ def evaluate_matched_instance( >>> result = map_instance_labels(unmatched_instance_pair, labelmap) """ if decision_metric is not None: - assert decision_metric.name in [v.name for v in eval_metrics], "decision metric not contained in eval_metrics" + assert decision_metric.name in [ + v.name for v in eval_metrics + ], "decision metric not contained in eval_metrics" assert decision_threshold is not None, "decision metric set but no threshold" # Initialize variables for True Positives (tp) tp = len(matched_instance_pair.matched_instances) @@ -40,13 +42,21 @@ def evaluate_matched_instance( ) ref_matched_labels = matched_instance_pair.matched_instances - instance_pairs = [(reference_arr, prediction_arr, ref_idx, eval_metrics) for ref_idx in ref_matched_labels] + instance_pairs = [ + (reference_arr, prediction_arr, ref_idx, eval_metrics) + for ref_idx in ref_matched_labels + ] with Pool() as pool: - metric_dicts: list[dict[Metric, float]] = pool.starmap(_evaluate_instance, instance_pairs) + metric_dicts: list[dict[Metric, float]] = pool.starmap( + _evaluate_instance, instance_pairs + ) for metric_dict in metric_dicts: if decision_metric is None or ( - decision_threshold is not None and decision_metric.score_beats_threshold(metric_dict[decision_metric], decision_threshold) + decision_threshold is not None + and decision_metric.score_beats_threshold( + metric_dict[decision_metric], decision_threshold + ) ): for k, v in metric_dict.items(): score_dict[k].append(v) diff --git a/panoptica/metrics/metrics.py b/panoptica/metrics/metrics.py index 4f02e48..5cbce45 100644 --- a/panoptica/metrics/metrics.py +++ b/panoptica/metrics/metrics.py @@ -40,7 +40,9 @@ def __call__( reference_arr = reference_arr.copy() == ref_instance_idx if isinstance(pred_instance_idx, int): pred_instance_idx = [pred_instance_idx] - prediction_arr = np.isin(prediction_arr.copy(), pred_instance_idx) # type:ignore + prediction_arr = np.isin( + prediction_arr.copy(), pred_instance_idx + ) # type:ignore return self._metric_function(reference_arr, prediction_arr, *args, **kwargs) def __eq__(self, __value: object) -> bool: @@ -64,8 +66,12 @@ def __hash__(self) -> int: def increasing(self): return not self.decreasing - def score_beats_threshold(self, matching_score: float, matching_threshold: float) -> bool: - return (self.increasing and matching_score >= matching_threshold) or (self.decreasing and matching_score <= matching_threshold) + def score_beats_threshold( + self, matching_score: float, matching_threshold: float + ) -> bool: + return (self.increasing and matching_score >= matching_threshold) or ( + self.decreasing and matching_score <= matching_threshold + ) class DirectValueMeta(EnumMeta): @@ -88,9 +94,19 @@ class Metric(_Enum_Compare): DSC = _Metric("DSC", "Dice", False, _compute_instance_volumetric_dice) IOU = _Metric("IOU", "Intersection over Union", False, _compute_instance_iou) - ASSD = _Metric("ASSD", "Average Symmetric Surface Distance", True, _compute_instance_average_symmetric_surface_distance) + ASSD = _Metric( + "ASSD", + "Average Symmetric Surface Distance", + True, + _compute_instance_average_symmetric_surface_distance, + ) clDSC = _Metric("clDSC", "Centerline Dice", False, _compute_centerline_dice) - RVD = _Metric("RVD", "Relative Volume Difference", True, _compute_instance_relative_volume_difference) + RVD = _Metric( + "RVD", + "Relative Volume Difference", + True, + _compute_instance_relative_volume_difference, + ) # ST = _Metric("ST", False, _compute_instance_segmentation_tendency) def __call__( @@ -122,7 +138,9 @@ def __call__( **kwargs, ) - def score_beats_threshold(self, matching_score: float, matching_threshold: float) -> bool: + def score_beats_threshold( + self, matching_score: float, matching_threshold: float + ) -> bool: """Calculates whether a score beats a specified threshold Args: @@ -132,7 +150,9 @@ def score_beats_threshold(self, matching_score: float, matching_threshold: float Returns: bool: True if the matching_score beats the threshold, False otherwise. """ - return (self.increasing and matching_score >= matching_threshold) or (self.decreasing and matching_score <= matching_threshold) + return (self.increasing and matching_score >= matching_threshold) or ( + self.decreasing and matching_score <= matching_threshold + ) @property def name(self): @@ -229,7 +249,9 @@ def __call__(self, result_obj: "PanopticaResult") -> Any: # ERROR if self._error: if self._error_obj is None: - self._error_obj = MetricCouldNotBeComputedException(f"Metric {self.id} requested, but could not be computed") + self._error_obj = MetricCouldNotBeComputedException( + f"Metric {self.id} requested, but could not be computed" + ) raise self._error_obj # Already calculated? if self._was_calculated: @@ -237,8 +259,12 @@ def __call__(self, result_obj: "PanopticaResult") -> Any: # Calculate it try: - assert not self._was_calculated, f"Metric {self.id} was called to compute, but is set to have been already calculated" - assert self._calc_func is not None, f"Metric {self.id} was called to compute, but has no calculation function set" + assert ( + not self._was_calculated + ), f"Metric {self.id} was called to compute, but is set to have been already calculated" + assert ( + self._calc_func is not None + ), f"Metric {self.id} was called to compute, but has no calculation function set" value = self._calc_func(result_obj) except MetricCouldNotBeComputedException as e: value = e @@ -283,20 +309,32 @@ def __init__( else: self.AVG = None if self.ALL is None else np.average(self.ALL) self.SUM = None if self.ALL is None else np.sum(self.ALL) - self.MIN = None if self.ALL is None or len(self.ALL) == 0 else np.min(self.ALL) - self.MAX = None if self.ALL is None or len(self.ALL) == 0 else np.max(self.ALL) - - self.STD = None if self.ALL is None else empty_list_std if len(self.ALL) == 0 else np.std(self.ALL) + self.MIN = ( + None if self.ALL is None or len(self.ALL) == 0 else np.min(self.ALL) + ) + self.MAX = ( + None if self.ALL is None or len(self.ALL) == 0 else np.max(self.ALL) + ) + + self.STD = ( + None + if self.ALL is None + else empty_list_std if len(self.ALL) == 0 else np.std(self.ALL) + ) def __getitem__(self, mode: MetricMode | str): if self.error: - raise MetricCouldNotBeComputedException(f"Metric {self.id} has not been calculated, add it to your eval_metrics") + raise MetricCouldNotBeComputedException( + f"Metric {self.id} has not been calculated, add it to your eval_metrics" + ) if isinstance(mode, MetricMode): mode = mode.name if hasattr(self, mode): return getattr(self, mode) else: - raise MetricCouldNotBeComputedException(f"List_Metric {self.id} does not contain {mode} member") + raise MetricCouldNotBeComputedException( + f"List_Metric {self.id} does not contain {mode} member" + ) if __name__ == "__main__": diff --git a/panoptica/panoptica_evaluator.py b/panoptica/panoptica_evaluator.py index 2ffdebf..f0522f8 100644 --- a/panoptica/panoptica_evaluator.py +++ b/panoptica/panoptica_evaluator.py @@ -31,7 +31,12 @@ def __init__( instance_matcher: InstanceMatchingAlgorithm | None = None, edge_case_handler: EdgeCaseHandler | None = None, segmentation_class_groups: SegmentationClassGroups | None = None, - instance_metrics: list[Metric] = [Metric.DSC, Metric.IOU, Metric.ASSD, Metric.RVD], + instance_metrics: list[Metric] = [ + Metric.DSC, + Metric.IOU, + Metric.ASSD, + Metric.RVD, + ], global_metrics: list[Metric] = [Metric.DSC], decision_metric: Metric | None = None, decision_threshold: float | None = None, @@ -65,9 +70,13 @@ def __init__( self.__segmentation_class_groups = segmentation_class_groups - self.__edge_case_handler = edge_case_handler if edge_case_handler is not None else EdgeCaseHandler() + self.__edge_case_handler = ( + edge_case_handler if edge_case_handler is not None else EdgeCaseHandler() + ) if self.__decision_metric is not None: - assert self.__decision_threshold is not None, "decision metric set but no decision threshold for it" + assert ( + self.__decision_threshold is not None + ), "decision metric set but no decision threshold for it" # self.__log_times = log_times self.__verbose = verbose @@ -98,7 +107,9 @@ def evaluate( verbose: bool | None = None, ) -> dict[str, tuple[PanopticaResult, dict[str, _ProcessingPair]]]: processing_pair = self.__expected_input(prediction_arr, reference_arr) - assert isinstance(processing_pair, self.__expected_input.value), f"input not of expected type {self.__expected_input}" + assert isinstance( + processing_pair, self.__expected_input.value + ), f"input not of expected type {self.__expected_input}" if self.__segmentation_class_groups is None: return { @@ -118,8 +129,12 @@ def evaluate( ) } - self.__segmentation_class_groups.has_defined_labels_for(processing_pair.prediction_arr, raise_error=True) - self.__segmentation_class_groups.has_defined_labels_for(processing_pair.reference_arr, raise_error=True) + self.__segmentation_class_groups.has_defined_labels_for( + processing_pair.prediction_arr, raise_error=True + ) + self.__segmentation_class_groups.has_defined_labels_for( + processing_pair.reference_arr, raise_error=True + ) result_grouped = {} for group_name, label_group in self.__segmentation_class_groups.items(): @@ -131,7 +146,9 @@ def evaluate( single_instance_mode = label_group.single_instance processing_pair_grouped = processing_pair.__class__(prediction_arr=prediction_arr_grouped, reference_arr=reference_arr_grouped) # type: ignore decision_threshold = self.__decision_threshold - if single_instance_mode and not isinstance(processing_pair, MatchedInstancePair): + if single_instance_mode and not isinstance( + processing_pair, MatchedInstancePair + ): processing_pair_grouped = MatchedInstancePair( prediction_arr=processing_pair_grouped.prediction_arr, reference_arr=processing_pair_grouped.reference_arr, @@ -156,7 +173,9 @@ def evaluate( def panoptic_evaluate( - processing_pair: SemanticPair | UnmatchedInstancePair | MatchedInstancePair | PanopticaResult, + processing_pair: ( + SemanticPair | UnmatchedInstancePair | MatchedInstancePair | PanopticaResult + ), instance_approximator: InstanceApproximator | None = None, instance_matcher: InstanceMatchingAlgorithm | None = None, instance_metrics: list[Metric] = [Metric.DSC, Metric.IOU, Metric.ASSD], @@ -213,7 +232,9 @@ def panoptic_evaluate( processing_pair.crop_data() if isinstance(processing_pair, SemanticPair): - assert instance_approximator is not None, "Got SemanticPair but not InstanceApproximator" + assert ( + instance_approximator is not None + ), "Got SemanticPair but not InstanceApproximator" if verbose: print("-- Got SemanticPair, will approximate instances") start = perf_counter() @@ -234,7 +255,9 @@ def panoptic_evaluate( if isinstance(processing_pair, UnmatchedInstancePair): if verbose: print("-- Got UnmatchedInstancePair, will match instances") - assert instance_matcher is not None, "Got UnmatchedInstancePair but not InstanceMatchingAlgorithm" + assert ( + instance_matcher is not None + ), "Got UnmatchedInstancePair but not InstanceMatchingAlgorithm" start = perf_counter() processing_pair = instance_matcher.match_instances( processing_pair, diff --git a/panoptica/panoptica_result.py b/panoptica/panoptica_result.py index e87a33b..ef631f9 100644 --- a/panoptica/panoptica_result.py +++ b/panoptica/panoptica_result.py @@ -249,7 +249,9 @@ def __init__( num_pred_instances=self.num_pred_instances, num_ref_instances=self.num_ref_instances, ) - self._list_metrics[m] = Evaluation_List_Metric(m, empty_list_std, list_metrics[m], is_edge_case, edge_case_result) + self._list_metrics[m] = Evaluation_List_Metric( + m, empty_list_std, list_metrics[m], is_edge_case, edge_case_result + ) # even if not available, set the global vars self._add_metric( f"global_bin_{m.name.lower()}", @@ -327,13 +329,19 @@ def __str__(self) -> str: return text def to_dict(self) -> dict: - return {k: getattr(self, v.id) for k, v in self._evaluation_metrics.items() if (v._error == False and v._was_calculated)} + return { + k: getattr(self, v.id) + for k, v in self._evaluation_metrics.items() + if (v._error == False and v._was_calculated) + } def get_list_metric(self, metric: Metric, mode: MetricMode): if metric in self._list_metrics: return self._list_metrics[metric][mode] else: - raise MetricCouldNotBeComputedException(f"{metric} could not be found, have you set it in eval_metrics during evaluation?") + raise MetricCouldNotBeComputedException( + f"{metric} could not be found, have you set it in eval_metrics during evaluation?" + ) def _calc_metric(self, metric_name: str, supress_error: bool = False): if metric_name in self._evaluation_metrics: @@ -349,7 +357,9 @@ def _calc_metric(self, metric_name: str, supress_error: bool = False): self._evaluation_metrics[metric_name]._was_calculated = True return value else: - raise MetricCouldNotBeComputedException(f"could not find metric with name {metric_name}") + raise MetricCouldNotBeComputedException( + f"could not find metric with name {metric_name}" + ) def __getattribute__(self, __name: str) -> Any: attr = None @@ -362,7 +372,9 @@ def __getattribute__(self, __name: str) -> Any: raise e if attr is None: if self._evaluation_metrics[__name]._error: - raise MetricCouldNotBeComputedException(f"Requested metric {__name} that could not be computed") + raise MetricCouldNotBeComputedException( + f"Requested metric {__name} that could not be computed" + ) elif not self._evaluation_metrics[__name]._was_calculated: value = self._calc_metric(__name) setattr(self, __name, value) @@ -488,7 +500,9 @@ def function_template(res: PanopticaResult): if metric not in res._global_metrics: raise MetricCouldNotBeComputedException(f"Global Metric {metric} not set") if res.tp == 0: - is_edgecase, result = res._edge_case_handler.handle_zero_tp(metric, res.tp, res.num_pred_instances, res.num_ref_instances) + is_edgecase, result = res._edge_case_handler.handle_zero_tp( + metric, res.tp, res.num_pred_instances, res.num_ref_instances + ) if is_edgecase: return result pred_binary = res._prediction_arr.copy() diff --git a/panoptica/utils/edge_case_handling.py b/panoptica/utils/edge_case_handling.py index d6a0e02..b88f49e 100644 --- a/panoptica/utils/edge_case_handling.py +++ b/panoptica/utils/edge_case_handling.py @@ -41,12 +41,26 @@ def __init__( self._default_result = default_result self._edgecase_dict: dict[EdgeCaseZeroTP, EdgeCaseResult] = {} - self._edgecase_dict[EdgeCaseZeroTP.EMPTY_PRED] = empty_prediction_result if empty_prediction_result is not None else default_result - self._edgecase_dict[EdgeCaseZeroTP.EMPTY_REF] = empty_reference_result if empty_reference_result is not None else default_result - self._edgecase_dict[EdgeCaseZeroTP.NO_INSTANCES] = no_instances_result if no_instances_result is not None else default_result - self._edgecase_dict[EdgeCaseZeroTP.NORMAL] = normal if normal is not None else default_result + self._edgecase_dict[EdgeCaseZeroTP.EMPTY_PRED] = ( + empty_prediction_result + if empty_prediction_result is not None + else default_result + ) + self._edgecase_dict[EdgeCaseZeroTP.EMPTY_REF] = ( + empty_reference_result + if empty_reference_result is not None + else default_result + ) + self._edgecase_dict[EdgeCaseZeroTP.NO_INSTANCES] = ( + no_instances_result if no_instances_result is not None else default_result + ) + self._edgecase_dict[EdgeCaseZeroTP.NORMAL] = ( + normal if normal is not None else default_result + ) - def __call__(self, tp: int, num_pred_instances, num_ref_instances) -> tuple[bool, float | None]: + def __call__( + self, tp: int, num_pred_instances, num_ref_instances + ) -> tuple[bool, float | None]: if tp != 0: return False, EdgeCaseResult.NONE.value # @@ -117,7 +131,9 @@ def __init__( }, empty_list_std: EdgeCaseResult = EdgeCaseResult.NAN, ) -> None: - self.__listmetric_zeroTP_handling: dict[Metric, MetricZeroTPEdgeCaseHandling] = listmetric_zeroTP_handling + self.__listmetric_zeroTP_handling: dict[ + Metric, MetricZeroTPEdgeCaseHandling + ] = listmetric_zeroTP_handling self.__empty_list_std: EdgeCaseResult = empty_list_std def handle_zero_tp( @@ -144,7 +160,9 @@ def handle_zero_tp( if tp != 0: return False, EdgeCaseResult.NONE.value if metric not in self.__listmetric_zeroTP_handling: - raise NotImplementedError(f"Metric {metric} encountered zero TP, but no edge handling available") + raise NotImplementedError( + f"Metric {metric} encountered zero TP, but no edge handling available" + ) return self.__listmetric_zeroTP_handling[metric]( tp=tp, @@ -181,7 +199,9 @@ def _yaml_repr(cls, node) -> dict: print() # print(handler.get_metric_zero_tp_handle(ListMetric.IOU)) - r = handler.handle_zero_tp(Metric.IOU, tp=0, num_pred_instances=1, num_ref_instances=1) + r = handler.handle_zero_tp( + Metric.IOU, tp=0, num_pred_instances=1, num_ref_instances=1 + ) print(r) iou_test = MetricZeroTPEdgeCaseHandling( diff --git a/panoptica/utils/processing_pair.py b/panoptica/utils/processing_pair.py index b89d5ad..e0afa86 100644 --- a/panoptica/utils/processing_pair.py +++ b/panoptica/utils/processing_pair.py @@ -25,7 +25,9 @@ class _ProcessingPair(ABC): _pred_labels: tuple[int, ...] n_dim: int - def __init__(self, prediction_arr: np.ndarray, reference_arr: np.ndarray, dtype: type | None) -> None: + def __init__( + self, prediction_arr: np.ndarray, reference_arr: np.ndarray, dtype: type | None + ) -> None: """Initializes a general Processing Pair Args: @@ -38,8 +40,12 @@ def __init__(self, prediction_arr: np.ndarray, reference_arr: np.ndarray, dtype: self._reference_arr = reference_arr self.dtype = dtype self.n_dim = reference_arr.ndim - self._ref_labels: tuple[int, ...] = tuple(_unique_without_zeros(reference_arr)) # type:ignore - self._pred_labels: tuple[int, ...] = tuple(_unique_without_zeros(prediction_arr)) # type:ignore + self._ref_labels: tuple[int, ...] = tuple( + _unique_without_zeros(reference_arr) + ) # type:ignore + self._pred_labels: tuple[int, ...] = tuple( + _unique_without_zeros(prediction_arr) + ) # type:ignore self.crop: tuple[slice, ...] = None self.is_cropped: bool = False self.uncropped_shape: tuple[int, ...] = reference_arr.shape @@ -56,25 +62,41 @@ def crop_data(self, verbose: bool = False): self._prediction_arr = self._prediction_arr[self.crop] self._reference_arr = self._reference_arr[self.crop] - (print(f"-- Cropped from {self.uncropped_shape} to {self._prediction_arr.shape}") if verbose else None) + ( + print( + f"-- Cropped from {self.uncropped_shape} to {self._prediction_arr.shape}" + ) + if verbose + else None + ) self.is_cropped = True def uncrop_data(self, verbose: bool = False): if self.is_cropped == False: return - assert self.uncropped_shape is not None, "Calling uncrop_data() without having cropped first" + assert ( + self.uncropped_shape is not None + ), "Calling uncrop_data() without having cropped first" prediction_arr = np.zeros(self.uncropped_shape) prediction_arr[self.crop] = self._prediction_arr self._prediction_arr = prediction_arr reference_arr = np.zeros(self.uncropped_shape) reference_arr[self.crop] = self._reference_arr - (print(f"-- Uncropped from {self._reference_arr.shape} to {self.uncropped_shape}") if verbose else None) + ( + print( + f"-- Uncropped from {self._reference_arr.shape} to {self.uncropped_shape}" + ) + if verbose + else None + ) self._reference_arr = reference_arr self.is_cropped = False def set_dtype(self, type): - assert np.issubdtype(type, int_type), "set_dtype: tried to set dtype to something other than integers" + assert np.issubdtype( + type, int_type + ), "set_dtype: tried to set dtype to something other than integers" self._prediction_arr = self._prediction_arr.astype(type) self._reference_arr = self._reference_arr.astype(type) @@ -153,7 +175,9 @@ def copy(self): ) # type:ignore -def _check_array_integrity(prediction_arr: np.ndarray, reference_arr: np.ndarray, dtype: type | None = None): +def _check_array_integrity( + prediction_arr: np.ndarray, reference_arr: np.ndarray, dtype: type | None = None +): """ Check the integrity of two numpy arrays. @@ -175,8 +199,12 @@ def _check_array_integrity(prediction_arr: np.ndarray, reference_arr: np.ndarray assert isinstance(prediction_arr, np.ndarray) and isinstance( reference_arr, np.ndarray ), "prediction and/or reference are not numpy arrays" - assert prediction_arr.shape == reference_arr.shape, f"shape mismatch, got {prediction_arr.shape},{reference_arr.shape}" - assert prediction_arr.dtype == reference_arr.dtype, f"dtype mismatch, got {prediction_arr.dtype},{reference_arr.dtype}" + assert ( + prediction_arr.shape == reference_arr.shape + ), f"shape mismatch, got {prediction_arr.shape},{reference_arr.shape}" + assert ( + prediction_arr.dtype == reference_arr.dtype + ), f"dtype mismatch, got {prediction_arr.dtype},{reference_arr.dtype}" if dtype is not None: assert ( np.issubdtype(prediction_arr.dtype, dtype) @@ -259,11 +287,15 @@ def __init__( self.matched_instances = matched_instances if missed_reference_labels is None: - missed_reference_labels = list([i for i in self._ref_labels if i not in self._pred_labels]) + missed_reference_labels = list( + [i for i in self._ref_labels if i not in self._pred_labels] + ) self.missed_reference_labels = missed_reference_labels if missed_prediction_labels is None: - missed_prediction_labels = list([i for i in self._pred_labels if i not in self._ref_labels]) + missed_prediction_labels = list( + [i for i in self._pred_labels if i not in self._ref_labels] + ) self.missed_prediction_labels = missed_prediction_labels @property @@ -300,5 +332,7 @@ class InputType(_Enum_Compare): UNMATCHED_INSTANCE = UnmatchedInstancePair MATCHED_INSTANCE = MatchedInstancePair - def __call__(self, prediction_arr: np.ndarray, reference_arr: np.ndarray) -> _ProcessingPair: + def __call__( + self, prediction_arr: np.ndarray, reference_arr: np.ndarray + ) -> _ProcessingPair: return self.value(prediction_arr, reference_arr) diff --git a/unit_tests/test_panoptic_result.py b/unit_tests/test_panoptic_result.py index 883c604..200d084 100644 --- a/unit_tests/test_panoptic_result.py +++ b/unit_tests/test_panoptic_result.py @@ -86,7 +86,9 @@ def test_existing_metrics(self): def powerset(iterable): s = list(iterable) - return list(chain.from_iterable(combinations(s, r) for r in range(len(s) + 1))) + return list( + chain.from_iterable(combinations(s, r) for r in range(len(s) + 1)) + ) power_set = powerset([Metric.DSC, Metric.IOU, Metric.ASSD]) for m in power_set[1:]: