Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Semantic classes #109

Merged
merged 4 commits into from
Jul 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions examples/example_spine_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
from auxiliary.nifti.io import read_nifti
from auxiliary.turbopath import turbopath

from panoptica import MatchedInstancePair, Panoptic_Evaluator
from panoptica import MatchedInstancePair, Panoptica_Evaluator
from panoptica.metrics import Metric
from panoptica.utils.segmentation_class import LabelGroup, SegmentationClassGroups

directory = turbopath(__file__).parent

Expand All @@ -14,18 +15,28 @@

sample = MatchedInstancePair(prediction_arr=pred_masks, reference_arr=ref_masks)


evaluator = Panoptic_Evaluator(
evaluator = Panoptica_Evaluator(
expected_input=MatchedInstancePair,
eval_metrics=[Metric.DSC, Metric.IOU],
segmentation_class_groups=SegmentationClassGroups(
{
"vertebra": LabelGroup([i for i in range(1, 10)]),
"ivd": LabelGroup([i for i in range(101, 109)]),
"sacrum": (26, True),
"endplate": LabelGroup([i for i in range(201, 209)]),
}
),
decision_metric=Metric.DSC,
decision_threshold=0.5,
)


with cProfile.Profile() as pr:
if __name__ == "__main__":
result, debug_data = evaluator.evaluate(sample, verbose=True)
print(result)
results = evaluator.evaluate(sample, verbose=False)
for groupname, (result, debug) in results.items():
print()
print("### Group", groupname)
print(result)

pr.dump_stats(directory + "/instance_example.log")
7 changes: 3 additions & 4 deletions examples/example_spine_semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from panoptica import (
ConnectedComponentsInstanceApproximator,
NaiveThresholdMatching,
Panoptic_Evaluator,
Panoptica_Evaluator,
SemanticPair,
)

Expand All @@ -17,8 +17,7 @@

sample = SemanticPair(pred_masks, ref_masks)


evaluator = Panoptic_Evaluator(
evaluator = Panoptica_Evaluator(
expected_input=SemanticPair,
instance_approximator=ConnectedComponentsInstanceApproximator(),
instance_matcher=NaiveThresholdMatching(),
Expand All @@ -27,7 +26,7 @@

with cProfile.Profile() as pr:
if __name__ == "__main__":
result, debug_data = evaluator.evaluate(sample)
result, debug_data = evaluator.evaluate(sample)["ungrouped"]
print(result)

pr.dump_stats(directory + "/semantic_example.log")
4 changes: 2 additions & 2 deletions panoptica/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
CCABackend,
)
from panoptica.instance_matcher import NaiveThresholdMatching
from panoptica.panoptic_evaluator import Panoptic_Evaluator
from panoptica.panoptic_result import PanopticaResult
from panoptica.panoptica_evaluator import Panoptica_Evaluator
from panoptica.panoptica_result import PanopticaResult
from panoptica.utils.processing_pair import (
SemanticPair,
UnmatchedInstancePair,
Expand Down
81 changes: 0 additions & 81 deletions panoptica/_functionals.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,87 +79,6 @@ def _calc_matching_metric_of_overlapping_labels(
return mm_pairs


def _calc_iou_of_overlapping_labels(
prediction_arr: np.ndarray,
reference_arr: np.ndarray,
ref_labels: tuple[int, ...],
**kwargs,
) -> list[tuple[float, tuple[int, int]]]:
"""Calculates the IOU for all overlapping labels (fast!)

Args:
prediction_arr (np.ndarray): Numpy array containing the prediction labels.
reference_arr (np.ndarray): Numpy array containing the reference labels.
ref_labels (list[int]): List of unique reference labels.
pred_labels (list[int]): List of unique prediction labels.

Returns:
list[tuple[float, tuple[int, int]]]: List of pairs in style: (iou, (ref_label, pred_label))
"""
instance_pairs = [
(reference_arr, prediction_arr, i[0], i[1])
for i in _calc_overlapping_labels(
prediction_arr=prediction_arr,
reference_arr=reference_arr,
ref_labels=ref_labels,
)
]
with Pool() as pool:
iou_values = pool.starmap(_compute_instance_iou, instance_pairs)

iou_pairs = [
(i, (instance_pairs[idx][2], instance_pairs[idx][3]))
for idx, i in enumerate(iou_values)
]
iou_pairs = sorted(iou_pairs, key=lambda x: x[0], reverse=True)

return iou_pairs


def _calc_iou_matrix(
prediction_arr: np.ndarray,
reference_arr: np.ndarray,
ref_labels: tuple[int, ...],
pred_labels: tuple[int, ...],
):
"""
Calculate the Intersection over Union (IoU) matrix between reference and prediction arrays.

Args:
prediction_arr (np.ndarray): Numpy array containing the prediction labels.
reference_arr (np.ndarray): Numpy array containing the reference labels.
ref_labels (list[int]): List of unique reference labels.
pred_labels (list[int]): List of unique prediction labels.

Returns:
np.ndarray: IoU matrix where each element represents the IoU between a reference and prediction instance.

Example:
>>> _calc_iou_matrix(np.array([1, 2, 3]), np.array([4, 5, 6]), [1, 2, 3], [4, 5, 6])
array([[0. , 0. , 0. ],
[0. , 0. , 0. ],
[0. , 0. , 0. ]])
"""
num_ref_instances = len(ref_labels)
num_pred_instances = len(pred_labels)

# Create a pool of worker processes to parallelize the computation
with Pool() as pool:
# # Generate all possible pairs of instance indices for IoU computation
instance_pairs = [
(reference_arr, prediction_arr, ref_idx, pred_idx)
for ref_idx in ref_labels
for pred_idx in pred_labels
]

# Calculate IoU for all instance pairs in parallel using starmap
iou_values = pool.starmap(_compute_instance_iou, instance_pairs)

# Reshape the resulting IoU values into a matrix
iou_matrix = np.array(iou_values).reshape((num_ref_instances, num_pred_instances))
return iou_matrix


def _map_labels(
arr: np.ndarray,
label_map: dict[np.integer, np.integer],
Expand Down
2 changes: 1 addition & 1 deletion panoptica/instance_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np

from panoptica.metrics import Metric
from panoptica.panoptic_result import PanopticaResult
from panoptica.panoptica_result import PanopticaResult
from panoptica.utils import EdgeCaseHandler
from panoptica.utils.processing_pair import MatchedInstancePair

Expand Down
2 changes: 1 addition & 1 deletion panoptica/metrics/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from panoptica.utils.constants import _Enum_Compare, auto

if TYPE_CHECKING:
from panoptic_result import PanopticaResult
from panoptica.panoptica_result import PanopticaResult


@dataclass
Expand Down
2 changes: 1 addition & 1 deletion panoptica/metrics/relative_volume_difference.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,5 +66,5 @@ def _compute_relative_volume_difference(
return 0.0

# Calculate Dice coefficient
rvd = (prediction_mask - reference_mask) / reference_mask
rvd = float(prediction_mask - reference_mask) / reference_mask
return rvd
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from panoptica.instance_evaluator import evaluate_matched_instance
from panoptica.instance_matcher import InstanceMatchingAlgorithm
from panoptica.metrics import Metric, _Metric
from panoptica.panoptic_result import PanopticaResult
from panoptica.timing import measure_time
from panoptica.panoptica_result import PanopticaResult
from panoptica.utils.timing import measure_time
from panoptica.utils import EdgeCaseHandler
from panoptica.utils.citation_reminder import citation_reminder
from panoptica.utils.processing_pair import (
Expand All @@ -15,18 +15,21 @@
UnmatchedInstancePair,
_ProcessingPair,
)
from panoptica.utils.segmentation_class import SegmentationClassGroups, LabelGroup


class Panoptic_Evaluator:
class Panoptica_Evaluator:

def __init__(
self,
# TODO let users give prediction and reference arr instead of the processing pair, so let this create the processing pair itself
expected_input: (
Type[SemanticPair] | Type[UnmatchedInstancePair] | Type[MatchedInstancePair]
) = MatchedInstancePair,
instance_approximator: InstanceApproximator | None = None,
instance_matcher: InstanceMatchingAlgorithm | None = None,
edge_case_handler: EdgeCaseHandler | None = None,
segmentation_class_groups: SegmentationClassGroups | None = None,
eval_metrics: list[Metric] = [Metric.DSC, Metric.IOU, Metric.ASSD, Metric.RVD],
decision_metric: Metric | None = None,
decision_threshold: float | None = None,
Expand All @@ -49,6 +52,8 @@ def __init__(
self.__decision_metric = decision_metric
self.__decision_threshold = decision_threshold

self.__segmentation_class_groups = segmentation_class_groups

self.__edge_case_handler = (
edge_case_handler if edge_case_handler is not None else EdgeCaseHandler()
)
Expand All @@ -69,24 +74,69 @@ def evaluate(
),
result_all: bool = True,
verbose: bool | None = None,
) -> tuple[PanopticaResult, dict[str, _ProcessingPair]]:
) -> dict[str, tuple[PanopticaResult, dict[str, _ProcessingPair]]]:
assert (
type(processing_pair) == self.__expected_input
), f"input not of expected type {self.__expected_input}"
return panoptic_evaluate(
processing_pair=processing_pair,
edge_case_handler=self.__edge_case_handler,
instance_approximator=self.__instance_approximator,
instance_matcher=self.__instance_matcher,
eval_metrics=self.__eval_metrics,
decision_metric=self.__decision_metric,
decision_threshold=self.__decision_threshold,
result_all=result_all,
log_times=self.__log_times,
verbose=True if verbose is None else verbose,
verbose_calc=self.__verbose if verbose is None else verbose,

if self.__segmentation_class_groups is None:
return {
"ungrouped": panoptic_evaluate(
processing_pair=processing_pair,
edge_case_handler=self.__edge_case_handler,
instance_approximator=self.__instance_approximator,
instance_matcher=self.__instance_matcher,
eval_metrics=self.__eval_metrics,
decision_metric=self.__decision_metric,
decision_threshold=self.__decision_threshold,
result_all=result_all,
log_times=self.__log_times,
verbose=True if verbose is None else verbose,
verbose_calc=self.__verbose if verbose is None else verbose,
)
}

self.__segmentation_class_groups.has_defined_labels_for(
processing_pair.prediction_arr, raise_error=True
)
self.__segmentation_class_groups.has_defined_labels_for(
processing_pair.reference_arr, raise_error=True
)

result_grouped = {}
for group_name, label_group in self.__segmentation_class_groups.items():
assert isinstance(label_group, LabelGroup)

prediction_arr_grouped = label_group(processing_pair.prediction_arr)
reference_arr_grouped = label_group(processing_pair.reference_arr)

single_instance_mode = label_group.single_instance
processing_pair_grouped = processing_pair.__class__(prediction_arr=prediction_arr_grouped, reference_arr=reference_arr_grouped) # type: ignore
decision_threshold = self.__decision_threshold
if single_instance_mode and not isinstance(
processing_pair, MatchedInstancePair
):
processing_pair_grouped = MatchedInstancePair(
prediction_arr=processing_pair_grouped.prediction_arr,
reference_arr=processing_pair_grouped.reference_arr,
)
decision_threshold = 0.0

result_grouped[group_name] = panoptic_evaluate(
processing_pair=processing_pair_grouped,
edge_case_handler=self.__edge_case_handler,
instance_approximator=self.__instance_approximator,
instance_matcher=self.__instance_matcher,
eval_metrics=self.__eval_metrics,
decision_metric=self.__decision_metric,
decision_threshold=decision_threshold,
result_all=result_all,
log_times=self.__log_times,
verbose=True if verbose is None else verbose,
verbose_calc=self.__verbose if verbose is None else verbose,
)
return result_grouped


def panoptic_evaluate(
processing_pair: (
Expand Down
File renamed without changes.
4 changes: 4 additions & 0 deletions panoptica/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,7 @@
)

# from utils.constants import
from panoptica.utils.segmentation_class import (
SegmentationClassGroups,
LabelGroup,
)
Loading
Loading