diff --git a/panoptica/instance_matcher.py b/panoptica/instance_matcher.py index 4273fd5..3b32e63 100644 --- a/panoptica/instance_matcher.py +++ b/panoptica/instance_matcher.py @@ -80,7 +80,9 @@ def match_instances( return map_instance_labels(unmatched_instance_pair.copy(), instance_labelmap) -def map_instance_labels(processing_pair: UnmatchedInstancePair, labelmap: InstanceLabelMap) -> MatchedInstancePair: +def map_instance_labels( + processing_pair: UnmatchedInstancePair, labelmap: InstanceLabelMap +) -> MatchedInstancePair: """ Map instance labels based on the provided labelmap and create a MatchedInstancePair. @@ -192,13 +194,20 @@ def _match_instances( unmatched_instance_pair.prediction_arr, unmatched_instance_pair.reference_arr, ) - mm_pairs = _calc_matching_metric_of_overlapping_labels(pred_arr, ref_arr, ref_labels, matching_metric=self.matching_metric) + mm_pairs = _calc_matching_metric_of_overlapping_labels( + pred_arr, ref_arr, ref_labels, matching_metric=self.matching_metric + ) # Loop through matched instances to compute PQ components for matching_score, (ref_label, pred_label) in mm_pairs: - if labelmap.contains_or(pred_label, ref_label) and not self.allow_many_to_one: + if ( + labelmap.contains_or(pred_label, ref_label) + and not self.allow_many_to_one + ): continue # -> doesnt make speed difference - if self.matching_metric.score_beats_threshold(matching_score, self.matching_threshold): + if self.matching_metric.score_beats_threshold( + matching_score, self.matching_threshold + ): # Match found, increment true positive count and collect IoU and Dice values labelmap.add_labelmap_entry(pred_label, ref_label) # map label ref_idx to pred_idx @@ -275,11 +284,15 @@ def _match_instances( continue if labelmap.contains_ref(ref_label): pred_labels_ = labelmap.get_pred_labels_matched_to_ref(ref_label) - new_score = self.new_combination_score(pred_labels_, pred_label, ref_label, unmatched_instance_pair) + new_score = self.new_combination_score( + pred_labels_, pred_label, ref_label, unmatched_instance_pair + ) if new_score > score_ref[ref_label]: labelmap.add_labelmap_entry(pred_label, ref_label) score_ref[ref_label] = new_score - elif self.matching_metric.score_beats_threshold(matching_score, self.matching_threshold): + elif self.matching_metric.score_beats_threshold( + matching_score, self.matching_threshold + ): # Match found, increment true positive count and collect IoU and Dice values labelmap.add_labelmap_entry(pred_label, ref_label) score_ref[ref_label] = matching_score diff --git a/panoptica/metrics/metrics.py b/panoptica/metrics/metrics.py index a3d13bc..960e6af 100644 --- a/panoptica/metrics/metrics.py +++ b/panoptica/metrics/metrics.py @@ -37,7 +37,9 @@ def __call__( reference_arr = reference_arr.copy() == ref_instance_idx if isinstance(pred_instance_idx, int): pred_instance_idx = [pred_instance_idx] - prediction_arr = np.isin(prediction_arr.copy(), pred_instance_idx) # type:ignore + prediction_arr = np.isin( + prediction_arr.copy(), pred_instance_idx + ) # type:ignore return self._metric_function(reference_arr, prediction_arr, *args, **kwargs) def __eq__(self, __value: object) -> bool: @@ -61,8 +63,12 @@ def __hash__(self) -> int: def increasing(self): return not self.decreasing - def score_beats_threshold(self, matching_score: float, matching_threshold: float) -> bool: - return (self.increasing and matching_score >= matching_threshold) or (self.decreasing and matching_score <= matching_threshold) + def score_beats_threshold( + self, matching_score: float, matching_threshold: float + ) -> bool: + return (self.increasing and matching_score >= matching_threshold) or ( + self.decreasing and matching_score <= matching_threshold + ) class DirectValueMeta(EnumMeta): @@ -117,7 +123,9 @@ def __call__( **kwargs, ) - def score_beats_threshold(self, matching_score: float, matching_threshold: float) -> bool: + def score_beats_threshold( + self, matching_score: float, matching_threshold: float + ) -> bool: """Calculates whether a score beats a specified threshold Args: @@ -127,7 +135,9 @@ def score_beats_threshold(self, matching_score: float, matching_threshold: float Returns: bool: True if the matching_score beats the threshold, False otherwise. """ - return (self.increasing and matching_score >= matching_threshold) or (self.decreasing and matching_score <= matching_threshold) + return (self.increasing and matching_score >= matching_threshold) or ( + self.decreasing and matching_score <= matching_threshold + ) @property def name(self): @@ -222,7 +232,9 @@ def __call__(self, result_obj: "PanopticaResult") -> Any: # ERROR if self._error: if self._error_obj is None: - self._error_obj = MetricCouldNotBeComputedException(f"Metric {self.id} requested, but could not be computed") + self._error_obj = MetricCouldNotBeComputedException( + f"Metric {self.id} requested, but could not be computed" + ) raise self._error_obj # Already calculated? if self._was_calculated: @@ -230,8 +242,12 @@ def __call__(self, result_obj: "PanopticaResult") -> Any: # Calculate it try: - assert not self._was_calculated, f"Metric {self.id} was called to compute, but is set to have been already calculated" - assert self._calc_func is not None, f"Metric {self.id} was called to compute, but has no calculation function set" + assert ( + not self._was_calculated + ), f"Metric {self.id} was called to compute, but is set to have been already calculated" + assert ( + self._calc_func is not None + ), f"Metric {self.id} was called to compute, but has no calculation function set" value = self._calc_func(result_obj) except MetricCouldNotBeComputedException as e: value = e @@ -274,17 +290,25 @@ def __init__( else: self.AVG = None if self.ALL is None else np.average(self.ALL) self.SUM = None if self.ALL is None else np.sum(self.ALL) - self.STD = None if self.ALL is None else empty_list_std if len(self.ALL) == 0 else np.std(self.ALL) + self.STD = ( + None + if self.ALL is None + else empty_list_std if len(self.ALL) == 0 else np.std(self.ALL) + ) def __getitem__(self, mode: MetricMode | str): if self.error: - raise MetricCouldNotBeComputedException(f"Metric {self.id} has not been calculated, add it to your eval_metrics") + raise MetricCouldNotBeComputedException( + f"Metric {self.id} has not been calculated, add it to your eval_metrics" + ) if isinstance(mode, MetricMode): mode = mode.name if hasattr(self, mode): return getattr(self, mode) else: - raise MetricCouldNotBeComputedException(f"List_Metric {self.id} does not contain {mode} member") + raise MetricCouldNotBeComputedException( + f"List_Metric {self.id} does not contain {mode} member" + ) if __name__ == "__main__": diff --git a/panoptica/panoptic_evaluator.py b/panoptica/panoptic_evaluator.py index 744f61d..1daeda9 100644 --- a/panoptica/panoptic_evaluator.py +++ b/panoptica/panoptic_evaluator.py @@ -20,7 +20,9 @@ class Panoptic_Evaluator: def __init__( self, - expected_input: Type[SemanticPair] | Type[UnmatchedInstancePair] | Type[MatchedInstancePair] = MatchedInstancePair, + expected_input: ( + Type[SemanticPair] | Type[UnmatchedInstancePair] | Type[MatchedInstancePair] + ) = MatchedInstancePair, instance_approximator: InstanceApproximator | None = None, instance_matcher: InstanceMatchingAlgorithm | None = None, edge_case_handler: EdgeCaseHandler | None = None, @@ -46,9 +48,13 @@ def __init__( self.__decision_metric = decision_metric self.__decision_threshold = decision_threshold - self.__edge_case_handler = edge_case_handler if edge_case_handler is not None else EdgeCaseHandler() + self.__edge_case_handler = ( + edge_case_handler if edge_case_handler is not None else EdgeCaseHandler() + ) if self.__decision_metric is not None: - assert self.__decision_threshold is not None, "decision metric set but no decision threshold for it" + assert ( + self.__decision_threshold is not None + ), "decision metric set but no decision threshold for it" # self.__log_times = log_times self.__verbose = verbose @@ -57,11 +63,15 @@ def __init__( @measure_time def evaluate( self, - processing_pair: SemanticPair | UnmatchedInstancePair | MatchedInstancePair | PanopticaResult, + processing_pair: ( + SemanticPair | UnmatchedInstancePair | MatchedInstancePair | PanopticaResult + ), result_all: bool = True, verbose: bool | None = None, ) -> tuple[PanopticaResult, dict[str, _ProcessingPair]]: - assert type(processing_pair) == self.__expected_input, f"input not of expected type {self.__expected_input}" + assert ( + type(processing_pair) == self.__expected_input + ), f"input not of expected type {self.__expected_input}" return panoptic_evaluate( processing_pair=processing_pair, edge_case_handler=self.__edge_case_handler, @@ -77,7 +87,9 @@ def evaluate( def panoptic_evaluate( - processing_pair: SemanticPair | UnmatchedInstancePair | MatchedInstancePair | PanopticaResult, + processing_pair: ( + SemanticPair | UnmatchedInstancePair | MatchedInstancePair | PanopticaResult + ), instance_approximator: InstanceApproximator | None = None, instance_matcher: InstanceMatchingAlgorithm | None = None, eval_metrics: list[Metric] = [Metric.DSC, Metric.IOU, Metric.ASSD], @@ -130,7 +142,9 @@ def panoptic_evaluate( processing_pair.crop_data() if isinstance(processing_pair, SemanticPair): - assert instance_approximator is not None, "Got SemanticPair but not InstanceApproximator" + assert ( + instance_approximator is not None + ), "Got SemanticPair but not InstanceApproximator" print("-- Got SemanticPair, will approximate instances") start = perf_counter() processing_pair = instance_approximator.approximate_instances(processing_pair) @@ -148,7 +162,9 @@ def panoptic_evaluate( if isinstance(processing_pair, UnmatchedInstancePair): print("-- Got UnmatchedInstancePair, will match instances") - assert instance_matcher is not None, "Got UnmatchedInstancePair but not InstanceMatchingAlgorithm" + assert ( + instance_matcher is not None + ), "Got UnmatchedInstancePair but not InstanceMatchingAlgorithm" start = perf_counter() processing_pair = instance_matcher.match_instances( processing_pair, diff --git a/panoptica/panoptic_result.py b/panoptica/panoptic_result.py index 2e3076f..0abbba9 100644 --- a/panoptica/panoptic_result.py +++ b/panoptica/panoptic_result.py @@ -244,7 +244,9 @@ def __init__( num_pred_instances=self.num_pred_instances, num_ref_instances=self.num_ref_instances, ) - self._list_metrics[k] = Evaluation_List_Metric(k, empty_list_std, v, is_edge_case, edge_case_result) + self._list_metrics[k] = Evaluation_List_Metric( + k, empty_list_std, v, is_edge_case, edge_case_result + ) def _add_metric( self, @@ -313,13 +315,19 @@ def __str__(self) -> str: return text def to_dict(self) -> dict: - return {k: getattr(self, v.id) for k, v in self._evaluation_metrics.items() if (v._error == False and v._was_calculated)} + return { + k: getattr(self, v.id) + for k, v in self._evaluation_metrics.items() + if (v._error == False and v._was_calculated) + } def get_list_metric(self, metric: Metric, mode: MetricMode): if metric in self._list_metrics: return self._list_metrics[metric][mode] else: - raise MetricCouldNotBeComputedException(f"{metric} could not be found, have you set it in eval_metrics during evaluation?") + raise MetricCouldNotBeComputedException( + f"{metric} could not be found, have you set it in eval_metrics during evaluation?" + ) def _calc_metric(self, metric_name: str, supress_error: bool = False): if metric_name in self._evaluation_metrics: @@ -335,7 +343,9 @@ def _calc_metric(self, metric_name: str, supress_error: bool = False): self._evaluation_metrics[metric_name]._was_calculated = True return value else: - raise MetricCouldNotBeComputedException(f"could not find metric with name {metric_name}") + raise MetricCouldNotBeComputedException( + f"could not find metric with name {metric_name}" + ) def __getattribute__(self, __name: str) -> Any: attr = None @@ -348,7 +358,9 @@ def __getattribute__(self, __name: str) -> Any: raise e if attr is None: if self._evaluation_metrics[__name]._error: - raise MetricCouldNotBeComputedException(f"Requested metric {__name} that could not be computed") + raise MetricCouldNotBeComputedException( + f"Requested metric {__name} that could not be computed" + ) elif not self._evaluation_metrics[__name]._was_calculated: value = self._calc_metric(__name) setattr(self, __name, value)