Skip to content

Commit

Permalink
polished dynamic result object. added global dice score, added center…
Browse files Browse the repository at this point in the history
…lineDSC, fixed some bugs, added unit tests for the new panopticaresult object
  • Loading branch information
Hendrik-code committed Jan 22, 2024
1 parent 63b5f87 commit 293156e
Show file tree
Hide file tree
Showing 12 changed files with 419 additions and 133 deletions.
6 changes: 3 additions & 3 deletions examples/example_spine_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from auxiliary.turbopath import turbopath

from panoptica import MatchedInstancePair, Panoptic_Evaluator
from panoptica.metrics import Metrics
from panoptica.metrics import MatchingMetrics, ListMetric, ListMetricMode

directory = turbopath(__file__).parent

Expand All @@ -17,8 +17,8 @@

evaluator = Panoptic_Evaluator(
expected_input=MatchedInstancePair,
eval_metrics=[Metrics.ASSD, Metrics.IOU],
decision_metric=Metrics.IOU,
eval_metrics=[MatchingMetrics.clDSC, MatchingMetrics.DSC],
decision_metric=MatchingMetrics.DSC,
decision_threshold=0.5,
)

Expand Down
2 changes: 1 addition & 1 deletion examples/example_spine_semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
Panoptic_Evaluator,
SemanticPair,
)
from panoptica.metrics import Metrics
from panoptica.metrics import MatchingMetrics

directory = turbopath(__file__).parent

Expand Down
10 changes: 6 additions & 4 deletions panoptica/instance_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@
from panoptica.timing import measure_time
from panoptica.utils import EdgeCaseHandler
from panoptica.utils.processing_pair import MatchedInstancePair
from panoptica.metrics import Metrics, ListMetric
from panoptica.metrics import MatchingMetrics, ListMetric


def evaluate_matched_instance(
matched_instance_pair: MatchedInstancePair,
eval_metrics: list[_MatchingMetric] = [Metrics.DSC, Metrics.IOU, Metrics.ASSD],
decision_metric: _MatchingMetric | None = Metrics.IOU,
eval_metrics: list[_MatchingMetric] = [MatchingMetrics.DSC, MatchingMetrics.IOU, MatchingMetrics.ASSD],
decision_metric: _MatchingMetric | None = MatchingMetrics.IOU,
decision_threshold: float | None = None,
edge_case_handler: EdgeCaseHandler | None = None,
**kwargs,
Expand Down Expand Up @@ -67,8 +67,10 @@ def evaluate_matched_instance(

# Create and return the PanopticaResult object with computed metrics
return PanopticaResult(
num_ref_instances=matched_instance_pair.n_reference_instance,
reference_arr=matched_instance_pair.reference_arr,
prediction_arr=matched_instance_pair.prediction_arr,
num_pred_instances=matched_instance_pair.n_prediction_instance,
num_ref_instances=matched_instance_pair.n_reference_instance,
tp=tp,
list_metrics=score_dict,
edge_case_handler=edge_case_handler,
Expand Down
6 changes: 3 additions & 3 deletions panoptica/instance_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
_calc_matching_metric_of_overlapping_labels,
_map_labels,
)
from panoptica.metrics import Metrics, _MatchingMetric
from panoptica.metrics import MatchingMetrics, _MatchingMetric
from panoptica.utils.processing_pair import (
InstanceLabelMap,
MatchedInstancePair,
Expand Down Expand Up @@ -153,7 +153,7 @@ class NaiveThresholdMatching(InstanceMatchingAlgorithm):

def __init__(
self,
matching_metric: _MatchingMetric = Metrics.IOU,
matching_metric: _MatchingMetric = MatchingMetrics.IOU,
matching_threshold: float = 0.5,
allow_many_to_one: bool = False,
) -> None:
Expand Down Expand Up @@ -228,7 +228,7 @@ class MaximizeMergeMatching(InstanceMatchingAlgorithm):

def __init__(
self,
matching_metric: _MatchingMetric = Metrics.IOU,
matching_metric: _MatchingMetric = MatchingMetrics.IOU,
matching_threshold: float = 0.5,
) -> None:
"""
Expand Down
11 changes: 9 additions & 2 deletions panoptica/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,12 @@
_compute_dice_coefficient,
_compute_instance_volumetric_dice,
)
from panoptica.metrics.iou import _compute_instance_iou, _compute_iou
from panoptica.metrics.metrics import Metrics, ListMetric, _MatchingMetric, ListMetricMode
from panoptica.metrics.iou import (
_compute_instance_iou,
_compute_iou,
)
from panoptica.metrics.cldice import (
_compute_centerline_dice,
_compute_centerline_dice_coefficient,
)
from panoptica.metrics.metrics import MatchingMetrics, ListMetric, _MatchingMetric, ListMetricMode
58 changes: 58 additions & 0 deletions panoptica/metrics/cldice.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from skimage.morphology import skeletonize, skeletonize_3d
import numpy as np


def cl_score(volume: np.ndarray, skeleton: np.ndarray):
"""Computes the skeleton volume overlap
Args:
volume (np.ndarray): volume
skeleton (np.ndarray): skeleton
Returns:
_type_: skeleton overlap
"""
return np.sum(volume * skeleton) / np.sum(skeleton)


def _compute_centerline_dice(
ref_labels: np.ndarray,
pred_labels: np.ndarray,
ref_instance_idx: int,
pred_instance_idx: int,
) -> float:
"""Compute the centerline Dice (clDice) coefficient between a specific pair of instances.
Args:
ref_labels (np.ndarray): Reference instance labels.
pred_labels (np.ndarray): Prediction instance labels.
ref_instance_idx (int): Index of the reference instance.
pred_instance_idx (int): Index of the prediction instance.
Returns:
float: clDice coefficient
"""
ref_instance_mask = ref_labels == ref_instance_idx
pred_instance_mask = pred_labels == pred_instance_idx
return _compute_centerline_dice_coefficient(
reference=ref_instance_mask,
prediction=pred_instance_mask,
)



def _compute_centerline_dice_coefficient(
reference: np.ndarray,
prediction: np.ndarray,
*args,
) -> float:
ndim = reference.ndim
assert 2 <= ndim <= 3, "clDice only implemented for 2D or 3D"
if ndim == 2:
tprec = cl_score(prediction,skeletonize(reference))
tsens = cl_score(reference,skeletonize(prediction))
elif ndim == 3:
tprec = cl_score(prediction,skeletonize_3d(reference))
tsens = cl_score(reference,skeletonize_3d(prediction))

return 2 * tprec * tsens / (tprec + tsens)
33 changes: 16 additions & 17 deletions panoptica/metrics/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
_average_symmetric_surface_distance,
_compute_dice_coefficient,
_compute_iou,
_compute_centerline_dice_coefficient,
)
from panoptica.utils.constants import _Enum_Compare, auto

Expand Down Expand Up @@ -57,14 +58,11 @@ def score_beats_threshold(self, matching_score: float, matching_threshold: float


# Important metrics that must be calculated in the evaluator, can be set for thresholding in matching and evaluation
# TODO make abstract class for metric, make enum with references to these classes for referenciation and user exposure
class Metrics:
# TODO make this with meta above, and then it can function without the double name, right?
DSC = _MatchingMetric("DSC", False, _compute_dice_coefficient)
IOU = _MatchingMetric("IOU", False, _compute_iou)
ASSD = _MatchingMetric("ASSD", True, _average_symmetric_surface_distance)
# These are all lists of values

class MatchingMetrics:
DSC: _MatchingMetric = _MatchingMetric("DSC", False, _compute_dice_coefficient)
IOU: _MatchingMetric = _MatchingMetric("IOU", False, _compute_iou)
ASSD: _MatchingMetric = _MatchingMetric("ASSD", True, _average_symmetric_surface_distance)
clDSC: _MatchingMetric = _MatchingMetric("clDSC", False, _compute_centerline_dice_coefficient)

class ListMetricMode(_Enum_Compare):
ALL = auto()
Expand All @@ -74,21 +72,22 @@ class ListMetricMode(_Enum_Compare):


class ListMetric(_Enum_Compare):
DSC = Metrics.DSC.name
IOU = Metrics.IOU.name
ASSD = Metrics.ASSD.name
DSC = MatchingMetrics.DSC.name
IOU = MatchingMetrics.IOU.name
ASSD = MatchingMetrics.ASSD.name
clDSC = MatchingMetrics.clDSC.name

def __hash__(self) -> int:
return abs(hash(self.value)) % (10**8)


if __name__ == "__main__":
print(Metrics.DSC)
print(MatchingMetrics.DSC)
# print(MatchingMetric.DSC.name)

print(Metrics.DSC == Metrics.DSC)
print(Metrics.DSC == "DSC")
print(Metrics.DSC.name == "DSC")
print(MatchingMetrics.DSC == MatchingMetrics.DSC)
print(MatchingMetrics.DSC == "DSC")
print(MatchingMetrics.DSC.name == "DSC")
#
print(Metrics.DSC == Metrics.IOU)
print(Metrics.DSC == "IOU")
print(MatchingMetrics.DSC == MatchingMetrics.IOU)
print(MatchingMetrics.DSC == "IOU")
61 changes: 33 additions & 28 deletions panoptica/panoptic_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from panoptica.instance_approximator import InstanceApproximator
from panoptica.instance_evaluator import evaluate_matched_instance
from panoptica.instance_matcher import InstanceMatchingAlgorithm
from panoptica.metrics import Metrics, _MatchingMetric
from panoptica.metrics import MatchingMetrics, _MatchingMetric, ListMetric
from panoptica.panoptic_result import PanopticaResult
from panoptica.timing import measure_time
from panoptica.utils import EdgeCaseHandler
Expand All @@ -25,7 +25,7 @@ def __init__(
instance_approximator: InstanceApproximator | None = None,
instance_matcher: InstanceMatchingAlgorithm | None = None,
edge_case_handler: EdgeCaseHandler | None = None,
eval_metrics: list[_MatchingMetric] = [Metrics.DSC, Metrics.IOU, Metrics.ASSD],
eval_metrics: list[_MatchingMetric] = [MatchingMetrics.DSC, MatchingMetrics.IOU, MatchingMetrics.ASSD],
decision_metric: _MatchingMetric | None = None,
decision_threshold: float | None = None,
log_times: bool = False,
Expand Down Expand Up @@ -80,7 +80,7 @@ def panoptic_evaluate(
processing_pair: SemanticPair | UnmatchedInstancePair | MatchedInstancePair | PanopticaResult,
instance_approximator: InstanceApproximator | None = None,
instance_matcher: InstanceMatchingAlgorithm | None = None,
eval_metrics: list[_MatchingMetric] = [Metrics.DSC, Metrics.IOU, Metrics.ASSD],
eval_metrics: list[_MatchingMetric] = [MatchingMetrics.DSC, MatchingMetrics.IOU, MatchingMetrics.ASSD],
decision_metric: _MatchingMetric | None = None,
decision_threshold: float | None = None,
edge_case_handler: EdgeCaseHandler | None = None,
Expand Down Expand Up @@ -141,7 +141,7 @@ def panoptic_evaluate(

# Second Phase: Instance Matching
if isinstance(processing_pair, UnmatchedInstancePair):
processing_pair = _handle_zero_instances_cases(processing_pair, edge_case_handler=edge_case_handler)
processing_pair = _handle_zero_instances_cases(processing_pair, eval_metrics=eval_metrics, edge_case_handler=edge_case_handler)

if isinstance(processing_pair, UnmatchedInstancePair):
print("-- Got UnmatchedInstancePair, will match instances")
Expand All @@ -157,7 +157,7 @@ def panoptic_evaluate(

# Third Phase: Instance Evaluation
if isinstance(processing_pair, MatchedInstancePair):
processing_pair = _handle_zero_instances_cases(processing_pair, edge_case_handler=edge_case_handler)
processing_pair = _handle_zero_instances_cases(processing_pair, eval_metrics=eval_metrics, edge_case_handler=edge_case_handler)

if isinstance(processing_pair, MatchedInstancePair):
print("-- Got MatchedInstancePair, will evaluate instances")
Expand All @@ -182,6 +182,7 @@ def panoptic_evaluate(
def _handle_zero_instances_cases(
processing_pair: UnmatchedInstancePair | MatchedInstancePair,
edge_case_handler: EdgeCaseHandler,
eval_metrics: list[_MatchingMetric] = [MatchingMetrics.DSC, MatchingMetrics.IOU, MatchingMetrics.ASSD],
) -> UnmatchedInstancePair | MatchedInstancePair | PanopticaResult:
"""
Handle edge cases when comparing reference and prediction masks.
Expand All @@ -196,32 +197,36 @@ def _handle_zero_instances_cases(
n_reference_instance = processing_pair.n_reference_instance
n_prediction_instance = processing_pair.n_prediction_instance

panoptica_result_args = {
"list_metrics": {ListMetric[k.name]: [] for k in eval_metrics},
"tp": 0,
"edge_case_handler": edge_case_handler,
"reference_arr": processing_pair.reference_arr,
"prediction_arr": processing_pair.prediction_arr,
}

is_edge_case = False

# Handle cases where either the reference or the prediction is empty
if n_prediction_instance == 0 and n_reference_instance == 0:
# Both references and predictions are empty, perfect match
return PanopticaResult(
num_ref_instances=0,
num_pred_instances=0,
tp=0,
list_metrics={},
edge_case_handler=edge_case_handler,
)
if n_reference_instance == 0:
n_reference_instance=0
n_prediction_instance=0
is_edge_case=True
elif n_reference_instance == 0:
# All references are missing, only false positives
return PanopticaResult(
num_ref_instances=0,
num_pred_instances=n_prediction_instance,
tp=0,
list_metrics={},
edge_case_handler=edge_case_handler,
)
if n_prediction_instance == 0:
n_reference_instance=0
n_prediction_instance=n_prediction_instance
is_edge_case=True
elif n_prediction_instance == 0:
# All predictions are missing, only false negatives
return PanopticaResult(
num_ref_instances=n_reference_instance,
num_pred_instances=0,
tp=0,
list_metrics={},
edge_case_handler=edge_case_handler,
)
n_reference_instance=n_reference_instance
n_prediction_instance=0
is_edge_case=True

if is_edge_case:
panoptica_result_args["num_ref_instances"] = n_reference_instance
panoptica_result_args["num_pred_instances"] = n_prediction_instance
return PanopticaResult(**panoptica_result_args)

return processing_pair
Loading

0 comments on commit 293156e

Please sign in to comment.