Skip to content

Commit

Permalink
global binary metrics are now using a function creator, as they have …
Browse files Browse the repository at this point in the history
…an underlying pattern. They use the edge case handler now. to allow for maximum flexibility, added a EvaluateInstancePair class which allows some functions to be easier. This way, one could jump from an instance approximation algorithm directly to results. Also added global_metrics as argument so users can decide which global metrics should be calculated, default set to Dice (DSC). Additionally, renamed eval_metrics to instance_metrics to distinguish it better to the global_metrics argument
  • Loading branch information
Hendrik-code committed Aug 7, 2024
1 parent 2f7d01f commit 735397e
Show file tree
Hide file tree
Showing 12 changed files with 182 additions and 289 deletions.
2 changes: 1 addition & 1 deletion examples/example_spine_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

evaluator = Panoptica_Evaluator(
expected_input=InputType.MATCHED_INSTANCE,
eval_metrics=[Metric.DSC, Metric.IOU],
instance_metrics=[Metric.DSC, Metric.IOU],
segmentation_class_groups=SegmentationClassGroups(
{
"vertebra": LabelGroup([i for i in range(1, 10)]),
Expand Down
4 changes: 1 addition & 3 deletions examples/example_spine_instance_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@
reference_mask = read_nifti(directory + "/spine_seg/matched_instance/ref.nii.gz")
prediction_mask = read_nifti(directory + "/spine_seg/matched_instance/pred.nii.gz")

evaluator = Panoptica_Evaluator.load_from_config_name(
"panoptica_evaluator_unmatched_instance"
)
evaluator = Panoptica_Evaluator.load_from_config_name("panoptica_evaluator_unmatched_instance")


with cProfile.Profile() as pr:
Expand Down
4 changes: 1 addition & 3 deletions examples/example_spine_semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,7 @@

with cProfile.Profile() as pr:
if __name__ == "__main__":
result, debug_data = evaluator.evaluate(prediction_mask, reference_mask)[
"ungrouped"
]
result, debug_data = evaluator.evaluate(prediction_mask, reference_mask)["ungrouped"]
print(result)

pr.dump_stats(directory + "/semantic_example.log")
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ edge_case_handler: !EdgeCaseHandler
!Metric RVD: !MetricZeroTPEdgeCaseHandling {empty_prediction_result: !EdgeCaseResult NAN,
empty_reference_result: !EdgeCaseResult NAN, no_instances_result: !EdgeCaseResult NAN,
normal: !EdgeCaseResult NAN}
eval_metrics: [!Metric DSC, !Metric IOU]
instance_metrics: [!Metric DSC, !Metric IOU]
global_metrics: [!Metric DSC, !Metric RVD]
expected_input: !InputType UNMATCHED_INSTANCE
instance_approximator: null
instance_matcher: !NaiveThresholdMatching {allow_many_to_one: false, matching_metric: !Metric IOU,
Expand Down
31 changes: 7 additions & 24 deletions panoptica/instance_evaluator.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,17 @@
from multiprocessing import Pool

import numpy as np

from panoptica.metrics import Metric
from panoptica.panoptica_result import PanopticaResult
from panoptica.utils import EdgeCaseHandler
from panoptica.utils.processing_pair import MatchedInstancePair
from panoptica.utils.processing_pair import MatchedInstancePair, EvaluateInstancePair


def evaluate_matched_instance(
matched_instance_pair: MatchedInstancePair,
eval_metrics: list[Metric] = [Metric.DSC, Metric.IOU, Metric.ASSD],
decision_metric: Metric | None = Metric.IOU,
decision_threshold: float | None = None,
edge_case_handler: EdgeCaseHandler | None = None,
**kwargs,
) -> PanopticaResult:
) -> EvaluateInstancePair:
"""
Map instance labels based on the provided labelmap and create a MatchedInstancePair.
Expand All @@ -31,12 +27,8 @@ def evaluate_matched_instance(
>>> labelmap = [([1, 2], [3, 4]), ([5], [6])]
>>> result = map_instance_labels(unmatched_instance_pair, labelmap)
"""
if edge_case_handler is None:
edge_case_handler = EdgeCaseHandler()
if decision_metric is not None:
assert decision_metric.name in [
v.name for v in eval_metrics
], "decision metric not contained in eval_metrics"
assert decision_metric.name in [v.name for v in eval_metrics], "decision metric not contained in eval_metrics"
assert decision_threshold is not None, "decision metric set but no threshold"
# Initialize variables for True Positives (tp)
tp = len(matched_instance_pair.matched_instances)
Expand All @@ -48,34 +40,25 @@ def evaluate_matched_instance(
)
ref_matched_labels = matched_instance_pair.matched_instances

instance_pairs = [
(reference_arr, prediction_arr, ref_idx, eval_metrics)
for ref_idx in ref_matched_labels
]
instance_pairs = [(reference_arr, prediction_arr, ref_idx, eval_metrics) for ref_idx in ref_matched_labels]
with Pool() as pool:
metric_dicts: list[dict[Metric, float]] = pool.starmap(
_evaluate_instance, instance_pairs
)
metric_dicts: list[dict[Metric, float]] = pool.starmap(_evaluate_instance, instance_pairs)

for metric_dict in metric_dicts:
if decision_metric is None or (
decision_threshold is not None
and decision_metric.score_beats_threshold(
metric_dict[decision_metric], decision_threshold
)
decision_threshold is not None and decision_metric.score_beats_threshold(metric_dict[decision_metric], decision_threshold)
):
for k, v in metric_dict.items():
score_dict[k].append(v)

# Create and return the PanopticaResult object with computed metrics
return PanopticaResult(
return EvaluateInstancePair(
reference_arr=matched_instance_pair.reference_arr,
prediction_arr=matched_instance_pair.prediction_arr,
num_pred_instances=matched_instance_pair.n_prediction_instance,
num_ref_instances=matched_instance_pair.n_reference_instance,
tp=tp,
list_metrics=score_dict,
edge_case_handler=edge_case_handler,
)


Expand Down
67 changes: 20 additions & 47 deletions panoptica/metrics/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class _Metric:
"""A Metric class containing a name, whether higher or lower values is better, and a function to calculate that metric between two instances in an array"""

name: str
long_name: str
decreasing: bool
_metric_function: Callable

Expand All @@ -39,9 +40,7 @@ def __call__(
reference_arr = reference_arr.copy() == ref_instance_idx
if isinstance(pred_instance_idx, int):
pred_instance_idx = [pred_instance_idx]
prediction_arr = np.isin(
prediction_arr.copy(), pred_instance_idx
) # type:ignore
prediction_arr = np.isin(prediction_arr.copy(), pred_instance_idx) # type:ignore
return self._metric_function(reference_arr, prediction_arr, *args, **kwargs)

def __eq__(self, __value: object) -> bool:
Expand All @@ -65,12 +64,8 @@ def __hash__(self) -> int:
def increasing(self):
return not self.decreasing

def score_beats_threshold(
self, matching_score: float, matching_threshold: float
) -> bool:
return (self.increasing and matching_score >= matching_threshold) or (
self.decreasing and matching_score <= matching_threshold
)
def score_beats_threshold(self, matching_score: float, matching_threshold: float) -> bool:
return (self.increasing and matching_score >= matching_threshold) or (self.decreasing and matching_score <= matching_threshold)


class DirectValueMeta(EnumMeta):
Expand All @@ -91,11 +86,11 @@ class Metric(_Enum_Compare):
_type_: _description_
"""

DSC = _Metric("DSC", False, _compute_instance_volumetric_dice)
IOU = _Metric("IOU", False, _compute_instance_iou)
ASSD = _Metric("ASSD", True, _compute_instance_average_symmetric_surface_distance)
clDSC = _Metric("clDSC", False, _compute_centerline_dice)
RVD = _Metric("RVD", True, _compute_instance_relative_volume_difference)
DSC = _Metric("DSC", "Dice", False, _compute_instance_volumetric_dice)
IOU = _Metric("IOU", "Intersection over Union", False, _compute_instance_iou)
ASSD = _Metric("ASSD", "Average Symmetric Surface Distance", True, _compute_instance_average_symmetric_surface_distance)
clDSC = _Metric("clDSC", "Centerline Dice", False, _compute_centerline_dice)
RVD = _Metric("RVD", "Relative Volume Difference", True, _compute_instance_relative_volume_difference)
# ST = _Metric("ST", False, _compute_instance_segmentation_tendency)

def __call__(
Expand Down Expand Up @@ -127,9 +122,7 @@ def __call__(
**kwargs,
)

def score_beats_threshold(
self, matching_score: float, matching_threshold: float
) -> bool:
def score_beats_threshold(self, matching_score: float, matching_threshold: float) -> bool:
"""Calculates whether a score beats a specified threshold
Args:
Expand All @@ -139,9 +132,7 @@ def score_beats_threshold(
Returns:
bool: True if the matching_score beats the threshold, False otherwise.
"""
return (self.increasing and matching_score >= matching_threshold) or (
self.decreasing and matching_score <= matching_threshold
)
return (self.increasing and matching_score >= matching_threshold) or (self.decreasing and matching_score <= matching_threshold)

@property
def name(self):
Expand Down Expand Up @@ -238,22 +229,16 @@ def __call__(self, result_obj: "PanopticaResult") -> Any:
# ERROR
if self._error:
if self._error_obj is None:
self._error_obj = MetricCouldNotBeComputedException(
f"Metric {self.id} requested, but could not be computed"
)
self._error_obj = MetricCouldNotBeComputedException(f"Metric {self.id} requested, but could not be computed")
raise self._error_obj
# Already calculated?
if self._was_calculated:
return self._value

# Calculate it
try:
assert (
not self._was_calculated
), f"Metric {self.id} was called to compute, but is set to have been already calculated"
assert (
self._calc_func is not None
), f"Metric {self.id} was called to compute, but has no calculation function set"
assert not self._was_calculated, f"Metric {self.id} was called to compute, but is set to have been already calculated"
assert self._calc_func is not None, f"Metric {self.id} was called to compute, but has no calculation function set"
value = self._calc_func(result_obj)
except MetricCouldNotBeComputedException as e:
value = e
Expand Down Expand Up @@ -298,32 +283,20 @@ def __init__(
else:
self.AVG = None if self.ALL is None else np.average(self.ALL)
self.SUM = None if self.ALL is None else np.sum(self.ALL)
self.MIN = (
None if self.ALL is None or len(self.ALL) == 0 else np.min(self.ALL)
)
self.MAX = (
None if self.ALL is None or len(self.ALL) == 0 else np.max(self.ALL)
)

self.STD = (
None
if self.ALL is None
else empty_list_std if len(self.ALL) == 0 else np.std(self.ALL)
)
self.MIN = None if self.ALL is None or len(self.ALL) == 0 else np.min(self.ALL)
self.MAX = None if self.ALL is None or len(self.ALL) == 0 else np.max(self.ALL)

self.STD = None if self.ALL is None else empty_list_std if len(self.ALL) == 0 else np.std(self.ALL)

def __getitem__(self, mode: MetricMode | str):
if self.error:
raise MetricCouldNotBeComputedException(
f"Metric {self.id} has not been calculated, add it to your eval_metrics"
)
raise MetricCouldNotBeComputedException(f"Metric {self.id} has not been calculated, add it to your eval_metrics")
if isinstance(mode, MetricMode):
mode = mode.name
if hasattr(self, mode):
return getattr(self, mode)
else:
raise MetricCouldNotBeComputedException(
f"List_Metric {self.id} does not contain {mode} member"
)
raise MetricCouldNotBeComputedException(f"List_Metric {self.id} does not contain {mode} member")


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit 735397e

Please sign in to comment.