Skip to content

Commit

Permalink
Autoformat with black
Browse files Browse the repository at this point in the history
  • Loading branch information
brainless-bot[bot] committed Aug 7, 2024
1 parent 735397e commit 5808e46
Show file tree
Hide file tree
Showing 9 changed files with 205 additions and 60 deletions.
4 changes: 3 additions & 1 deletion examples/example_spine_instance_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
reference_mask = read_nifti(directory + "/spine_seg/matched_instance/ref.nii.gz")
prediction_mask = read_nifti(directory + "/spine_seg/matched_instance/pred.nii.gz")

evaluator = Panoptica_Evaluator.load_from_config_name("panoptica_evaluator_unmatched_instance")
evaluator = Panoptica_Evaluator.load_from_config_name(
"panoptica_evaluator_unmatched_instance"
)


with cProfile.Profile() as pr:
Expand Down
4 changes: 3 additions & 1 deletion examples/example_spine_semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@

with cProfile.Profile() as pr:
if __name__ == "__main__":
result, debug_data = evaluator.evaluate(prediction_mask, reference_mask)["ungrouped"]
result, debug_data = evaluator.evaluate(prediction_mask, reference_mask)[
"ungrouped"
]
print(result)

pr.dump_stats(directory + "/semantic_example.log")
18 changes: 14 additions & 4 deletions panoptica/instance_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ def evaluate_matched_instance(
>>> result = map_instance_labels(unmatched_instance_pair, labelmap)
"""
if decision_metric is not None:
assert decision_metric.name in [v.name for v in eval_metrics], "decision metric not contained in eval_metrics"
assert decision_metric.name in [
v.name for v in eval_metrics
], "decision metric not contained in eval_metrics"
assert decision_threshold is not None, "decision metric set but no threshold"
# Initialize variables for True Positives (tp)
tp = len(matched_instance_pair.matched_instances)
Expand All @@ -40,13 +42,21 @@ def evaluate_matched_instance(
)
ref_matched_labels = matched_instance_pair.matched_instances

instance_pairs = [(reference_arr, prediction_arr, ref_idx, eval_metrics) for ref_idx in ref_matched_labels]
instance_pairs = [
(reference_arr, prediction_arr, ref_idx, eval_metrics)
for ref_idx in ref_matched_labels
]
with Pool() as pool:
metric_dicts: list[dict[Metric, float]] = pool.starmap(_evaluate_instance, instance_pairs)
metric_dicts: list[dict[Metric, float]] = pool.starmap(
_evaluate_instance, instance_pairs
)

for metric_dict in metric_dicts:
if decision_metric is None or (
decision_threshold is not None and decision_metric.score_beats_threshold(metric_dict[decision_metric], decision_threshold)
decision_threshold is not None
and decision_metric.score_beats_threshold(
metric_dict[decision_metric], decision_threshold
)
):
for k, v in metric_dict.items():
score_dict[k].append(v)
Expand Down
70 changes: 54 additions & 16 deletions panoptica/metrics/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ def __call__(
reference_arr = reference_arr.copy() == ref_instance_idx
if isinstance(pred_instance_idx, int):
pred_instance_idx = [pred_instance_idx]
prediction_arr = np.isin(prediction_arr.copy(), pred_instance_idx) # type:ignore
prediction_arr = np.isin(
prediction_arr.copy(), pred_instance_idx
) # type:ignore
return self._metric_function(reference_arr, prediction_arr, *args, **kwargs)

def __eq__(self, __value: object) -> bool:
Expand All @@ -64,8 +66,12 @@ def __hash__(self) -> int:
def increasing(self):
return not self.decreasing

def score_beats_threshold(self, matching_score: float, matching_threshold: float) -> bool:
return (self.increasing and matching_score >= matching_threshold) or (self.decreasing and matching_score <= matching_threshold)
def score_beats_threshold(
self, matching_score: float, matching_threshold: float
) -> bool:
return (self.increasing and matching_score >= matching_threshold) or (
self.decreasing and matching_score <= matching_threshold
)


class DirectValueMeta(EnumMeta):
Expand All @@ -88,9 +94,19 @@ class Metric(_Enum_Compare):

DSC = _Metric("DSC", "Dice", False, _compute_instance_volumetric_dice)
IOU = _Metric("IOU", "Intersection over Union", False, _compute_instance_iou)
ASSD = _Metric("ASSD", "Average Symmetric Surface Distance", True, _compute_instance_average_symmetric_surface_distance)
ASSD = _Metric(
"ASSD",
"Average Symmetric Surface Distance",
True,
_compute_instance_average_symmetric_surface_distance,
)
clDSC = _Metric("clDSC", "Centerline Dice", False, _compute_centerline_dice)
RVD = _Metric("RVD", "Relative Volume Difference", True, _compute_instance_relative_volume_difference)
RVD = _Metric(
"RVD",
"Relative Volume Difference",
True,
_compute_instance_relative_volume_difference,
)
# ST = _Metric("ST", False, _compute_instance_segmentation_tendency)

def __call__(
Expand Down Expand Up @@ -122,7 +138,9 @@ def __call__(
**kwargs,
)

def score_beats_threshold(self, matching_score: float, matching_threshold: float) -> bool:
def score_beats_threshold(
self, matching_score: float, matching_threshold: float
) -> bool:
"""Calculates whether a score beats a specified threshold
Args:
Expand All @@ -132,7 +150,9 @@ def score_beats_threshold(self, matching_score: float, matching_threshold: float
Returns:
bool: True if the matching_score beats the threshold, False otherwise.
"""
return (self.increasing and matching_score >= matching_threshold) or (self.decreasing and matching_score <= matching_threshold)
return (self.increasing and matching_score >= matching_threshold) or (
self.decreasing and matching_score <= matching_threshold
)

@property
def name(self):
Expand Down Expand Up @@ -229,16 +249,22 @@ def __call__(self, result_obj: "PanopticaResult") -> Any:
# ERROR
if self._error:
if self._error_obj is None:
self._error_obj = MetricCouldNotBeComputedException(f"Metric {self.id} requested, but could not be computed")
self._error_obj = MetricCouldNotBeComputedException(
f"Metric {self.id} requested, but could not be computed"
)
raise self._error_obj
# Already calculated?
if self._was_calculated:
return self._value

# Calculate it
try:
assert not self._was_calculated, f"Metric {self.id} was called to compute, but is set to have been already calculated"
assert self._calc_func is not None, f"Metric {self.id} was called to compute, but has no calculation function set"
assert (
not self._was_calculated
), f"Metric {self.id} was called to compute, but is set to have been already calculated"
assert (
self._calc_func is not None
), f"Metric {self.id} was called to compute, but has no calculation function set"
value = self._calc_func(result_obj)
except MetricCouldNotBeComputedException as e:
value = e
Expand Down Expand Up @@ -283,20 +309,32 @@ def __init__(
else:
self.AVG = None if self.ALL is None else np.average(self.ALL)
self.SUM = None if self.ALL is None else np.sum(self.ALL)
self.MIN = None if self.ALL is None or len(self.ALL) == 0 else np.min(self.ALL)
self.MAX = None if self.ALL is None or len(self.ALL) == 0 else np.max(self.ALL)

self.STD = None if self.ALL is None else empty_list_std if len(self.ALL) == 0 else np.std(self.ALL)
self.MIN = (
None if self.ALL is None or len(self.ALL) == 0 else np.min(self.ALL)
)
self.MAX = (
None if self.ALL is None or len(self.ALL) == 0 else np.max(self.ALL)
)

self.STD = (
None
if self.ALL is None
else empty_list_std if len(self.ALL) == 0 else np.std(self.ALL)
)

def __getitem__(self, mode: MetricMode | str):
if self.error:
raise MetricCouldNotBeComputedException(f"Metric {self.id} has not been calculated, add it to your eval_metrics")
raise MetricCouldNotBeComputedException(
f"Metric {self.id} has not been calculated, add it to your eval_metrics"
)
if isinstance(mode, MetricMode):
mode = mode.name
if hasattr(self, mode):
return getattr(self, mode)
else:
raise MetricCouldNotBeComputedException(f"List_Metric {self.id} does not contain {mode} member")
raise MetricCouldNotBeComputedException(
f"List_Metric {self.id} does not contain {mode} member"
)


if __name__ == "__main__":
Expand Down
43 changes: 33 additions & 10 deletions panoptica/panoptica_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,12 @@ def __init__(
instance_matcher: InstanceMatchingAlgorithm | None = None,
edge_case_handler: EdgeCaseHandler | None = None,
segmentation_class_groups: SegmentationClassGroups | None = None,
instance_metrics: list[Metric] = [Metric.DSC, Metric.IOU, Metric.ASSD, Metric.RVD],
instance_metrics: list[Metric] = [
Metric.DSC,
Metric.IOU,
Metric.ASSD,
Metric.RVD,
],
global_metrics: list[Metric] = [Metric.DSC],
decision_metric: Metric | None = None,
decision_threshold: float | None = None,
Expand Down Expand Up @@ -65,9 +70,13 @@ def __init__(

self.__segmentation_class_groups = segmentation_class_groups

self.__edge_case_handler = edge_case_handler if edge_case_handler is not None else EdgeCaseHandler()
self.__edge_case_handler = (
edge_case_handler if edge_case_handler is not None else EdgeCaseHandler()
)
if self.__decision_metric is not None:
assert self.__decision_threshold is not None, "decision metric set but no decision threshold for it"
assert (
self.__decision_threshold is not None
), "decision metric set but no decision threshold for it"
#
self.__log_times = log_times
self.__verbose = verbose
Expand Down Expand Up @@ -98,7 +107,9 @@ def evaluate(
verbose: bool | None = None,
) -> dict[str, tuple[PanopticaResult, dict[str, _ProcessingPair]]]:
processing_pair = self.__expected_input(prediction_arr, reference_arr)
assert isinstance(processing_pair, self.__expected_input.value), f"input not of expected type {self.__expected_input}"
assert isinstance(
processing_pair, self.__expected_input.value
), f"input not of expected type {self.__expected_input}"

if self.__segmentation_class_groups is None:
return {
Expand All @@ -118,8 +129,12 @@ def evaluate(
)
}

self.__segmentation_class_groups.has_defined_labels_for(processing_pair.prediction_arr, raise_error=True)
self.__segmentation_class_groups.has_defined_labels_for(processing_pair.reference_arr, raise_error=True)
self.__segmentation_class_groups.has_defined_labels_for(
processing_pair.prediction_arr, raise_error=True
)
self.__segmentation_class_groups.has_defined_labels_for(
processing_pair.reference_arr, raise_error=True
)

result_grouped = {}
for group_name, label_group in self.__segmentation_class_groups.items():
Expand All @@ -131,7 +146,9 @@ def evaluate(
single_instance_mode = label_group.single_instance
processing_pair_grouped = processing_pair.__class__(prediction_arr=prediction_arr_grouped, reference_arr=reference_arr_grouped) # type: ignore
decision_threshold = self.__decision_threshold
if single_instance_mode and not isinstance(processing_pair, MatchedInstancePair):
if single_instance_mode and not isinstance(
processing_pair, MatchedInstancePair
):
processing_pair_grouped = MatchedInstancePair(
prediction_arr=processing_pair_grouped.prediction_arr,
reference_arr=processing_pair_grouped.reference_arr,
Expand All @@ -156,7 +173,9 @@ def evaluate(


def panoptic_evaluate(
processing_pair: SemanticPair | UnmatchedInstancePair | MatchedInstancePair | PanopticaResult,
processing_pair: (
SemanticPair | UnmatchedInstancePair | MatchedInstancePair | PanopticaResult
),
instance_approximator: InstanceApproximator | None = None,
instance_matcher: InstanceMatchingAlgorithm | None = None,
instance_metrics: list[Metric] = [Metric.DSC, Metric.IOU, Metric.ASSD],
Expand Down Expand Up @@ -213,7 +232,9 @@ def panoptic_evaluate(
processing_pair.crop_data()

if isinstance(processing_pair, SemanticPair):
assert instance_approximator is not None, "Got SemanticPair but not InstanceApproximator"
assert (
instance_approximator is not None
), "Got SemanticPair but not InstanceApproximator"
if verbose:
print("-- Got SemanticPair, will approximate instances")
start = perf_counter()
Expand All @@ -234,7 +255,9 @@ def panoptic_evaluate(
if isinstance(processing_pair, UnmatchedInstancePair):
if verbose:
print("-- Got UnmatchedInstancePair, will match instances")
assert instance_matcher is not None, "Got UnmatchedInstancePair but not InstanceMatchingAlgorithm"
assert (
instance_matcher is not None
), "Got UnmatchedInstancePair but not InstanceMatchingAlgorithm"
start = perf_counter()
processing_pair = instance_matcher.match_instances(
processing_pair,
Expand Down
26 changes: 20 additions & 6 deletions panoptica/panoptica_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,9 @@ def __init__(
num_pred_instances=self.num_pred_instances,
num_ref_instances=self.num_ref_instances,
)
self._list_metrics[m] = Evaluation_List_Metric(m, empty_list_std, list_metrics[m], is_edge_case, edge_case_result)
self._list_metrics[m] = Evaluation_List_Metric(
m, empty_list_std, list_metrics[m], is_edge_case, edge_case_result
)
# even if not available, set the global vars
self._add_metric(
f"global_bin_{m.name.lower()}",
Expand Down Expand Up @@ -327,13 +329,19 @@ def __str__(self) -> str:
return text

def to_dict(self) -> dict:
return {k: getattr(self, v.id) for k, v in self._evaluation_metrics.items() if (v._error == False and v._was_calculated)}
return {
k: getattr(self, v.id)
for k, v in self._evaluation_metrics.items()
if (v._error == False and v._was_calculated)
}

def get_list_metric(self, metric: Metric, mode: MetricMode):
if metric in self._list_metrics:
return self._list_metrics[metric][mode]
else:
raise MetricCouldNotBeComputedException(f"{metric} could not be found, have you set it in eval_metrics during evaluation?")
raise MetricCouldNotBeComputedException(
f"{metric} could not be found, have you set it in eval_metrics during evaluation?"
)

def _calc_metric(self, metric_name: str, supress_error: bool = False):
if metric_name in self._evaluation_metrics:
Expand All @@ -349,7 +357,9 @@ def _calc_metric(self, metric_name: str, supress_error: bool = False):
self._evaluation_metrics[metric_name]._was_calculated = True
return value
else:
raise MetricCouldNotBeComputedException(f"could not find metric with name {metric_name}")
raise MetricCouldNotBeComputedException(
f"could not find metric with name {metric_name}"
)

def __getattribute__(self, __name: str) -> Any:
attr = None
Expand All @@ -362,7 +372,9 @@ def __getattribute__(self, __name: str) -> Any:
raise e
if attr is None:
if self._evaluation_metrics[__name]._error:
raise MetricCouldNotBeComputedException(f"Requested metric {__name} that could not be computed")
raise MetricCouldNotBeComputedException(
f"Requested metric {__name} that could not be computed"
)
elif not self._evaluation_metrics[__name]._was_calculated:
value = self._calc_metric(__name)
setattr(self, __name, value)
Expand Down Expand Up @@ -488,7 +500,9 @@ def function_template(res: PanopticaResult):
if metric not in res._global_metrics:
raise MetricCouldNotBeComputedException(f"Global Metric {metric} not set")
if res.tp == 0:
is_edgecase, result = res._edge_case_handler.handle_zero_tp(metric, res.tp, res.num_pred_instances, res.num_ref_instances)
is_edgecase, result = res._edge_case_handler.handle_zero_tp(
metric, res.tp, res.num_pred_instances, res.num_ref_instances
)
if is_edgecase:
return result
pred_binary = res._prediction_arr.copy()
Expand Down
Loading

0 comments on commit 5808e46

Please sign in to comment.