Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Global metrics hotfix #118

Merged
merged 7 commits into from
Aug 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion panoptica/_functionals.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import numpy as np

from panoptica.metrics.iou import _compute_instance_iou
from panoptica.utils.constants import CCABackend
from panoptica.utils.numpy_utils import _get_bbox_nd

Expand Down
51 changes: 42 additions & 9 deletions panoptica/panoptica_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,6 @@
MetricCouldNotBeComputedException,
MetricMode,
MetricType,
_compute_centerline_dice_coefficient,
_compute_dice_coefficient,
_average_symmetric_surface_distance,
_compute_relative_volume_difference,
)
from panoptica.utils import EdgeCaseHandler

Expand Down Expand Up @@ -44,8 +40,6 @@ def __init__(
"""
self._edge_case_handler = edge_case_handler
empty_list_std = self._edge_case_handler.handle_empty_list_std().value
self._prediction_arr = prediction_arr
self._reference_arr = reference_arr
self._global_metrics: list[Metric] = global_metrics
######################
# Evaluation Metrics #
Expand Down Expand Up @@ -253,12 +247,42 @@ def __init__(
m, empty_list_std, list_metrics[m], is_edge_case, edge_case_result
)
# even if not available, set the global vars
default_value = None
was_calculated = False
if m in self._global_metrics:
default_value = self._calc_global_bin_metric(
m, prediction_arr, reference_arr
)
was_calculated = True

self._add_metric(
f"global_bin_{m.name.lower()}",
MetricType.GLOBAL,
_build_global_bin_metric_function(m),
lambda x: MetricCouldNotBeComputedException(
f"Global Metric {m} not set"
),
long_name="Global Binary " + m.value.long_name,
default_value=default_value,
was_calculated=was_calculated,
)

def _calc_global_bin_metric(self, metric: Metric, prediction_arr, reference_arr):
if metric not in self._global_metrics:
raise MetricCouldNotBeComputedException(f"Global Metric {metric} not set")
if self.tp == 0:
is_edgecase, result = self._edge_case_handler.handle_zero_tp(
metric, self.tp, self.num_pred_instances, self.num_ref_instances
)
if is_edgecase:
return result
pred_binary = prediction_arr
ref_binary = reference_arr
pred_binary[pred_binary != 0] = 1
ref_binary[ref_binary != 0] = 1
return metric(
reference_arr=ref_binary,
prediction_arr=pred_binary,
)

def _add_metric(
self,
Expand Down Expand Up @@ -292,6 +316,7 @@ def calculate_all(self, print_errors: bool = False):
print_errors (bool, optional): If true, will print every metric that could not be computed and its reason. Defaults to False.
"""
metric_errors: dict[str, Exception] = {}

for k, v in self._evaluation_metrics.items():
try:
v = getattr(self, k)
Expand All @@ -302,6 +327,13 @@ def calculate_all(self, print_errors: bool = False):
for k, v in metric_errors.items():
print(f"Metric {k}: {v}")

def _calc(self, k, v):
try:
v = getattr(self, k)
return False, v
except Exception as e:
return True, e

def __str__(self) -> str:
text = ""
for metric_type in MetricType:
Expand Down Expand Up @@ -366,6 +398,8 @@ def __getattribute__(self, __name: str) -> Any:
try:
attr = object.__getattribute__(self, __name)
except AttributeError as e:
if __name == "_evaluation_metrics":
raise e
if __name in self._evaluation_metrics.keys():
pass
else:
Expand Down Expand Up @@ -514,12 +548,11 @@ def function_template(res: PanopticaResult):
prediction_arr=res._prediction_arr,
)

return function_template
return lambda x: function_template(x)


# endregion


if __name__ == "__main__":
c = PanopticaResult(
reference_arr=np.zeros([5, 5, 5]),
Expand Down
22 changes: 22 additions & 0 deletions unit_tests/test_panoptic_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,28 @@ def test_simple_evaluation(self):
self.assertEqual(result.fp, 0)
self.assertEqual(result.sq, 0.75)
self.assertEqual(result.pq, 0.75)
self.assertAlmostEqual(result.global_bin_dsc, 0.8571428571428571)

def test_simple_evaluation_instance_multiclass(self):
a = np.zeros([50, 50], dtype=np.uint16)
b = a.copy().astype(a.dtype)
a[20:30, 10:20] = 1
a[30:40, 10:20] = 3
b[20:35, 10:20] = 2

evaluator = Panoptica_Evaluator(
expected_input=InputType.UNMATCHED_INSTANCE,
instance_matcher=NaiveThresholdMatching(),
)

result, debug_data = evaluator.evaluate(b, a)["ungrouped"]
print(result)
self.assertAlmostEqual(result.global_bin_dsc, 0.8571428571428571)
self.assertEqual(result.tp, 1)
self.assertEqual(result.fp, 0)
self.assertEqual(result.fn, 1)
self.assertAlmostEqual(result.sq, 0.6666666666666666)
self.assertAlmostEqual(result.pq, 0.4444444444444444)

def test_simple_evaluation_DSC(self):
a = np.zeros([50, 50], dtype=np.uint16)
Expand Down
Loading