Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modular metrics #71

Merged
merged 20 commits into from
Jan 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions benchmark/modules_speedtest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
ConnectedComponentsInstanceApproximator,
NaiveThresholdMatching,
SemanticPair,
UnmatchedInstancePair,
MatchedInstancePair,
)
from panoptica.instance_evaluator import evaluate_matched_instance
from time import perf_counter
Expand Down Expand Up @@ -80,16 +82,21 @@ def test_input(processing_pair: SemanticPair):
processing_pair.crop_data()
#
start1 = perf_counter()
processing_pair = instance_approximator.approximate_instances(processing_pair)
unmatched_instance_pair = instance_approximator.approximate_instances(
semantic_pair=processing_pair
)
time1 = perf_counter() - start1
#
start2 = perf_counter()
processing_pair = instance_matcher.match_instances(processing_pair)
matched_instance_pair = instance_matcher.match_instances(
unmatched_instance_pair=unmatched_instance_pair
)
time2 = perf_counter() - start2
#
start3 = perf_counter()
processing_pair = evaluate_matched_instance(
processing_pair, iou_threshold=iou_threshold
result = evaluate_matched_instance(
matched_instance_pair,
decision_threshold=iou_threshold,
)
time3 = perf_counter() - start3
return time1, time2, time3
Expand Down
28 changes: 0 additions & 28 deletions examples/example_cfos_3d.py

This file was deleted.

10 changes: 7 additions & 3 deletions examples/example_spine_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from auxiliary.turbopath import turbopath

from panoptica import MatchedInstancePair, Panoptic_Evaluator
from panoptica.metrics import Metrics

directory = turbopath(__file__).parent

Expand All @@ -16,13 +17,16 @@

evaluator = Panoptic_Evaluator(
expected_input=MatchedInstancePair,
instance_approximator=None,
instance_matcher=None,
iou_threshold=0.5,
eval_metrics=[Metrics.ASSD, Metrics.IOU],
decision_metric=Metrics.IOU,
decision_threshold=0.5,
)


with cProfile.Profile() as pr:
if __name__ == "__main__":
result, debug_data = evaluator.evaluate(sample)

print(result)

pr.dump_stats(directory + "/instance_example.log")
5 changes: 3 additions & 2 deletions examples/example_spine_semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
from auxiliary.nifti.io import read_nifti
from auxiliary.turbopath import turbopath


from panoptica import (
ConnectedComponentsInstanceApproximator,
NaiveThresholdMatching,
Panoptic_Evaluator,
SemanticPair,
)
from panoptica.metrics import Metrics

directory = turbopath(__file__).parent

Expand All @@ -18,12 +18,13 @@

sample = SemanticPair(pred_masks, ref_masks)


evaluator = Panoptic_Evaluator(
expected_input=SemanticPair,
instance_approximator=ConnectedComponentsInstanceApproximator(),
instance_matcher=NaiveThresholdMatching(),
iou_threshold=0.5,
)

with cProfile.Profile() as pr:
if __name__ == "__main__":
result, debug_data = evaluator.evaluate(sample)
Expand Down
6 changes: 3 additions & 3 deletions panoptica/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
CCABackend,
)
from panoptica.instance_matcher import NaiveThresholdMatching
from panoptica.evaluator import Panoptic_Evaluator
from panoptica.result import PanopticaResult
from panoptica.utils.datatypes import (
from panoptica.panoptic_evaluator import Panoptic_Evaluator
from panoptica.panoptic_result import PanopticaResult
from panoptica.utils.processing_pair import (
SemanticPair,
UnmatchedInstancePair,
MatchedInstancePair,
Expand Down
42 changes: 40 additions & 2 deletions panoptica/_functionals.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from multiprocessing import Pool

import numpy as np
from panoptica.metrics import _compute_instance_iou

from panoptica.metrics import _compute_instance_iou, _MatchingMetric
from panoptica.utils.constants import CCABackend
from panoptica.utils.numpy_utils import _get_bbox_nd
from multiprocessing import Pool


def _calc_overlapping_labels(
Expand Down Expand Up @@ -35,6 +37,42 @@ def _calc_overlapping_labels(
]


def _calc_matching_metric_of_overlapping_labels(
prediction_arr: np.ndarray,
reference_arr: np.ndarray,
ref_labels: tuple[int, ...],
matching_metric: _MatchingMetric,
) -> list[tuple[float, tuple[int, int]]]:
"""Calculates the MatchingMetric for all overlapping labels (fast!)

Args:
prediction_arr (np.ndarray): Numpy array containing the prediction labels.
reference_arr (np.ndarray): Numpy array containing the reference labels.
ref_labels (list[int]): List of unique reference labels.

Returns:
list[tuple[float, tuple[int, int]]]: List of pairs in style: (iou, (ref_label, pred_label))
"""
instance_pairs = [
(reference_arr == i[0], prediction_arr == i[1], i[0], i[1])
for i in _calc_overlapping_labels(
prediction_arr=prediction_arr,
reference_arr=reference_arr,
ref_labels=ref_labels,
)
]
with Pool() as pool:
mm_values = pool.starmap(matching_metric._metric_function, instance_pairs)

mm_pairs = [
(i, (instance_pairs[idx][2], instance_pairs[idx][3]))
for idx, i in enumerate(mm_values)
]
mm_pairs = sorted(mm_pairs, key=lambda x: x[0], reverse=matching_metric.decreasing)

return mm_pairs


def _calc_iou_of_overlapping_labels(
prediction_arr: np.ndarray,
reference_arr: np.ndarray,
Expand Down
2 changes: 1 addition & 1 deletion panoptica/instance_approximator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import abstractmethod, ABC
from panoptica.utils.datatypes import (
from panoptica.utils.processing_pair import (
SemanticPair,
UnmatchedInstancePair,
MatchedInstancePair,
Expand Down
131 changes: 53 additions & 78 deletions panoptica/instance_evaluator.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,26 @@
import concurrent.futures
from panoptica.utils.datatypes import MatchedInstancePair
from panoptica.result import PanopticaResult
import gc
from multiprocessing import Pool

import numpy as np

from panoptica.metrics import (
_compute_iou,
_compute_dice_coefficient,
_average_symmetric_surface_distance,
Metrics,
_MatchingMetric,
)
from panoptica.panoptic_result import PanopticaResult
from panoptica.timing import measure_time
import numpy as np
import gc
from multiprocessing import Pool
from panoptica.utils import EdgeCaseHandler
from panoptica.utils.processing_pair import MatchedInstancePair


def evaluate_matched_instance(
matched_instance_pair: MatchedInstancePair, iou_threshold: float, **kwargs
matched_instance_pair: MatchedInstancePair,
eval_metrics: list[_MatchingMetric] = [Metrics.DSC, Metrics.IOU, Metrics.ASSD],
decision_metric: _MatchingMetric | None = Metrics.IOU,
decision_threshold: float | None = None,
edge_case_handler: EdgeCaseHandler | None = None,
**kwargs,
) -> PanopticaResult:
"""
Map instance labels based on the provided labelmap and create a MatchedInstancePair.
Expand All @@ -30,75 +37,58 @@ def evaluate_matched_instance(
>>> labelmap = [([1, 2], [3, 4]), ([5], [6])]
>>> result = map_instance_labels(unmatched_instance_pair, labelmap)
"""
if edge_case_handler is None:
edge_case_handler = EdgeCaseHandler()
if decision_metric is not None:
assert decision_metric.name in [
v.name for v in eval_metrics
], "decision metric not contained in eval_metrics"
assert decision_threshold is not None, "decision metric set but no threshold"
# Initialize variables for True Positives (tp)
tp, dice_list, iou_list, assd_list = 0, [], [], []
tp = len(matched_instance_pair.matched_instances)
score_dict: dict[str | _MatchingMetric, list[float]] = {
m.name: [] for m in eval_metrics
}

reference_arr, prediction_arr = (
matched_instance_pair._reference_arr,
matched_instance_pair._prediction_arr,
)
ref_labels = matched_instance_pair._ref_labels

# instance_pairs = _calc_overlapping_labels(
# prediction_arr=prediction_arr,
# reference_arr=reference_arr,
# ref_labels=ref_labels,
# )
# instance_pairs = [(ra, pa, rl, iou_threshold) for (ra, pa, rl, pl) in instance_pairs]
ref_matched_labels = matched_instance_pair.matched_instances

instance_pairs = [
(reference_arr, prediction_arr, ref_idx, iou_threshold)
for ref_idx in ref_labels
(reference_arr, prediction_arr, ref_idx, eval_metrics)
for ref_idx in ref_matched_labels
]
with Pool() as pool:
metric_values = pool.starmap(_evaluate_instance, instance_pairs)

for tp_i, dice_i, iou_i, assd_i in metric_values:
tp += tp_i
if dice_i is not None and iou_i is not None and assd_i is not None:
dice_list.append(dice_i)
iou_list.append(iou_i)
assd_list.append(assd_i)

# Use concurrent.futures.ThreadPoolExecutor for parallelization
# with concurrent.futures.ThreadPoolExecutor() as executor:
# futures = [
# executor.submit(
# _evaluate_instance,
# reference_arr,
# prediction_arr,
# ref_idx,
# iou_threshold,
# )
# for ref_idx in ref_labels
# ]
#
# for future in concurrent.futures.as_completed(futures):
# tp_i, dice_i, iou_i, assd_i = future.result()
# tp += tp_i
# if dice_i is not None and iou_i is not None and assd_i is not None:
# dice_list.append(dice_i)
# iou_list.append(iou_i)
# assd_list.append(assd_i)
# del future
# gc.collect()
metric_dicts = pool.starmap(_evaluate_instance, instance_pairs)

for metric_dict in metric_dicts:
if decision_metric is None or (
decision_threshold is not None
and decision_metric.score_beats_threshold(
metric_dict[decision_metric.name], decision_threshold
)
):
for k, v in metric_dict.items():
score_dict[k].append(v)

# Create and return the PanopticaResult object with computed metrics
return PanopticaResult(
num_ref_instances=matched_instance_pair.n_reference_instance,
num_pred_instances=matched_instance_pair.n_prediction_instance,
tp=tp,
dice_list=dice_list,
iou_list=iou_list,
assd_list=assd_list,
list_metrics=score_dict,
edge_case_handler=edge_case_handler,
)


def _evaluate_instance(
reference_arr: np.ndarray,
prediction_arr: np.ndarray,
ref_idx: int,
iou_threshold: float,
) -> tuple[int, float | None, float | None, float | None]:
eval_metrics: list[_MatchingMetric],
) -> dict[str, float]:
"""
Evaluate a single instance.

Expand All @@ -113,27 +103,12 @@ def _evaluate_instance(
"""
ref_arr = reference_arr == ref_idx
pred_arr = prediction_arr == ref_idx
result: dict[str, float] = {}
if ref_arr.sum() == 0 or pred_arr.sum() == 0:
tp = 0
dice = None
iou = None
assd = None
return result
else:
iou: float | None = _compute_iou(
reference=ref_arr,
prediction=pred_arr,
)
if iou > iou_threshold:
tp = 1
dice = _compute_dice_coefficient(
reference=ref_arr,
prediction=pred_arr,
)
assd = _average_symmetric_surface_distance(pred_arr, ref_arr)
else:
tp = 0
dice = None
iou = None
assd = None

return tp, dice, iou, assd
for metric in eval_metrics:
value = metric._metric_function(ref_arr, pred_arr)
result[metric.name] = value

return result
Loading