Skip to content

Commit

Permalink
fixed #160 by embedding the intermediate_steps_data into the Panoptic…
Browse files Browse the repository at this point in the history
…aResult object
  • Loading branch information
Hendrik-code committed Dec 17, 2024
1 parent e6d64d2 commit c1fdb83
Show file tree
Hide file tree
Showing 7 changed files with 37 additions and 38 deletions.
2 changes: 1 addition & 1 deletion examples/example_spine_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
def main():
with cProfile.Profile() as pr:
results = evaluator.evaluate(prediction_mask, reference_mask, verbose=False)
for groupname, (result, intermediate_steps_data) in results.items():
for groupname, result in results.items():
print()
print("### Group", groupname)
print(result)
Expand Down
2 changes: 1 addition & 1 deletion examples/example_spine_instance_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
def main():
with cProfile.Profile() as pr:
results = evaluator.evaluate(prediction_mask, reference_mask, verbose=False)
for groupname, (result, intermediate_steps_data) in results.items():
for groupname, result in results.items():
print()
print("### Group", groupname)
print(result)
Expand Down
6 changes: 3 additions & 3 deletions examples/example_spine_semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,13 @@

def main():
with cProfile.Profile() as pr:
result, intermediate_steps_data = evaluator.evaluate(
prediction_mask, reference_mask
)["ungrouped"]
result = evaluator.evaluate(prediction_mask, reference_mask)["ungrouped"]

# To print the results, just call print
print(result)

intermediate_steps_data = result.intermediate_steps_data
assert intermediate_steps_data is not None
# To get the different intermediate arrays, just use the second returned object
intermediate_steps_data.original_prediction_arr # Input prediction array, untouched
intermediate_steps_data.original_reference_arr # Input reference array, untouched
Expand Down
2 changes: 1 addition & 1 deletion panoptica/panoptica_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def _save_one_subject(self, subject_name, result_grouped):
#
content = [subject_name]
for groupname in self.__class_group_names:
result: PanopticaResult = result_grouped[groupname][0]
result: PanopticaResult = result_grouped[groupname]
result_dict = result.to_dict()
if result.computation_time is not None:
result_dict[COMPUTATION_TIME_KEY] = result.computation_time
Expand Down
28 changes: 12 additions & 16 deletions panoptica/panoptica_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from panoptica.instance_approximator import InstanceApproximator
from panoptica.instance_evaluator import evaluate_matched_instance
from panoptica.instance_matcher import InstanceMatchingAlgorithm
from panoptica.metrics import Metric, _Metric
from panoptica.metrics import Metric
from panoptica.panoptica_result import PanopticaResult
from panoptica.utils.timing import measure_time
from panoptica.utils import EdgeCaseHandler
Expand All @@ -12,7 +12,6 @@
MatchedInstancePair,
SemanticPair,
UnmatchedInstancePair,
_ProcessingPair,
InputType,
EvaluateInstancePair,
IntermediateStepsData,
Expand Down Expand Up @@ -121,7 +120,7 @@ def evaluate(
save_group_times: bool | None = None,
log_times: bool | None = None,
verbose: bool | None = None,
) -> dict[str, tuple[PanopticaResult, IntermediateStepsData]]:
) -> dict[str, PanopticaResult]:
processing_pair = self.__expected_input(prediction_arr, reference_arr)
assert isinstance(
processing_pair, self.__expected_input.value
Expand All @@ -134,21 +133,17 @@ def evaluate(
processing_pair.reference_arr, raise_error=True
)

result_grouped: dict[str, tuple[PanopticaResult, IntermediateStepsData]] = {}
result_grouped: dict[str, PanopticaResult] = {}
for group_name, label_group in self.__segmentation_class_groups.items():
result_grouped[group_name] = self._evaluate_group(
group_name,
label_group,
processing_pair,
result_all,
save_group_times=(
self.__save_group_times
if save_group_times is None
else save_group_times
),
save_group_times=(self.__save_group_times if save_group_times is None else save_group_times),
log_times=log_times,
verbose=verbose,
)[1:]
)
return result_grouped

@property
Expand All @@ -170,7 +165,7 @@ def resulting_metric_keys(self) -> list[str]:
dummy_input = MatchedInstancePair(
np.ones((1, 1, 1), dtype=np.uint8), np.ones((1, 1, 1), dtype=np.uint8)
)
_, res, _ = self._evaluate_group(
res = self._evaluate_group(
group_name="",
label_group=LabelGroup(1, single_instance=False),
processing_pair=dummy_input,
Expand All @@ -192,7 +187,7 @@ def _evaluate_group(
verbose: bool | None = None,
log_times: bool | None = None,
save_group_times: bool = False,
):
) -> PanopticaResult:
assert isinstance(label_group, LabelGroup)
if self.__save_group_times:
start_time = perf_counter()
Expand All @@ -212,7 +207,7 @@ def _evaluate_group(
)
decision_threshold = 0.0

result, intermediate_steps_data = panoptic_evaluate(
result = panoptic_evaluate(
input_pair=processing_pair_grouped,
edge_case_handler=self.__edge_case_handler,
instance_approximator=self.__instance_approximator,
Expand All @@ -229,7 +224,7 @@ def _evaluate_group(
if save_group_times:
duration = perf_counter() - start_time
result.computation_time = duration
return group_name, result, intermediate_steps_data
return result


def panoptic_evaluate(
Expand All @@ -246,7 +241,7 @@ def panoptic_evaluate(
verbose=False,
verbose_calc=False,
**kwargs,
) -> tuple[PanopticaResult, IntermediateStepsData]:
) -> PanopticaResult:
"""
Perform panoptic evaluation on the given processing pair.
Expand Down Expand Up @@ -368,13 +363,14 @@ def panoptic_evaluate(
list_metrics=processing_pair.list_metrics,
global_metrics=global_metrics,
edge_case_handler=edge_case_handler,
intermediate_steps_data=intermediate_steps_data,
)

if isinstance(processing_pair, PanopticaResult):
processing_pair._global_metrics = global_metrics
if result_all:
processing_pair.calculate_all(print_errors=verbose_calc)
return processing_pair, intermediate_steps_data
return processing_pair

raise RuntimeError("End of panoptic pipeline reached without results")

Expand Down
3 changes: 3 additions & 0 deletions panoptica/panoptica_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
MetricType,
)
from panoptica.utils import EdgeCaseHandler
from panoptica.utils.processing_pair import IntermediateStepsData


class PanopticaResult(object):
Expand All @@ -27,6 +28,7 @@ def __init__(
list_metrics: dict[Metric, list[float]],
edge_case_handler: EdgeCaseHandler,
global_metrics: list[Metric] = [],
intermediate_steps_data: IntermediateStepsData | None = None,
computation_time: float | None = None,
):
"""Result object for Panoptica, contains all calculatable metrics
Expand All @@ -45,6 +47,7 @@ def __init__(
empty_list_std = self._edge_case_handler.handle_empty_list_std().value
self._global_metrics: list[Metric] = global_metrics
self.computation_time = computation_time
self.intermediate_steps_data = intermediate_steps_data
######################
# Evaluation Metrics #
######################
Expand Down
32 changes: 16 additions & 16 deletions unit_tests/test_panoptic_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def test_simple_evaluation(self):
instance_matcher=NaiveThresholdMatching(),
)

result, debug_data = evaluator.evaluate(b, a)["ungrouped"]
result = evaluator.evaluate(b, a)["ungrouped"]
print(result)
self.assertEqual(result.tp, 1)
self.assertEqual(result.fp, 0)
Expand All @@ -83,7 +83,7 @@ def test_simple_evaluation_instance_multiclass(self):
instance_matcher=NaiveThresholdMatching(),
)

result, debug_data = evaluator.evaluate(b, a)["ungrouped"]
result = evaluator.evaluate(b, a)["ungrouped"]
print(result)
self.assertAlmostEqual(result.global_bin_dsc, 0.8571428571428571)
self.assertEqual(result.tp, 1)
Expand All @@ -104,7 +104,7 @@ def test_simple_evaluation_DSC(self):
instance_matcher=NaiveThresholdMatching(),
)

result, debug_data = evaluator.evaluate(b, a)["ungrouped"]
result = evaluator.evaluate(b, a)["ungrouped"]
print(result)
self.assertEqual(result.tp, 1)
self.assertEqual(result.fp, 0)
Expand All @@ -124,7 +124,7 @@ def test_simple_evaluation_DSC_partial(self):
instance_metrics=[Metric.DSC],
)

result, debug_data = evaluator.evaluate(b, a)["ungrouped"]
result = evaluator.evaluate(b, a)["ungrouped"]
print(result)
self.assertEqual(result.tp, 1)
self.assertEqual(result.fp, 0)
Expand All @@ -150,7 +150,7 @@ def test_simple_evaluation_ASSD(self):
),
)

result, debug_data = evaluator.evaluate(b, a)["ungrouped"]
result = evaluator.evaluate(b, a)["ungrouped"]
print(result)
self.assertEqual(result.tp, 1)
self.assertEqual(result.fp, 0)
Expand All @@ -172,7 +172,7 @@ def test_simple_evaluation_ASSD_negative(self):
),
)

result, debug_data = evaluator.evaluate(b, a)["ungrouped"]
result = evaluator.evaluate(b, a)["ungrouped"]
print(result)
self.assertEqual(result.tp, 0)
self.assertEqual(result.fp, 1)
Expand All @@ -192,7 +192,7 @@ def test_pred_empty(self):
instance_matcher=NaiveThresholdMatching(),
)

result, debug_data = evaluator.evaluate(b, a)["ungrouped"]
result = evaluator.evaluate(b, a)["ungrouped"]
print(result)
self.assertEqual(result.tp, 0)
self.assertEqual(result.fp, 0)
Expand All @@ -214,7 +214,7 @@ def test_no_TP_but_overlap(self):
instance_matcher=NaiveThresholdMatching(),
)

result, debug_data = evaluator.evaluate(b, a)["ungrouped"]
result = evaluator.evaluate(b, a)["ungrouped"]
print(result)
self.assertEqual(result.tp, 0)
self.assertEqual(result.fp, 1)
Expand All @@ -237,7 +237,7 @@ def test_ref_empty(self):
instance_matcher=NaiveThresholdMatching(),
)

result, debug_data = evaluator.evaluate(b, a)["ungrouped"]
result = evaluator.evaluate(b, a)["ungrouped"]
print(result)
self.assertEqual(result.tp, 0)
self.assertEqual(result.fp, 1)
Expand All @@ -258,7 +258,7 @@ def test_both_empty(self):
instance_matcher=NaiveThresholdMatching(),
)

result, debug_data = evaluator.evaluate(b, a)["ungrouped"]
result = evaluator.evaluate(b, a)["ungrouped"]
print(result)
self.assertEqual(result.tp, 0)
self.assertEqual(result.fp, 0)
Expand Down Expand Up @@ -291,7 +291,7 @@ def test_dtype_evaluation(self):
instance_matcher=NaiveThresholdMatching(),
)

result, debug_data = evaluator.evaluate(b, a)["ungrouped"]
result = evaluator.evaluate(b, a)["ungrouped"]
print(result)
self.assertEqual(result.tp, 1)
self.assertEqual(result.fp, 0)
Expand All @@ -310,7 +310,7 @@ def test_simple_evaluation_maximize_matcher(self):
instance_matcher=MaximizeMergeMatching(),
)

result, debug_data = evaluator.evaluate(b, a)["ungrouped"]
result = evaluator.evaluate(b, a)["ungrouped"]
print(result)
self.assertEqual(result.tp, 1)
self.assertEqual(result.fp, 0)
Expand All @@ -330,7 +330,7 @@ def test_simple_evaluation_maximize_matcher_overlaptwo(self):
instance_matcher=MaximizeMergeMatching(),
)

result, debug_data = evaluator.evaluate(b, a)["ungrouped"]
result = evaluator.evaluate(b, a)["ungrouped"]
print(result)
self.assertEqual(result.tp, 1)
self.assertEqual(result.fp, 0)
Expand All @@ -352,7 +352,7 @@ def test_simple_evaluation_maximize_matcher_overlap(self):
instance_matcher=MaximizeMergeMatching(),
)

result, debug_data = evaluator.evaluate(b, a)["ungrouped"]
result = evaluator.evaluate(b, a)["ungrouped"]
print(result)
self.assertEqual(result.tp, 1)
self.assertEqual(result.fp, 1)
Expand All @@ -374,7 +374,7 @@ def test_single_instance_mode(self):
segmentation_class_groups=SegmentationClassGroups({"organ": (5, True)}),
)

result, debug_data = evaluator.evaluate(b, a)["organ"]
result = evaluator.evaluate(b, a)["organ"]
print(result)
self.assertEqual(result.tp, 1)
self.assertEqual(result.fp, 0)
Expand All @@ -394,7 +394,7 @@ def test_single_instance_mode_nooverlap(self):
segmentation_class_groups=SegmentationClassGroups({"organ": (5, True)}),
)

result, debug_data = evaluator.evaluate(b, a)["organ"]
result = evaluator.evaluate(b, a)["organ"]
print(result)
self.assertEqual(result.tp, 1)
self.assertEqual(result.fp, 0)
Expand Down

0 comments on commit c1fdb83

Please sign in to comment.