From dbb5c3e3ce8caca9fe7b64d3ac2d55cf558ed4cf Mon Sep 17 00:00:00 2001 From: iback Date: Tue, 13 Feb 2024 09:14:29 +0000 Subject: [PATCH 1/4] added precision and recall, added NO_PRINT metric type so they don't get printed by default --- panoptica/metrics/metrics.py | 47 +++++++++------------------------ panoptica/panoptic_evaluator.py | 34 +++++++----------------- panoptica/panoptic_result.py | 46 ++++++++++++++++++++------------ 3 files changed, 50 insertions(+), 77 deletions(-) diff --git a/panoptica/metrics/metrics.py b/panoptica/metrics/metrics.py index 57432d7..a3d13bc 100644 --- a/panoptica/metrics/metrics.py +++ b/panoptica/metrics/metrics.py @@ -37,9 +37,7 @@ def __call__( reference_arr = reference_arr.copy() == ref_instance_idx if isinstance(pred_instance_idx, int): pred_instance_idx = [pred_instance_idx] - prediction_arr = np.isin( - prediction_arr.copy(), pred_instance_idx - ) # type:ignore + prediction_arr = np.isin(prediction_arr.copy(), pred_instance_idx) # type:ignore return self._metric_function(reference_arr, prediction_arr, *args, **kwargs) def __eq__(self, __value: object) -> bool: @@ -63,12 +61,8 @@ def __hash__(self) -> int: def increasing(self): return not self.decreasing - def score_beats_threshold( - self, matching_score: float, matching_threshold: float - ) -> bool: - return (self.increasing and matching_score >= matching_threshold) or ( - self.decreasing and matching_score <= matching_threshold - ) + def score_beats_threshold(self, matching_score: float, matching_threshold: float) -> bool: + return (self.increasing and matching_score >= matching_threshold) or (self.decreasing and matching_score <= matching_threshold) class DirectValueMeta(EnumMeta): @@ -123,9 +117,7 @@ def __call__( **kwargs, ) - def score_beats_threshold( - self, matching_score: float, matching_threshold: float - ) -> bool: + def score_beats_threshold(self, matching_score: float, matching_threshold: float) -> bool: """Calculates whether a score beats a specified threshold Args: @@ -135,9 +127,7 @@ def score_beats_threshold( Returns: bool: True if the matching_score beats the threshold, False otherwise. """ - return (self.increasing and matching_score >= matching_threshold) or ( - self.decreasing and matching_score <= matching_threshold - ) + return (self.increasing and matching_score >= matching_threshold) or (self.decreasing and matching_score <= matching_threshold) @property def name(self): @@ -175,6 +165,7 @@ class MetricType(_Enum_Compare): _Enum_Compare (_type_): _description_ """ + NO_PRINT = auto() MATCHING = auto() GLOBAL = auto() INSTANCE = auto() @@ -231,9 +222,7 @@ def __call__(self, result_obj: "PanopticaResult") -> Any: # ERROR if self._error: if self._error_obj is None: - self._error_obj = MetricCouldNotBeComputedException( - f"Metric {self.id} requested, but could not be computed" - ) + self._error_obj = MetricCouldNotBeComputedException(f"Metric {self.id} requested, but could not be computed") raise self._error_obj # Already calculated? if self._was_calculated: @@ -241,12 +230,8 @@ def __call__(self, result_obj: "PanopticaResult") -> Any: # Calculate it try: - assert ( - not self._was_calculated - ), f"Metric {self.id} was called to compute, but is set to have been already calculated" - assert ( - self._calc_func is not None - ), f"Metric {self.id} was called to compute, but has no calculation function set" + assert not self._was_calculated, f"Metric {self.id} was called to compute, but is set to have been already calculated" + assert self._calc_func is not None, f"Metric {self.id} was called to compute, but has no calculation function set" value = self._calc_func(result_obj) except MetricCouldNotBeComputedException as e: value = e @@ -289,25 +274,17 @@ def __init__( else: self.AVG = None if self.ALL is None else np.average(self.ALL) self.SUM = None if self.ALL is None else np.sum(self.ALL) - self.STD = ( - None - if self.ALL is None - else empty_list_std if len(self.ALL) == 0 else np.std(self.ALL) - ) + self.STD = None if self.ALL is None else empty_list_std if len(self.ALL) == 0 else np.std(self.ALL) def __getitem__(self, mode: MetricMode | str): if self.error: - raise MetricCouldNotBeComputedException( - f"Metric {self.id} has not been calculated, add it to your eval_metrics" - ) + raise MetricCouldNotBeComputedException(f"Metric {self.id} has not been calculated, add it to your eval_metrics") if isinstance(mode, MetricMode): mode = mode.name if hasattr(self, mode): return getattr(self, mode) else: - raise MetricCouldNotBeComputedException( - f"List_Metric {self.id} does not contain {mode} member" - ) + raise MetricCouldNotBeComputedException(f"List_Metric {self.id} does not contain {mode} member") if __name__ == "__main__": diff --git a/panoptica/panoptic_evaluator.py b/panoptica/panoptic_evaluator.py index 0119fbc..744f61d 100644 --- a/panoptica/panoptic_evaluator.py +++ b/panoptica/panoptic_evaluator.py @@ -20,9 +20,7 @@ class Panoptic_Evaluator: def __init__( self, - expected_input: ( - Type[SemanticPair] | Type[UnmatchedInstancePair] | Type[MatchedInstancePair] - ) = MatchedInstancePair, + expected_input: Type[SemanticPair] | Type[UnmatchedInstancePair] | Type[MatchedInstancePair] = MatchedInstancePair, instance_approximator: InstanceApproximator | None = None, instance_matcher: InstanceMatchingAlgorithm | None = None, edge_case_handler: EdgeCaseHandler | None = None, @@ -48,13 +46,9 @@ def __init__( self.__decision_metric = decision_metric self.__decision_threshold = decision_threshold - self.__edge_case_handler = ( - edge_case_handler if edge_case_handler is not None else EdgeCaseHandler() - ) + self.__edge_case_handler = edge_case_handler if edge_case_handler is not None else EdgeCaseHandler() if self.__decision_metric is not None: - assert ( - self.__decision_threshold is not None - ), "decision metric set but no decision threshold for it" + assert self.__decision_threshold is not None, "decision metric set but no decision threshold for it" # self.__log_times = log_times self.__verbose = verbose @@ -63,15 +57,11 @@ def __init__( @measure_time def evaluate( self, - processing_pair: ( - SemanticPair | UnmatchedInstancePair | MatchedInstancePair | PanopticaResult - ), + processing_pair: SemanticPair | UnmatchedInstancePair | MatchedInstancePair | PanopticaResult, result_all: bool = True, verbose: bool | None = None, ) -> tuple[PanopticaResult, dict[str, _ProcessingPair]]: - assert ( - type(processing_pair) == self.__expected_input - ), f"input not of expected type {self.__expected_input}" + assert type(processing_pair) == self.__expected_input, f"input not of expected type {self.__expected_input}" return panoptic_evaluate( processing_pair=processing_pair, edge_case_handler=self.__edge_case_handler, @@ -87,9 +77,7 @@ def evaluate( def panoptic_evaluate( - processing_pair: ( - SemanticPair | UnmatchedInstancePair | MatchedInstancePair | PanopticaResult - ), + processing_pair: SemanticPair | UnmatchedInstancePair | MatchedInstancePair | PanopticaResult, instance_approximator: InstanceApproximator | None = None, instance_matcher: InstanceMatchingAlgorithm | None = None, eval_metrics: list[Metric] = [Metric.DSC, Metric.IOU, Metric.ASSD], @@ -142,11 +130,8 @@ def panoptic_evaluate( processing_pair.crop_data() if isinstance(processing_pair, SemanticPair): - assert ( - instance_approximator is not None - ), "Got SemanticPair but not InstanceApproximator" + assert instance_approximator is not None, "Got SemanticPair but not InstanceApproximator" print("-- Got SemanticPair, will approximate instances") - processing_pair = instance_approximator.approximate_instances(processing_pair) start = perf_counter() processing_pair = instance_approximator.approximate_instances(processing_pair) if log_times: @@ -163,9 +148,7 @@ def panoptic_evaluate( if isinstance(processing_pair, UnmatchedInstancePair): print("-- Got UnmatchedInstancePair, will match instances") - assert ( - instance_matcher is not None - ), "Got UnmatchedInstancePair but not InstanceMatchingAlgorithm" + assert instance_matcher is not None, "Got UnmatchedInstancePair but not InstanceMatchingAlgorithm" start = perf_counter() processing_pair = instance_matcher.match_instances( processing_pair, @@ -185,6 +168,7 @@ def panoptic_evaluate( if isinstance(processing_pair, MatchedInstancePair): print("-- Got MatchedInstancePair, will evaluate instances") + start = perf_counter() processing_pair = evaluate_matched_instance( processing_pair, eval_metrics=eval_metrics, diff --git a/panoptica/panoptic_result.py b/panoptica/panoptic_result.py index 2b4b79b..4498c5f 100644 --- a/panoptica/panoptic_result.py +++ b/panoptica/panoptic_result.py @@ -94,6 +94,20 @@ def __init__( fn, long_name="False Negatives", ) + self.prec: int + self._add_metric( + "prec", + MetricType.NO_PRINT, + prec, + long_name="Precision (positive predictive value)", + ) + self.rec: int + self._add_metric( + "rec", + MetricType.NO_PRINT, + rec, + long_name="Recall (sensitivity)", + ) self.rq: float self._add_metric( "rq", @@ -221,9 +235,7 @@ def __init__( num_pred_instances=self.num_pred_instances, num_ref_instances=self.num_ref_instances, ) - self._list_metrics[k] = Evaluation_List_Metric( - k, empty_list_std, v, is_edge_case, edge_case_result - ) + self._list_metrics[k] = Evaluation_List_Metric(k, empty_list_std, v, is_edge_case, edge_case_result) def _add_metric( self, @@ -270,6 +282,8 @@ def calculate_all(self, print_errors: bool = False): def __str__(self) -> str: text = "" for metric_type in MetricType: + if metric_type == MetricType.NO_PRINT: + continue text += f"\n+++ {metric_type.name} +++\n" for k, v in self._evaluation_metrics.items(): if v.metric_type != metric_type: @@ -290,19 +304,13 @@ def __str__(self) -> str: return text def to_dict(self) -> dict: - return { - k: getattr(self, v.id) - for k, v in self._evaluation_metrics.items() - if (v._error == False and v._was_calculated) - } + return {k: getattr(self, v.id) for k, v in self._evaluation_metrics.items() if (v._error == False and v._was_calculated)} def get_list_metric(self, metric: Metric, mode: MetricMode): if metric in self._list_metrics: return self._list_metrics[metric][mode] else: - raise MetricCouldNotBeComputedException( - f"{metric} could not be found, have you set it in eval_metrics during evaluation?" - ) + raise MetricCouldNotBeComputedException(f"{metric} could not be found, have you set it in eval_metrics during evaluation?") def _calc_metric(self, metric_name: str, supress_error: bool = False): if metric_name in self._evaluation_metrics: @@ -318,9 +326,7 @@ def _calc_metric(self, metric_name: str, supress_error: bool = False): self._evaluation_metrics[metric_name]._was_calculated = True return value else: - raise MetricCouldNotBeComputedException( - f"could not find metric with name {metric_name}" - ) + raise MetricCouldNotBeComputedException(f"could not find metric with name {metric_name}") def __getattribute__(self, __name: str) -> Any: attr = None @@ -333,9 +339,7 @@ def __getattribute__(self, __name: str) -> Any: raise e if attr is None: if self._evaluation_metrics[__name]._error: - raise MetricCouldNotBeComputedException( - f"Requested metric {__name} that could not be computed" - ) + raise MetricCouldNotBeComputedException(f"Requested metric {__name} that could not be computed") elif not self._evaluation_metrics[__name]._was_calculated: value = self._calc_metric(__name) setattr(self, __name, value) @@ -360,6 +364,14 @@ def fn(res: PanopticaResult): return res.num_ref_instances - res.tp +def prec(res: PanopticaResult): + return res.tp / (res.tp + res.fp) + + +def rec(res: PanopticaResult): + return res.tp / (res.tp + res.fn) + + def rq(res: PanopticaResult): """ Calculate the Recognition Quality (RQ) based on TP, FP, and FN. From fe7c3b283668c330b56e9110a6df72dff6daca21 Mon Sep 17 00:00:00 2001 From: iback Date: Thu, 18 Apr 2024 09:35:24 +0000 Subject: [PATCH 2/4] formatting --- panoptica/instance_matcher.py | 25 ++++++------------------- 1 file changed, 6 insertions(+), 19 deletions(-) diff --git a/panoptica/instance_matcher.py b/panoptica/instance_matcher.py index 3b32e63..4273fd5 100644 --- a/panoptica/instance_matcher.py +++ b/panoptica/instance_matcher.py @@ -80,9 +80,7 @@ def match_instances( return map_instance_labels(unmatched_instance_pair.copy(), instance_labelmap) -def map_instance_labels( - processing_pair: UnmatchedInstancePair, labelmap: InstanceLabelMap -) -> MatchedInstancePair: +def map_instance_labels(processing_pair: UnmatchedInstancePair, labelmap: InstanceLabelMap) -> MatchedInstancePair: """ Map instance labels based on the provided labelmap and create a MatchedInstancePair. @@ -194,20 +192,13 @@ def _match_instances( unmatched_instance_pair.prediction_arr, unmatched_instance_pair.reference_arr, ) - mm_pairs = _calc_matching_metric_of_overlapping_labels( - pred_arr, ref_arr, ref_labels, matching_metric=self.matching_metric - ) + mm_pairs = _calc_matching_metric_of_overlapping_labels(pred_arr, ref_arr, ref_labels, matching_metric=self.matching_metric) # Loop through matched instances to compute PQ components for matching_score, (ref_label, pred_label) in mm_pairs: - if ( - labelmap.contains_or(pred_label, ref_label) - and not self.allow_many_to_one - ): + if labelmap.contains_or(pred_label, ref_label) and not self.allow_many_to_one: continue # -> doesnt make speed difference - if self.matching_metric.score_beats_threshold( - matching_score, self.matching_threshold - ): + if self.matching_metric.score_beats_threshold(matching_score, self.matching_threshold): # Match found, increment true positive count and collect IoU and Dice values labelmap.add_labelmap_entry(pred_label, ref_label) # map label ref_idx to pred_idx @@ -284,15 +275,11 @@ def _match_instances( continue if labelmap.contains_ref(ref_label): pred_labels_ = labelmap.get_pred_labels_matched_to_ref(ref_label) - new_score = self.new_combination_score( - pred_labels_, pred_label, ref_label, unmatched_instance_pair - ) + new_score = self.new_combination_score(pred_labels_, pred_label, ref_label, unmatched_instance_pair) if new_score > score_ref[ref_label]: labelmap.add_labelmap_entry(pred_label, ref_label) score_ref[ref_label] = new_score - elif self.matching_metric.score_beats_threshold( - matching_score, self.matching_threshold - ): + elif self.matching_metric.score_beats_threshold(matching_score, self.matching_threshold): # Match found, increment true positive count and collect IoU and Dice values labelmap.add_labelmap_entry(pred_label, ref_label) score_ref[ref_label] = matching_score From 35beb13cfb20e6592a343977dd403618f9f99c2d Mon Sep 17 00:00:00 2001 From: iback Date: Thu, 18 Apr 2024 09:41:17 +0000 Subject: [PATCH 3/4] added global assd --- examples/example_spine_semantic.py | 1 + panoptica/panoptic_result.py | 19 +++++++++++++++++++ 2 files changed, 20 insertions(+) diff --git a/examples/example_spine_semantic.py b/examples/example_spine_semantic.py index 8bf993d..57e6a04 100644 --- a/examples/example_spine_semantic.py +++ b/examples/example_spine_semantic.py @@ -22,6 +22,7 @@ expected_input=SemanticPair, instance_approximator=ConnectedComponentsInstanceApproximator(), instance_matcher=NaiveThresholdMatching(), + verbose=True, ) with cProfile.Profile() as pr: diff --git a/panoptica/panoptic_result.py b/panoptica/panoptic_result.py index 4498c5f..2e3076f 100644 --- a/panoptica/panoptic_result.py +++ b/panoptica/panoptic_result.py @@ -13,6 +13,7 @@ MetricType, _compute_centerline_dice_coefficient, _compute_dice_coefficient, + _average_symmetric_surface_distance, ) from panoptica.utils import EdgeCaseHandler @@ -133,6 +134,14 @@ def __init__( global_bin_cldsc, long_name="Global Binary Centerline Dice", ) + # + self.global_bin_assd: int + self._add_metric( + "global_bin_assd", + MetricType.GLOBAL, + global_bin_assd, + long_name="Global Binary Average Symmetric Surface Distance", + ) # endregion # # region IOU @@ -468,6 +477,16 @@ def global_bin_cldsc(res: PanopticaResult): return _compute_centerline_dice_coefficient(ref_binary, pred_binary) +def global_bin_assd(res: PanopticaResult): + if res.tp == 0: + return 0.0 + pred_binary = res._prediction_arr.copy() + ref_binary = res._reference_arr.copy() + pred_binary[pred_binary != 0] = 1 + ref_binary[ref_binary != 0] = 1 + return _average_symmetric_surface_distance(ref_binary, pred_binary) + + # endregion From 7f8cccedcd0f4b681005f85d4554225bb98e7c22 Mon Sep 17 00:00:00 2001 From: "brainless-bot[bot]" <153751247+brainless-bot[bot]@users.noreply.github.com> Date: Thu, 18 Apr 2024 09:44:26 +0000 Subject: [PATCH 4/4] Autoformat with black --- panoptica/instance_matcher.py | 25 +++++++++++++----- panoptica/metrics/metrics.py | 46 +++++++++++++++++++++++++-------- panoptica/panoptic_evaluator.py | 32 +++++++++++++++++------ panoptica/panoptic_result.py | 22 ++++++++++++---- 4 files changed, 95 insertions(+), 30 deletions(-) diff --git a/panoptica/instance_matcher.py b/panoptica/instance_matcher.py index 4273fd5..3b32e63 100644 --- a/panoptica/instance_matcher.py +++ b/panoptica/instance_matcher.py @@ -80,7 +80,9 @@ def match_instances( return map_instance_labels(unmatched_instance_pair.copy(), instance_labelmap) -def map_instance_labels(processing_pair: UnmatchedInstancePair, labelmap: InstanceLabelMap) -> MatchedInstancePair: +def map_instance_labels( + processing_pair: UnmatchedInstancePair, labelmap: InstanceLabelMap +) -> MatchedInstancePair: """ Map instance labels based on the provided labelmap and create a MatchedInstancePair. @@ -192,13 +194,20 @@ def _match_instances( unmatched_instance_pair.prediction_arr, unmatched_instance_pair.reference_arr, ) - mm_pairs = _calc_matching_metric_of_overlapping_labels(pred_arr, ref_arr, ref_labels, matching_metric=self.matching_metric) + mm_pairs = _calc_matching_metric_of_overlapping_labels( + pred_arr, ref_arr, ref_labels, matching_metric=self.matching_metric + ) # Loop through matched instances to compute PQ components for matching_score, (ref_label, pred_label) in mm_pairs: - if labelmap.contains_or(pred_label, ref_label) and not self.allow_many_to_one: + if ( + labelmap.contains_or(pred_label, ref_label) + and not self.allow_many_to_one + ): continue # -> doesnt make speed difference - if self.matching_metric.score_beats_threshold(matching_score, self.matching_threshold): + if self.matching_metric.score_beats_threshold( + matching_score, self.matching_threshold + ): # Match found, increment true positive count and collect IoU and Dice values labelmap.add_labelmap_entry(pred_label, ref_label) # map label ref_idx to pred_idx @@ -275,11 +284,15 @@ def _match_instances( continue if labelmap.contains_ref(ref_label): pred_labels_ = labelmap.get_pred_labels_matched_to_ref(ref_label) - new_score = self.new_combination_score(pred_labels_, pred_label, ref_label, unmatched_instance_pair) + new_score = self.new_combination_score( + pred_labels_, pred_label, ref_label, unmatched_instance_pair + ) if new_score > score_ref[ref_label]: labelmap.add_labelmap_entry(pred_label, ref_label) score_ref[ref_label] = new_score - elif self.matching_metric.score_beats_threshold(matching_score, self.matching_threshold): + elif self.matching_metric.score_beats_threshold( + matching_score, self.matching_threshold + ): # Match found, increment true positive count and collect IoU and Dice values labelmap.add_labelmap_entry(pred_label, ref_label) score_ref[ref_label] = matching_score diff --git a/panoptica/metrics/metrics.py b/panoptica/metrics/metrics.py index a3d13bc..960e6af 100644 --- a/panoptica/metrics/metrics.py +++ b/panoptica/metrics/metrics.py @@ -37,7 +37,9 @@ def __call__( reference_arr = reference_arr.copy() == ref_instance_idx if isinstance(pred_instance_idx, int): pred_instance_idx = [pred_instance_idx] - prediction_arr = np.isin(prediction_arr.copy(), pred_instance_idx) # type:ignore + prediction_arr = np.isin( + prediction_arr.copy(), pred_instance_idx + ) # type:ignore return self._metric_function(reference_arr, prediction_arr, *args, **kwargs) def __eq__(self, __value: object) -> bool: @@ -61,8 +63,12 @@ def __hash__(self) -> int: def increasing(self): return not self.decreasing - def score_beats_threshold(self, matching_score: float, matching_threshold: float) -> bool: - return (self.increasing and matching_score >= matching_threshold) or (self.decreasing and matching_score <= matching_threshold) + def score_beats_threshold( + self, matching_score: float, matching_threshold: float + ) -> bool: + return (self.increasing and matching_score >= matching_threshold) or ( + self.decreasing and matching_score <= matching_threshold + ) class DirectValueMeta(EnumMeta): @@ -117,7 +123,9 @@ def __call__( **kwargs, ) - def score_beats_threshold(self, matching_score: float, matching_threshold: float) -> bool: + def score_beats_threshold( + self, matching_score: float, matching_threshold: float + ) -> bool: """Calculates whether a score beats a specified threshold Args: @@ -127,7 +135,9 @@ def score_beats_threshold(self, matching_score: float, matching_threshold: float Returns: bool: True if the matching_score beats the threshold, False otherwise. """ - return (self.increasing and matching_score >= matching_threshold) or (self.decreasing and matching_score <= matching_threshold) + return (self.increasing and matching_score >= matching_threshold) or ( + self.decreasing and matching_score <= matching_threshold + ) @property def name(self): @@ -222,7 +232,9 @@ def __call__(self, result_obj: "PanopticaResult") -> Any: # ERROR if self._error: if self._error_obj is None: - self._error_obj = MetricCouldNotBeComputedException(f"Metric {self.id} requested, but could not be computed") + self._error_obj = MetricCouldNotBeComputedException( + f"Metric {self.id} requested, but could not be computed" + ) raise self._error_obj # Already calculated? if self._was_calculated: @@ -230,8 +242,12 @@ def __call__(self, result_obj: "PanopticaResult") -> Any: # Calculate it try: - assert not self._was_calculated, f"Metric {self.id} was called to compute, but is set to have been already calculated" - assert self._calc_func is not None, f"Metric {self.id} was called to compute, but has no calculation function set" + assert ( + not self._was_calculated + ), f"Metric {self.id} was called to compute, but is set to have been already calculated" + assert ( + self._calc_func is not None + ), f"Metric {self.id} was called to compute, but has no calculation function set" value = self._calc_func(result_obj) except MetricCouldNotBeComputedException as e: value = e @@ -274,17 +290,25 @@ def __init__( else: self.AVG = None if self.ALL is None else np.average(self.ALL) self.SUM = None if self.ALL is None else np.sum(self.ALL) - self.STD = None if self.ALL is None else empty_list_std if len(self.ALL) == 0 else np.std(self.ALL) + self.STD = ( + None + if self.ALL is None + else empty_list_std if len(self.ALL) == 0 else np.std(self.ALL) + ) def __getitem__(self, mode: MetricMode | str): if self.error: - raise MetricCouldNotBeComputedException(f"Metric {self.id} has not been calculated, add it to your eval_metrics") + raise MetricCouldNotBeComputedException( + f"Metric {self.id} has not been calculated, add it to your eval_metrics" + ) if isinstance(mode, MetricMode): mode = mode.name if hasattr(self, mode): return getattr(self, mode) else: - raise MetricCouldNotBeComputedException(f"List_Metric {self.id} does not contain {mode} member") + raise MetricCouldNotBeComputedException( + f"List_Metric {self.id} does not contain {mode} member" + ) if __name__ == "__main__": diff --git a/panoptica/panoptic_evaluator.py b/panoptica/panoptic_evaluator.py index 744f61d..1daeda9 100644 --- a/panoptica/panoptic_evaluator.py +++ b/panoptica/panoptic_evaluator.py @@ -20,7 +20,9 @@ class Panoptic_Evaluator: def __init__( self, - expected_input: Type[SemanticPair] | Type[UnmatchedInstancePair] | Type[MatchedInstancePair] = MatchedInstancePair, + expected_input: ( + Type[SemanticPair] | Type[UnmatchedInstancePair] | Type[MatchedInstancePair] + ) = MatchedInstancePair, instance_approximator: InstanceApproximator | None = None, instance_matcher: InstanceMatchingAlgorithm | None = None, edge_case_handler: EdgeCaseHandler | None = None, @@ -46,9 +48,13 @@ def __init__( self.__decision_metric = decision_metric self.__decision_threshold = decision_threshold - self.__edge_case_handler = edge_case_handler if edge_case_handler is not None else EdgeCaseHandler() + self.__edge_case_handler = ( + edge_case_handler if edge_case_handler is not None else EdgeCaseHandler() + ) if self.__decision_metric is not None: - assert self.__decision_threshold is not None, "decision metric set but no decision threshold for it" + assert ( + self.__decision_threshold is not None + ), "decision metric set but no decision threshold for it" # self.__log_times = log_times self.__verbose = verbose @@ -57,11 +63,15 @@ def __init__( @measure_time def evaluate( self, - processing_pair: SemanticPair | UnmatchedInstancePair | MatchedInstancePair | PanopticaResult, + processing_pair: ( + SemanticPair | UnmatchedInstancePair | MatchedInstancePair | PanopticaResult + ), result_all: bool = True, verbose: bool | None = None, ) -> tuple[PanopticaResult, dict[str, _ProcessingPair]]: - assert type(processing_pair) == self.__expected_input, f"input not of expected type {self.__expected_input}" + assert ( + type(processing_pair) == self.__expected_input + ), f"input not of expected type {self.__expected_input}" return panoptic_evaluate( processing_pair=processing_pair, edge_case_handler=self.__edge_case_handler, @@ -77,7 +87,9 @@ def evaluate( def panoptic_evaluate( - processing_pair: SemanticPair | UnmatchedInstancePair | MatchedInstancePair | PanopticaResult, + processing_pair: ( + SemanticPair | UnmatchedInstancePair | MatchedInstancePair | PanopticaResult + ), instance_approximator: InstanceApproximator | None = None, instance_matcher: InstanceMatchingAlgorithm | None = None, eval_metrics: list[Metric] = [Metric.DSC, Metric.IOU, Metric.ASSD], @@ -130,7 +142,9 @@ def panoptic_evaluate( processing_pair.crop_data() if isinstance(processing_pair, SemanticPair): - assert instance_approximator is not None, "Got SemanticPair but not InstanceApproximator" + assert ( + instance_approximator is not None + ), "Got SemanticPair but not InstanceApproximator" print("-- Got SemanticPair, will approximate instances") start = perf_counter() processing_pair = instance_approximator.approximate_instances(processing_pair) @@ -148,7 +162,9 @@ def panoptic_evaluate( if isinstance(processing_pair, UnmatchedInstancePair): print("-- Got UnmatchedInstancePair, will match instances") - assert instance_matcher is not None, "Got UnmatchedInstancePair but not InstanceMatchingAlgorithm" + assert ( + instance_matcher is not None + ), "Got UnmatchedInstancePair but not InstanceMatchingAlgorithm" start = perf_counter() processing_pair = instance_matcher.match_instances( processing_pair, diff --git a/panoptica/panoptic_result.py b/panoptica/panoptic_result.py index 2e3076f..0abbba9 100644 --- a/panoptica/panoptic_result.py +++ b/panoptica/panoptic_result.py @@ -244,7 +244,9 @@ def __init__( num_pred_instances=self.num_pred_instances, num_ref_instances=self.num_ref_instances, ) - self._list_metrics[k] = Evaluation_List_Metric(k, empty_list_std, v, is_edge_case, edge_case_result) + self._list_metrics[k] = Evaluation_List_Metric( + k, empty_list_std, v, is_edge_case, edge_case_result + ) def _add_metric( self, @@ -313,13 +315,19 @@ def __str__(self) -> str: return text def to_dict(self) -> dict: - return {k: getattr(self, v.id) for k, v in self._evaluation_metrics.items() if (v._error == False and v._was_calculated)} + return { + k: getattr(self, v.id) + for k, v in self._evaluation_metrics.items() + if (v._error == False and v._was_calculated) + } def get_list_metric(self, metric: Metric, mode: MetricMode): if metric in self._list_metrics: return self._list_metrics[metric][mode] else: - raise MetricCouldNotBeComputedException(f"{metric} could not be found, have you set it in eval_metrics during evaluation?") + raise MetricCouldNotBeComputedException( + f"{metric} could not be found, have you set it in eval_metrics during evaluation?" + ) def _calc_metric(self, metric_name: str, supress_error: bool = False): if metric_name in self._evaluation_metrics: @@ -335,7 +343,9 @@ def _calc_metric(self, metric_name: str, supress_error: bool = False): self._evaluation_metrics[metric_name]._was_calculated = True return value else: - raise MetricCouldNotBeComputedException(f"could not find metric with name {metric_name}") + raise MetricCouldNotBeComputedException( + f"could not find metric with name {metric_name}" + ) def __getattribute__(self, __name: str) -> Any: attr = None @@ -348,7 +358,9 @@ def __getattribute__(self, __name: str) -> Any: raise e if attr is None: if self._evaluation_metrics[__name]._error: - raise MetricCouldNotBeComputedException(f"Requested metric {__name} that could not be computed") + raise MetricCouldNotBeComputedException( + f"Requested metric {__name} that could not be computed" + ) elif not self._evaluation_metrics[__name]._was_calculated: value = self._calc_metric(__name) setattr(self, __name, value)