Skip to content

Commit

Permalink
Merge pull request #99 from BrainLesion/more_metrics
Browse files Browse the repository at this point in the history
More metrics
  • Loading branch information
Hendrik-code authored Apr 18, 2024
2 parents c687ec8 + 7f8ccce commit 06a37e9
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 1 deletion.
1 change: 1 addition & 0 deletions examples/example_spine_semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
expected_input=SemanticPair,
instance_approximator=ConnectedComponentsInstanceApproximator(),
instance_matcher=NaiveThresholdMatching(),
verbose=True,
)

with cProfile.Profile() as pr:
Expand Down
1 change: 1 addition & 0 deletions panoptica/metrics/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ class MetricType(_Enum_Compare):
_Enum_Compare (_type_): _description_
"""

NO_PRINT = auto()
MATCHING = auto()
GLOBAL = auto()
INSTANCE = auto()
Expand Down
2 changes: 1 addition & 1 deletion panoptica/panoptic_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,6 @@ def panoptic_evaluate(
instance_approximator is not None
), "Got SemanticPair but not InstanceApproximator"
print("-- Got SemanticPair, will approximate instances")
processing_pair = instance_approximator.approximate_instances(processing_pair)
start = perf_counter()
processing_pair = instance_approximator.approximate_instances(processing_pair)
if log_times:
Expand Down Expand Up @@ -185,6 +184,7 @@ def panoptic_evaluate(

if isinstance(processing_pair, MatchedInstancePair):
print("-- Got MatchedInstancePair, will evaluate instances")
start = perf_counter()
processing_pair = evaluate_matched_instance(
processing_pair,
eval_metrics=eval_metrics,
Expand Down
43 changes: 43 additions & 0 deletions panoptica/panoptic_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
MetricType,
_compute_centerline_dice_coefficient,
_compute_dice_coefficient,
_average_symmetric_surface_distance,
)
from panoptica.utils import EdgeCaseHandler

Expand Down Expand Up @@ -94,6 +95,20 @@ def __init__(
fn,
long_name="False Negatives",
)
self.prec: int
self._add_metric(
"prec",
MetricType.NO_PRINT,
prec,
long_name="Precision (positive predictive value)",
)
self.rec: int
self._add_metric(
"rec",
MetricType.NO_PRINT,
rec,
long_name="Recall (sensitivity)",
)
self.rq: float
self._add_metric(
"rq",
Expand All @@ -119,6 +134,14 @@ def __init__(
global_bin_cldsc,
long_name="Global Binary Centerline Dice",
)
#
self.global_bin_assd: int
self._add_metric(
"global_bin_assd",
MetricType.GLOBAL,
global_bin_assd,
long_name="Global Binary Average Symmetric Surface Distance",
)
# endregion
#
# region IOU
Expand Down Expand Up @@ -270,6 +293,8 @@ def calculate_all(self, print_errors: bool = False):
def __str__(self) -> str:
text = ""
for metric_type in MetricType:
if metric_type == MetricType.NO_PRINT:
continue
text += f"\n+++ {metric_type.name} +++\n"
for k, v in self._evaluation_metrics.items():
if v.metric_type != metric_type:
Expand Down Expand Up @@ -360,6 +385,14 @@ def fn(res: PanopticaResult):
return res.num_ref_instances - res.tp


def prec(res: PanopticaResult):
return res.tp / (res.tp + res.fp)


def rec(res: PanopticaResult):
return res.tp / (res.tp + res.fn)


def rq(res: PanopticaResult):
"""
Calculate the Recognition Quality (RQ) based on TP, FP, and FN.
Expand Down Expand Up @@ -456,6 +489,16 @@ def global_bin_cldsc(res: PanopticaResult):
return _compute_centerline_dice_coefficient(ref_binary, pred_binary)


def global_bin_assd(res: PanopticaResult):
if res.tp == 0:
return 0.0
pred_binary = res._prediction_arr.copy()
ref_binary = res._reference_arr.copy()
pred_binary[pred_binary != 0] = 1
ref_binary[ref_binary != 0] = 1
return _average_symmetric_surface_distance(ref_binary, pred_binary)


# endregion


Expand Down

0 comments on commit 06a37e9

Please sign in to comment.