Skip to content

Commit

Permalink
added unittests for definition of segmentation labels. Tweaked some t…
Browse files Browse the repository at this point in the history
…hings. Added single_instance mode which disables matching and sets decision threshold to zero as it assumes there is only one instance of this class. renamed files consistently to panoptica.
  • Loading branch information
Hendrik-code committed Jun 10, 2024
1 parent a578687 commit b900eb5
Show file tree
Hide file tree
Showing 15 changed files with 197 additions and 135 deletions.
8 changes: 2 additions & 6 deletions examples/example_spine_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from auxiliary.nifti.io import read_nifti
from auxiliary.turbopath import turbopath

from panoptica import MatchedInstancePair, Panoptic_Evaluator
from panoptica import MatchedInstancePair, Panoptica_Evaluator
from panoptica.metrics import Metric
from panoptica.utils.segmentation_class import LabelGroup, SegmentationClassGroups

Expand All @@ -15,11 +15,7 @@

sample = MatchedInstancePair(prediction_arr=pred_masks, reference_arr=ref_masks)

import numpy as np

print(np.unique(pred_masks))

evaluator = Panoptic_Evaluator(
evaluator = Panoptica_Evaluator(
expected_input=MatchedInstancePair,
eval_metrics=[Metric.DSC, Metric.IOU],
segmentation_class_groups=SegmentationClassGroups(
Expand Down
6 changes: 3 additions & 3 deletions examples/example_spine_semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from panoptica import (
ConnectedComponentsInstanceApproximator,
NaiveThresholdMatching,
Panoptic_Evaluator,
Panoptica_Evaluator,
SemanticPair,
)

Expand All @@ -17,7 +17,7 @@

sample = SemanticPair(pred_masks, ref_masks)

evaluator = Panoptic_Evaluator(
evaluator = Panoptica_Evaluator(
expected_input=SemanticPair,
instance_approximator=ConnectedComponentsInstanceApproximator(),
instance_matcher=NaiveThresholdMatching(),
Expand All @@ -26,7 +26,7 @@

with cProfile.Profile() as pr:
if __name__ == "__main__":
result, debug_data = evaluator.evaluate(sample)
result, debug_data = evaluator.evaluate(sample)["ungrouped"]
print(result)

pr.dump_stats(directory + "/semantic_example.log")
4 changes: 2 additions & 2 deletions panoptica/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
CCABackend,
)
from panoptica.instance_matcher import NaiveThresholdMatching
from panoptica.panoptic_evaluator import Panoptic_Evaluator
from panoptica.panoptic_result import PanopticaResult
from panoptica.panoptica_evaluator import Panoptica_Evaluator
from panoptica.panoptica_result import PanopticaResult
from panoptica.utils.processing_pair import (
SemanticPair,
UnmatchedInstancePair,
Expand Down
81 changes: 0 additions & 81 deletions panoptica/_functionals.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,87 +79,6 @@ def _calc_matching_metric_of_overlapping_labels(
return mm_pairs


def _calc_iou_of_overlapping_labels(
prediction_arr: np.ndarray,
reference_arr: np.ndarray,
ref_labels: tuple[int, ...],
**kwargs,
) -> list[tuple[float, tuple[int, int]]]:
"""Calculates the IOU for all overlapping labels (fast!)
Args:
prediction_arr (np.ndarray): Numpy array containing the prediction labels.
reference_arr (np.ndarray): Numpy array containing the reference labels.
ref_labels (list[int]): List of unique reference labels.
pred_labels (list[int]): List of unique prediction labels.
Returns:
list[tuple[float, tuple[int, int]]]: List of pairs in style: (iou, (ref_label, pred_label))
"""
instance_pairs = [
(reference_arr, prediction_arr, i[0], i[1])
for i in _calc_overlapping_labels(
prediction_arr=prediction_arr,
reference_arr=reference_arr,
ref_labels=ref_labels,
)
]
with Pool() as pool:
iou_values = pool.starmap(_compute_instance_iou, instance_pairs)

iou_pairs = [
(i, (instance_pairs[idx][2], instance_pairs[idx][3]))
for idx, i in enumerate(iou_values)
]
iou_pairs = sorted(iou_pairs, key=lambda x: x[0], reverse=True)

return iou_pairs


def _calc_iou_matrix(
prediction_arr: np.ndarray,
reference_arr: np.ndarray,
ref_labels: tuple[int, ...],
pred_labels: tuple[int, ...],
):
"""
Calculate the Intersection over Union (IoU) matrix between reference and prediction arrays.
Args:
prediction_arr (np.ndarray): Numpy array containing the prediction labels.
reference_arr (np.ndarray): Numpy array containing the reference labels.
ref_labels (list[int]): List of unique reference labels.
pred_labels (list[int]): List of unique prediction labels.
Returns:
np.ndarray: IoU matrix where each element represents the IoU between a reference and prediction instance.
Example:
>>> _calc_iou_matrix(np.array([1, 2, 3]), np.array([4, 5, 6]), [1, 2, 3], [4, 5, 6])
array([[0. , 0. , 0. ],
[0. , 0. , 0. ],
[0. , 0. , 0. ]])
"""
num_ref_instances = len(ref_labels)
num_pred_instances = len(pred_labels)

# Create a pool of worker processes to parallelize the computation
with Pool() as pool:
# # Generate all possible pairs of instance indices for IoU computation
instance_pairs = [
(reference_arr, prediction_arr, ref_idx, pred_idx)
for ref_idx in ref_labels
for pred_idx in pred_labels
]

# Calculate IoU for all instance pairs in parallel using starmap
iou_values = pool.starmap(_compute_instance_iou, instance_pairs)

# Reshape the resulting IoU values into a matrix
iou_matrix = np.array(iou_values).reshape((num_ref_instances, num_pred_instances))
return iou_matrix


def _map_labels(
arr: np.ndarray,
label_map: dict[np.integer, np.integer],
Expand Down
2 changes: 1 addition & 1 deletion panoptica/instance_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np

from panoptica.metrics import Metric
from panoptica.panoptic_result import PanopticaResult
from panoptica.panoptica_result import PanopticaResult
from panoptica.utils import EdgeCaseHandler
from panoptica.utils.processing_pair import MatchedInstancePair

Expand Down
2 changes: 1 addition & 1 deletion panoptica/metrics/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from panoptica.utils.constants import _Enum_Compare, auto

if TYPE_CHECKING:
from panoptic_result import PanopticaResult
from panoptica.panoptica_result import PanopticaResult


@dataclass
Expand Down
2 changes: 1 addition & 1 deletion panoptica/metrics/relative_volume_difference.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,5 +66,5 @@ def _compute_relative_volume_difference(
return 0.0

# Calculate Dice coefficient
rvd = (prediction_mask - reference_mask) / reference_mask
rvd = float(prediction_mask - reference_mask) / reference_mask
return rvd
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from panoptica.instance_evaluator import evaluate_matched_instance
from panoptica.instance_matcher import InstanceMatchingAlgorithm
from panoptica.metrics import Metric, _Metric
from panoptica.panoptic_result import PanopticaResult
from panoptica.timing import measure_time
from panoptica.panoptica_result import PanopticaResult
from panoptica.utils.timing import measure_time
from panoptica.utils import EdgeCaseHandler
from panoptica.utils.citation_reminder import citation_reminder
from panoptica.utils.processing_pair import (
Expand All @@ -18,7 +18,7 @@
from panoptica.utils.segmentation_class import SegmentationClassGroups, LabelGroup


class Panoptic_Evaluator:
class Panoptica_Evaluator:

def __init__(
self,
Expand Down Expand Up @@ -66,7 +66,7 @@ def evaluate(
processing_pair: SemanticPair | UnmatchedInstancePair | MatchedInstancePair | PanopticaResult,
result_all: bool = True,
verbose: bool | None = None,
) -> tuple[PanopticaResult, dict[str, _ProcessingPair]]:
) -> dict[str, tuple[PanopticaResult, dict[str, _ProcessingPair]]]:
assert type(processing_pair) == self.__expected_input, f"input not of expected type {self.__expected_input}"

if self.__segmentation_class_groups is None:
Expand All @@ -90,22 +90,30 @@ def evaluate(
self.__segmentation_class_groups.has_defined_labels_for(processing_pair.reference_arr, raise_error=True)

result_grouped = {}
for group_name in self.__segmentation_class_groups:
label_group = self.__segmentation_class_groups[group_name]
for group_name, label_group in self.__segmentation_class_groups.items():
assert isinstance(label_group, LabelGroup)

prediction_arr_grouped = label_group(processing_pair.prediction_arr)
reference_arr_grouped = label_group(processing_pair.reference_arr)

processing_pair_grouped = processing_pair.__class__(prediction_arr=prediction_arr_grouped, reference_arr=reference_arr_grouped)
single_instance_mode = label_group.single_instance
processing_pair_grouped = processing_pair.__class__(prediction_arr=prediction_arr_grouped, reference_arr=reference_arr_grouped) # type: ignore
decision_threshold = self.__decision_threshold
if single_instance_mode and not isinstance(processing_pair, MatchedInstancePair):
processing_pair_grouped = MatchedInstancePair(
prediction_arr=processing_pair_grouped.prediction_arr,
reference_arr=processing_pair_grouped.reference_arr,
)
decision_threshold = 0.0

result_grouped[group_name] = panoptic_evaluate(
processing_pair=processing_pair_grouped,
edge_case_handler=self.__edge_case_handler,
instance_approximator=self.__instance_approximator,
instance_matcher=self.__instance_matcher,
eval_metrics=self.__eval_metrics,
decision_metric=self.__decision_metric,
decision_threshold=self.__decision_threshold,
decision_threshold=decision_threshold,
result_all=result_all,
log_times=self.__log_times,
verbose=True if verbose is None else verbose,
Expand Down
File renamed without changes.
12 changes: 8 additions & 4 deletions panoptica/utils/segmentation_class.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
import numpy as np


# TODO also support LabelMergedGroup which takes multi labels and convert them into one before the evaluation
# Useful for BraTs with hierarchical labels (then define one generic Group class and then two more specific subgroups, one for hierarchical, the other for the current one)


class LabelGroup:
"""Defines a group of labels that semantically belong to each other. Only labels within a group will be matched with each other"""

Expand All @@ -21,6 +17,7 @@ def __init__(
"""
if isinstance(value_labels, int):
value_labels = [value_labels]
assert len(value_labels) >= 1, f"You tried to define a LabelGroup without any specified labels, got {value_labels}"
self.__value_labels = value_labels
assert np.all([v > 0 for v in self.__value_labels]), f"Given value labels are not >0, got {value_labels}"
self.__single_instance = single_instance
Expand Down Expand Up @@ -119,6 +116,13 @@ def __getitem__(self, key):
def __iter__(self):
yield from self.__group_dictionary

def keys(self) -> list[str]:
return list(self.__group_dictionary.keys())

def items(self):
for k in self:
yield k, self[k]


def list_duplicates(seq):
seen = set()
Expand Down
File renamed without changes.
89 changes: 89 additions & 0 deletions unit_tests/test_labelgroup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Call 'python -m unittest' on this folder
# coverage run -m unittest
# coverage report
# coverage html
import os
import unittest
import numpy as np

from panoptica.utils.segmentation_class import LabelGroup, SegmentationClassGroups


class Test_DefinitionOfSegmentationLabels(unittest.TestCase):
def setUp(self) -> None:
os.environ["PANOPTICA_CITATION_REMINDER"] = "False"
return super().setUp()

def test_labelgroup(self):
group1 = LabelGroup([1, 2, 3, 4, 5], single_instance=False)

print(group1)
arr = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
group1_arr = group1(arr, True)

print(group1_arr)
self.assertEqual(group1_arr.sum(), 5)

group1_arr_ind = np.argwhere(group1_arr).flatten()
print(group1_arr_ind)
group1_labels = np.asarray(group1.value_labels)
print(group1_labels)
self.assertTrue(np.all(group1_arr_ind == group1_labels))

def test_labelgroup_notpresent(self):
group1 = LabelGroup([1, 2, 3, 4, 5], single_instance=False)

print(group1)
arr = np.array([0, 6, 7, 8, 0, 15, 6, 7, 8, 9, 10])
group1_arr = group1(arr, True)

print(group1_arr)
self.assertEqual(group1_arr.sum(), 0)

group1_arr_ind = np.argwhere(group1_arr).flatten()
self.assertEqual(len(group1_arr_ind), 0)

def test_wrong_labelgroup_definitions(self):

with self.assertRaises(AssertionError):
group1 = LabelGroup([1, 2, 3, 4, 5], single_instance=True)

with self.assertRaises(AssertionError):
group1 = LabelGroup([], single_instance=False)

with self.assertRaises(AssertionError):
group1 = LabelGroup([1, 0, -1, 5], single_instance=False)

def test_segmentationclassgroup_easy(self):
group1 = LabelGroup([1, 2, 3, 4, 5], single_instance=False)
classgroups = SegmentationClassGroups(
groups={
"vertebra": group1,
"ivds": LabelGroup([100, 101, 102]),
}
)

print(classgroups)

self.assertTrue(classgroups.has_defined_labels_for([1, 2, 3]))

self.assertTrue(classgroups.has_defined_labels_for([1, 100, 3]))

self.assertFalse(classgroups.has_defined_labels_for([1, 99, 3]))

self.assertTrue("ivds" in classgroups)

for i in classgroups:
self.assertTrue(i in ["vertebra", "ivds"])

for i, lg in classgroups.items():
print(i, lg)
self.assertTrue(isinstance(i, str))
self.assertTrue(isinstance(lg, LabelGroup))

def test_segmentationclassgroup_decarations(self):
classgroups = SegmentationClassGroups(groups=[LabelGroup(i) for i in range(1, 5)])

keys = classgroups.keys()
for i in range(1, 5):
self.assertTrue(f"group_{i-1}" in keys, f"not {i} in {keys}")
2 changes: 1 addition & 1 deletion unit_tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import numpy as np

from panoptica.metrics import Metric
from panoptica.panoptic_result import MetricCouldNotBeComputedException, PanopticaResult
from panoptica.panoptica_result import MetricCouldNotBeComputedException, PanopticaResult
from panoptica.utils.edge_case_handling import EdgeCaseHandler, EdgeCaseResult


Expand Down
Loading

0 comments on commit b900eb5

Please sign in to comment.