Skip to content

Commit

Permalink
Merge pull request #115 from BrainLesion/config
Browse files Browse the repository at this point in the history
Config
  • Loading branch information
neuronflow authored Aug 6, 2024
2 parents 10bbbb2 + 52cfe7f commit 50889d3
Show file tree
Hide file tree
Showing 26 changed files with 919 additions and 273 deletions.
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,16 @@ For this case, the matcher module can be utilized to match instances and the eva

If your predicted instances already match the reference instances, you can directly compute metrics using the evaluator module.


### Using Configs (saving and loading)

You can construct Panoptica_Evaluator (among many others) objects and save their arguments, so you can save project-specific configurations and use them later.

[Jupyter notebook tutorial](https://github.com/BrainLesion/tutorials/tree/main/panoptica/example_config.ipynb)

It uses ruamel.yaml in a readable way.


## Citation

If you use panoptica in your research, please cite it to support the development!
Expand Down
14 changes: 6 additions & 8 deletions examples/example_spine_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,17 @@
from auxiliary.nifti.io import read_nifti
from auxiliary.turbopath import turbopath

from panoptica import MatchedInstancePair, Panoptica_Evaluator
from panoptica import Panoptica_Evaluator, InputType
from panoptica.metrics import Metric
from panoptica.utils.segmentation_class import LabelGroup, SegmentationClassGroups

directory = turbopath(__file__).parent

ref_masks = read_nifti(directory + "/spine_seg/matched_instance/ref.nii.gz")

pred_masks = read_nifti(directory + "/spine_seg/matched_instance/pred.nii.gz")

sample = MatchedInstancePair(prediction_arr=pred_masks, reference_arr=ref_masks)
reference_mask = read_nifti(directory + "/spine_seg/matched_instance/ref.nii.gz")
prediction_mask = read_nifti(directory + "/spine_seg/matched_instance/pred.nii.gz")

evaluator = Panoptica_Evaluator(
expected_input=MatchedInstancePair,
expected_input=InputType.MATCHED_INSTANCE,
eval_metrics=[Metric.DSC, Metric.IOU],
segmentation_class_groups=SegmentationClassGroups(
{
Expand All @@ -28,12 +25,13 @@
),
decision_metric=Metric.DSC,
decision_threshold=0.5,
log_times=True,
)


with cProfile.Profile() as pr:
if __name__ == "__main__":
results = evaluator.evaluate(sample, verbose=False)
results = evaluator.evaluate(prediction_mask, reference_mask, verbose=False)
for groupname, (result, debug) in results.items():
print()
print("### Group", groupname)
Expand Down
26 changes: 26 additions & 0 deletions examples/example_spine_instance_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import cProfile

from auxiliary.nifti.io import read_nifti
from auxiliary.turbopath import turbopath

from panoptica import Panoptica_Evaluator

directory = turbopath(__file__).parent

reference_mask = read_nifti(directory + "/spine_seg/matched_instance/ref.nii.gz")
prediction_mask = read_nifti(directory + "/spine_seg/matched_instance/pred.nii.gz")

evaluator = Panoptica_Evaluator.load_from_config_name(
"panoptica_evaluator_unmatched_instance"
)


with cProfile.Profile() as pr:
if __name__ == "__main__":
results = evaluator.evaluate(prediction_mask, reference_mask, verbose=False)
for groupname, (result, debug) in results.items():
print()
print("### Group", groupname)
print(result)

pr.dump_stats(directory + "/instance_example.log")
13 changes: 7 additions & 6 deletions examples/example_spine_semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,26 +7,27 @@
ConnectedComponentsInstanceApproximator,
NaiveThresholdMatching,
Panoptica_Evaluator,
SemanticPair,
InputType,
)

directory = turbopath(__file__).parent

ref_masks = read_nifti(directory + "/spine_seg/semantic/ref.nii.gz")
pred_masks = read_nifti(directory + "/spine_seg/semantic/pred.nii.gz")
reference_mask = read_nifti(directory + "/spine_seg/semantic/ref.nii.gz")
prediction_mask = read_nifti(directory + "/spine_seg/semantic/pred.nii.gz")

sample = SemanticPair(pred_masks, ref_masks)

evaluator = Panoptica_Evaluator(
expected_input=SemanticPair,
expected_input=InputType.SEMANTIC,
instance_approximator=ConnectedComponentsInstanceApproximator(),
instance_matcher=NaiveThresholdMatching(),
verbose=True,
)

with cProfile.Profile() as pr:
if __name__ == "__main__":
result, debug_data = evaluator.evaluate(sample)["ungrouped"]
result, debug_data = evaluator.evaluate(prediction_mask, reference_mask)[
"ungrouped"
]
print(result)

pr.dump_stats(directory + "/semantic_example.log")
1 change: 1 addition & 0 deletions panoptica/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from panoptica.panoptica_evaluator import Panoptica_Evaluator
from panoptica.panoptica_result import PanopticaResult
from panoptica.utils.processing_pair import (
InputType,
SemanticPair,
UnmatchedInstancePair,
MatchedInstancePair,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
!SegmentationClassGroups
groups:
endplate: !LabelGroup
single_instance: false
value_labels: [201, 202, 203, 204, 205, 206, 207, 208, 209, 210]
ivd: !LabelGroup
single_instance: false
value_labels: [101, 102, 103, 104, 105, 106, 107, 108, 109, 110]
sacrum: !LabelGroup
single_instance: true
value_labels: [26]
vertebra: !LabelGroup
single_instance: false
value_labels: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
42 changes: 42 additions & 0 deletions panoptica/configs/panoptica_evaluator_unmatched_instance.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
!Panoptica_Evaluator
decision_metric: !Metric DSC
decision_threshold: 0.5
edge_case_handler: !EdgeCaseHandler
empty_list_std: !EdgeCaseResult NAN
listmetric_zeroTP_handling:
!Metric DSC: !MetricZeroTPEdgeCaseHandling {empty_prediction_result: !EdgeCaseResult ZERO,
empty_reference_result: !EdgeCaseResult ZERO, no_instances_result: !EdgeCaseResult NAN,
normal: !EdgeCaseResult ZERO}
!Metric clDSC: !MetricZeroTPEdgeCaseHandling {empty_prediction_result: !EdgeCaseResult ZERO,
empty_reference_result: !EdgeCaseResult ZERO, no_instances_result: !EdgeCaseResult NAN,
normal: !EdgeCaseResult ZERO}
!Metric IOU: !MetricZeroTPEdgeCaseHandling {empty_prediction_result: !EdgeCaseResult ZERO,
empty_reference_result: !EdgeCaseResult ZERO, no_instances_result: !EdgeCaseResult NAN,
normal: !EdgeCaseResult ZERO}
!Metric ASSD: !MetricZeroTPEdgeCaseHandling {empty_prediction_result: !EdgeCaseResult INF,
empty_reference_result: !EdgeCaseResult INF, no_instances_result: !EdgeCaseResult NAN,
normal: !EdgeCaseResult INF}
!Metric RVD: !MetricZeroTPEdgeCaseHandling {empty_prediction_result: !EdgeCaseResult NAN,
empty_reference_result: !EdgeCaseResult NAN, no_instances_result: !EdgeCaseResult NAN,
normal: !EdgeCaseResult NAN}
eval_metrics: [!Metric DSC, !Metric IOU]
expected_input: !InputType UNMATCHED_INSTANCE
instance_approximator: null
instance_matcher: !NaiveThresholdMatching {allow_many_to_one: false, matching_metric: !Metric IOU,
matching_threshold: 0.5}
log_times: true
segmentation_class_groups: !SegmentationClassGroups
groups:
endplate: !LabelGroup
single_instance: false
value_labels: [201, 202, 203, 204, 205, 206, 207, 208, 209, 210]
ivd: !LabelGroup
single_instance: false
value_labels: [101, 102, 103, 104, 105, 106, 107, 108, 109, 110]
sacrum: !LabelGroup
single_instance: true
value_labels: [26]
vertebra: !LabelGroup
single_instance: false
value_labels: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
verbose: false
17 changes: 14 additions & 3 deletions panoptica/instance_approximator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from abc import ABC, abstractmethod
from abc import ABC, abstractmethod, ABCMeta

import numpy as np

Expand All @@ -10,9 +10,10 @@
SemanticPair,
UnmatchedInstancePair,
)
from panoptica.utils.config import SupportsConfig


class InstanceApproximator(ABC):
class InstanceApproximator(SupportsConfig, metaclass=ABCMeta):
"""
Abstract base class for instance approximation algorithms in panoptic segmentation evaluation.
Expand Down Expand Up @@ -56,6 +57,12 @@ def _approximate_instances(
"""
pass

def _yaml_repr(cls, node) -> dict:
raise NotImplementedError(
f"Tried to get yaml representation of abstract class {cls.__name__}"
)
return {}

def approximate_instances(
self, semantic_pair: SemanticPair, verbose: bool = False, **kwargs
) -> UnmatchedInstancePair | MatchedInstancePair:
Expand Down Expand Up @@ -140,7 +147,7 @@ def _approximate_instances(
UnmatchedInstancePair: The result of the instance approximation.
"""
cca_backend = self.cca_backend
if self.cca_backend is None:
if cca_backend is None:
cca_backend = (
CCABackend.cc3d if semantic_pair.n_dim >= 3 else CCABackend.scipy
)
Expand All @@ -164,3 +171,7 @@ def _approximate_instances(
n_prediction_instance=n_prediction_instance,
n_reference_instance=n_reference_instance,
)

@classmethod
def _yaml_repr(cls, node) -> dict:
return {"cca_backend": node.cca_backend}
54 changes: 38 additions & 16 deletions panoptica/instance_matcher.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from abc import ABC, abstractmethod
from abc import ABCMeta, abstractmethod

import numpy as np

Expand All @@ -8,13 +8,14 @@
)
from panoptica.metrics import Metric
from panoptica.utils.processing_pair import (
InstanceLabelMap,
MatchedInstancePair,
UnmatchedInstancePair,
)
from panoptica.utils.instancelabelmap import InstanceLabelMap
from panoptica.utils.config import SupportsConfig


class InstanceMatchingAlgorithm(ABC):
class InstanceMatchingAlgorithm(SupportsConfig, metaclass=ABCMeta):
"""
Abstract base class for instance matching algorithms in panoptic segmentation evaluation.
Expand Down Expand Up @@ -79,6 +80,12 @@ def match_instances(
# print("instance_labelmap:", instance_labelmap)
return map_instance_labels(unmatched_instance_pair.copy(), instance_labelmap)

def _yaml_repr(cls, node) -> dict:
raise NotImplementedError(
f"Tried to get yaml representation of abstract class {cls.__name__}"
)
return {}


def map_instance_labels(
processing_pair: UnmatchedInstancePair, labelmap: InstanceLabelMap
Expand Down Expand Up @@ -166,9 +173,9 @@ def __init__(
Raises:
AssertionError: If the specified IoU threshold is not within the valid range.
"""
self.allow_many_to_one = allow_many_to_one
self.matching_metric = matching_metric
self.matching_threshold = matching_threshold
self._allow_many_to_one = allow_many_to_one
self._matching_metric = matching_metric
self._matching_threshold = matching_threshold

def _match_instances(
self,
Expand All @@ -195,24 +202,32 @@ def _match_instances(
unmatched_instance_pair.reference_arr,
)
mm_pairs = _calc_matching_metric_of_overlapping_labels(
pred_arr, ref_arr, ref_labels, matching_metric=self.matching_metric
pred_arr, ref_arr, ref_labels, matching_metric=self._matching_metric
)

# Loop through matched instances to compute PQ components
for matching_score, (ref_label, pred_label) in mm_pairs:
if (
labelmap.contains_or(pred_label, ref_label)
and not self.allow_many_to_one
and not self._allow_many_to_one
):
continue # -> doesnt make speed difference
if self.matching_metric.score_beats_threshold(
matching_score, self.matching_threshold
if self._matching_metric.score_beats_threshold(
matching_score, self._matching_threshold
):
# Match found, increment true positive count and collect IoU and Dice values
labelmap.add_labelmap_entry(pred_label, ref_label)
# map label ref_idx to pred_idx
return labelmap

@classmethod
def _yaml_repr(cls, node) -> dict:
return {
"matching_metric": node._matching_metric,
"matching_threshold": node._matching_threshold,
"allow_many_to_one": node._allow_many_to_one,
}


class MaximizeMergeMatching(InstanceMatchingAlgorithm):
"""
Expand Down Expand Up @@ -241,8 +256,8 @@ def __init__(
Raises:
AssertionError: If the specified IoU threshold is not within the valid range.
"""
self.matching_metric = matching_metric
self.matching_threshold = matching_threshold
self._matching_metric = matching_metric
self._matching_threshold = matching_threshold

def _match_instances(
self,
Expand Down Expand Up @@ -274,7 +289,7 @@ def _match_instances(
prediction_arr=pred_arr,
reference_arr=ref_arr,
ref_labels=ref_labels,
matching_metric=self.matching_metric,
matching_metric=self._matching_metric,
)

# Loop through matched instances to compute PQ components
Expand All @@ -290,8 +305,8 @@ def _match_instances(
if new_score > score_ref[ref_label]:
labelmap.add_labelmap_entry(pred_label, ref_label)
score_ref[ref_label] = new_score
elif self.matching_metric.score_beats_threshold(
matching_score, self.matching_threshold
elif self._matching_metric.score_beats_threshold(
matching_score, self._matching_threshold
):
# Match found, increment true positive count and collect IoU and Dice values
labelmap.add_labelmap_entry(pred_label, ref_label)
Expand All @@ -307,14 +322,21 @@ def new_combination_score(
unmatched_instance_pair: UnmatchedInstancePair,
):
pred_labels.append(new_pred_label)
score = self.matching_metric(
score = self._matching_metric(
unmatched_instance_pair.reference_arr,
prediction_arr=unmatched_instance_pair.prediction_arr,
ref_instance_idx=ref_label,
pred_instance_idx=pred_labels,
)
return score

@classmethod
def _yaml_repr(cls, node) -> dict:
return {
"matching_metric": node._matching_metric,
"matching_threshold": node._matching_threshold,
}


class MatchUntilConvergenceMatching(InstanceMatchingAlgorithm):
# Match like the naive matcher (so each to their best reference) and then again and again until no overlapping labels are left
Expand Down
9 changes: 4 additions & 5 deletions panoptica/metrics/relative_volume_difference.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,16 +55,15 @@ def _compute_relative_volume_difference(
prediction (np.ndarray): Prediction binary mask.
Returns:
float: Relative volume Error between the two binary masks. A value between 0 and 1, where higher values
indicate better overlap and similarity between masks.
float: Relative volume Error between the two binary masks. A value of zero means perfect volume match, while >0 means oversegmentation and <0 undersegmentation.
"""
reference_mask = np.sum(reference)
prediction_mask = np.sum(prediction)
reference_mask = float(np.sum(reference))
prediction_mask = float(np.sum(prediction))

# Handle division by zero
if reference_mask == 0 and prediction_mask == 0:
return 0.0

# Calculate Dice coefficient
rvd = float(prediction_mask - reference_mask) / reference_mask
rvd = (prediction_mask - reference_mask) / reference_mask
return rvd
Loading

0 comments on commit 50889d3

Please sign in to comment.