Skip to content

Commit

Permalink
Merge branch 'dynamic_result' of github.com:BrainLesion/panoptica int…
Browse files Browse the repository at this point in the history
…o dynamic_result
  • Loading branch information
Hendrik-code committed Jan 23, 2024
2 parents 72c872e + e7510f6 commit 3ac52a6
Show file tree
Hide file tree
Showing 8 changed files with 201 additions and 100 deletions.
17 changes: 13 additions & 4 deletions panoptica/instance_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ def evaluate_matched_instance(
if edge_case_handler is None:
edge_case_handler = EdgeCaseHandler()
if decision_metric is not None:
assert decision_metric.name in [v.name for v in eval_metrics], "decision metric not contained in eval_metrics"
assert decision_metric.name in [
v.name for v in eval_metrics
], "decision metric not contained in eval_metrics"
assert decision_threshold is not None, "decision metric set but no threshold"
# Initialize variables for True Positives (tp)
tp = len(matched_instance_pair.matched_instances)
Expand All @@ -45,14 +47,21 @@ def evaluate_matched_instance(
)
ref_matched_labels = matched_instance_pair.matched_instances

instance_pairs = [(reference_arr, prediction_arr, ref_idx, eval_metrics) for ref_idx in ref_matched_labels]
instance_pairs = [
(reference_arr, prediction_arr, ref_idx, eval_metrics)
for ref_idx in ref_matched_labels
]
with Pool() as pool:
metric_dicts: list[dict[Metric, float]] = pool.starmap(_evaluate_instance, instance_pairs)
metric_dicts: list[dict[Metric, float]] = pool.starmap(
_evaluate_instance, instance_pairs
)

for metric_dict in metric_dicts:
if decision_metric is None or (
decision_threshold is not None
and decision_metric.score_beats_threshold(metric_dict[decision_metric], decision_threshold)
and decision_metric.score_beats_threshold(
metric_dict[decision_metric], decision_threshold
)
):
for k, v in metric_dict.items():
score_dict[k].append(v)
Expand Down
2 changes: 1 addition & 1 deletion panoptica/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
_compute_instance_volumetric_dice,
)
from panoptica.metrics.iou import (
_compute_instance_iou,
_compute_instance_iou,
_compute_iou,
)
from panoptica.metrics.cldice import (
Expand Down
19 changes: 9 additions & 10 deletions panoptica/metrics/cldice.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ def cl_score(volume: np.ndarray, skeleton: np.ndarray):


def _compute_centerline_dice(
ref_labels: np.ndarray,
pred_labels: np.ndarray,
ref_instance_idx: int,
pred_instance_idx: int,
ref_labels: np.ndarray,
pred_labels: np.ndarray,
ref_instance_idx: int,
pred_instance_idx: int,
) -> float:
"""Compute the centerline Dice (clDice) coefficient between a specific pair of instances.
Expand All @@ -38,7 +38,6 @@ def _compute_centerline_dice(
reference=ref_instance_mask,
prediction=pred_instance_mask,
)



def _compute_centerline_dice_coefficient(
Expand All @@ -49,10 +48,10 @@ def _compute_centerline_dice_coefficient(
ndim = reference.ndim
assert 2 <= ndim <= 3, "clDice only implemented for 2D or 3D"
if ndim == 2:
tprec = cl_score(prediction,skeletonize(reference))
tsens = cl_score(reference,skeletonize(prediction))
tprec = cl_score(prediction, skeletonize(reference))
tsens = cl_score(reference, skeletonize(prediction))
elif ndim == 3:
tprec = cl_score(prediction,skeletonize_3d(reference))
tsens = cl_score(reference,skeletonize_3d(prediction))
tprec = cl_score(prediction, skeletonize_3d(reference))
tsens = cl_score(reference, skeletonize_3d(prediction))

return 2 * tprec * tsens / (tprec + tsens)
return 2 * tprec * tsens / (tprec + tsens)
50 changes: 35 additions & 15 deletions panoptica/metrics/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@

@dataclass
class _Metric:
"""A Metric class containing a name, whether higher or lower values is better, and a function to calculate that metric between two instances in an array
"""
"""A Metric class containing a name, whether higher or lower values is better, and a function to calculate that metric between two instances in an array"""

name: str
decreasing: bool
_metric_function: Callable
Expand All @@ -34,7 +34,9 @@ def __call__(
reference_arr = reference_arr.copy() == ref_instance_idx
if isinstance(pred_instance_idx, int):
pred_instance_idx = [pred_instance_idx]
prediction_arr = np.isin(prediction_arr.copy(), pred_instance_idx) #type:ignore
prediction_arr = np.isin(
prediction_arr.copy(), pred_instance_idx
) # type:ignore
return self._metric_function(reference_arr, prediction_arr, *args, **kwargs)

def __eq__(self, __value: object) -> bool:
Expand All @@ -50,34 +52,40 @@ def __str__(self) -> str:

def __repr__(self) -> str:
return str(self)

def __hash__(self) -> int:
return abs(hash(self.name)) % (10**8)

@property
def increasing(self):
return not self.decreasing

def score_beats_threshold(self, matching_score: float, matching_threshold: float) -> bool:
return (self.increasing and matching_score >= matching_threshold) or (self.decreasing and matching_score <= matching_threshold)
def score_beats_threshold(
self, matching_score: float, matching_threshold: float
) -> bool:
return (self.increasing and matching_score >= matching_threshold) or (
self.decreasing and matching_score <= matching_threshold
)


class DirectValueMeta(EnumMeta):
class DirectValueMeta(EnumMeta):
"Metaclass that allows for directly getting an enum attribute"

def __getattribute__(cls, name) -> _Metric:
value = super().__getattribute__(name)
if isinstance(value, cls):
value = value.value
return value


class Metric(_Enum_Compare):
"""Enum containing important metrics that must be calculated in the evaluator, can be set for thresholding in matching and evaluation
Never call the .value member here, use the properties directly
Returns:
_type_: _description_
"""

DSC = _Metric("DSC", False, _compute_dice_coefficient)
IOU = _Metric("IOU", False, _compute_iou)
ASSD = _Metric("ASSD", True, _average_symmetric_surface_distance)
Expand All @@ -103,9 +111,18 @@ def __call__(
Returns:
int | float: The metric value
"""
return self.value(reference_arr=reference_arr, prediction_arr=prediction_arr, ref_instance_idx=ref_instance_idx, pred_instance_idx=pred_instance_idx, *args, **kwargs,)

def score_beats_threshold(self, matching_score: float, matching_threshold: float) -> bool:
return self.value(
reference_arr=reference_arr,
prediction_arr=prediction_arr,
ref_instance_idx=ref_instance_idx,
pred_instance_idx=pred_instance_idx,
*args,
**kwargs,
)

def score_beats_threshold(
self, matching_score: float, matching_threshold: float
) -> bool:
"""Calculates whether a score beats a specified threshold
Args:
Expand All @@ -115,12 +132,14 @@ def score_beats_threshold(self, matching_score: float, matching_threshold: float
Returns:
bool: True if the matching_score beats the threshold, False otherwise.
"""
return (self.increasing and matching_score >= matching_threshold) or (self.decreasing and matching_score <= matching_threshold)
return (self.increasing and matching_score >= matching_threshold) or (
self.decreasing and matching_score <= matching_threshold
)

@property
def name(self):
return self.value.name

@property
def decreasing(self):
return self.value.decreasing
Expand All @@ -139,6 +158,7 @@ class MetricMode(_Enum_Compare):
Args:
_Enum_Compare (_type_): _description_
"""

ALL = auto()
AVG = auto()
SUM = auto()
Expand All @@ -154,4 +174,4 @@ class MetricMode(_Enum_Compare):
print(Metric.DSC.name == "DSC")
#
print(Metric.DSC == Metric.IOU)
print(Metric.DSC == "IOU")
print(Metric.DSC == "IOU")
68 changes: 47 additions & 21 deletions panoptica/panoptic_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
class Panoptic_Evaluator:
def __init__(
self,
expected_input: Type[SemanticPair] | Type[UnmatchedInstancePair] | Type[MatchedInstancePair] = MatchedInstancePair,
expected_input: Type[SemanticPair]
| Type[UnmatchedInstancePair]
| Type[MatchedInstancePair] = MatchedInstancePair,
instance_approximator: InstanceApproximator | None = None,
instance_matcher: InstanceMatchingAlgorithm | None = None,
edge_case_handler: EdgeCaseHandler | None = None,
Expand All @@ -47,9 +49,13 @@ def __init__(
self.__decision_metric = decision_metric
self.__decision_threshold = decision_threshold

self.__edge_case_handler = edge_case_handler if edge_case_handler is not None else EdgeCaseHandler()
self.__edge_case_handler = (
edge_case_handler if edge_case_handler is not None else EdgeCaseHandler()
)
if self.__decision_metric is not None:
assert self.__decision_threshold is not None, "decision metric set but no decision threshold for it"
assert (
self.__decision_threshold is not None
), "decision metric set but no decision threshold for it"
#
self.__log_times = log_times
self.__verbose = verbose
Expand All @@ -58,11 +64,16 @@ def __init__(
@measure_time
def evaluate(
self,
processing_pair: SemanticPair | UnmatchedInstancePair | MatchedInstancePair | PanopticaResult,
processing_pair: SemanticPair
| UnmatchedInstancePair
| MatchedInstancePair
| PanopticaResult,
result_all: bool = True,
verbose: bool | None = None,
) -> tuple[PanopticaResult, dict[str, _ProcessingPair]]:
assert type(processing_pair) == self.__expected_input, f"input not of expected type {self.__expected_input}"
assert (
type(processing_pair) == self.__expected_input
), f"input not of expected type {self.__expected_input}"
return panoptic_evaluate(
processing_pair=processing_pair,
edge_case_handler=self.__edge_case_handler,
Expand All @@ -78,7 +89,10 @@ def evaluate(


def panoptic_evaluate(
processing_pair: SemanticPair | UnmatchedInstancePair | MatchedInstancePair | PanopticaResult,
processing_pair: SemanticPair
| UnmatchedInstancePair
| MatchedInstancePair
| PanopticaResult,
instance_approximator: InstanceApproximator | None = None,
instance_matcher: InstanceMatchingAlgorithm | None = None,
eval_metrics: list[Metric] = [Metric.DSC, Metric.IOU, Metric.ASSD],
Expand Down Expand Up @@ -131,7 +145,9 @@ def panoptic_evaluate(
processing_pair.crop_data()

if isinstance(processing_pair, SemanticPair):
assert instance_approximator is not None, "Got SemanticPair but not InstanceApproximator"
assert (
instance_approximator is not None
), "Got SemanticPair but not InstanceApproximator"
print("-- Got SemanticPair, will approximate instances")
processing_pair = instance_approximator.approximate_instances(processing_pair)
start = perf_counter()
Expand All @@ -142,11 +158,17 @@ def panoptic_evaluate(

# Second Phase: Instance Matching
if isinstance(processing_pair, UnmatchedInstancePair):
processing_pair = _handle_zero_instances_cases(processing_pair, eval_metrics=eval_metrics, edge_case_handler=edge_case_handler)
processing_pair = _handle_zero_instances_cases(
processing_pair,
eval_metrics=eval_metrics,
edge_case_handler=edge_case_handler,
)

if isinstance(processing_pair, UnmatchedInstancePair):
print("-- Got UnmatchedInstancePair, will match instances")
assert instance_matcher is not None, "Got UnmatchedInstancePair but not InstanceMatchingAlgorithm"
assert (
instance_matcher is not None
), "Got UnmatchedInstancePair but not InstanceMatchingAlgorithm"
start = perf_counter()
processing_pair = instance_matcher.match_instances(
processing_pair,
Expand All @@ -158,7 +180,11 @@ def panoptic_evaluate(

# Third Phase: Instance Evaluation
if isinstance(processing_pair, MatchedInstancePair):
processing_pair = _handle_zero_instances_cases(processing_pair, eval_metrics=eval_metrics, edge_case_handler=edge_case_handler)
processing_pair = _handle_zero_instances_cases(
processing_pair,
eval_metrics=eval_metrics,
edge_case_handler=edge_case_handler,
)

if isinstance(processing_pair, MatchedInstancePair):
print("-- Got MatchedInstancePair, will evaluate instances")
Expand Down Expand Up @@ -211,23 +237,23 @@ def _handle_zero_instances_cases(
# Handle cases where either the reference or the prediction is empty
if n_prediction_instance == 0 and n_reference_instance == 0:
# Both references and predictions are empty, perfect match
n_reference_instance=0
n_prediction_instance=0
is_edge_case=True
n_reference_instance = 0
n_prediction_instance = 0
is_edge_case = True
elif n_reference_instance == 0:
# All references are missing, only false positives
n_reference_instance=0
n_prediction_instance=n_prediction_instance
is_edge_case=True
n_reference_instance = 0
n_prediction_instance = n_prediction_instance
is_edge_case = True
elif n_prediction_instance == 0:
# All predictions are missing, only false negatives
n_reference_instance=n_reference_instance
n_prediction_instance=0
is_edge_case=True
n_reference_instance = n_reference_instance
n_prediction_instance = 0
is_edge_case = True

if is_edge_case:
panoptica_result_args["num_ref_instances"] = n_reference_instance
panoptica_result_args["num_pred_instances"] = n_prediction_instance
return PanopticaResult(**panoptica_result_args)

return processing_pair
Loading

0 comments on commit 3ac52a6

Please sign in to comment.