Skip to content

Commit

Permalink
Merge pull request #7 from BrainLesion/bring_back_the_instances
Browse files Browse the repository at this point in the history
Bring back the instances
  • Loading branch information
neuronflow authored Nov 6, 2023
2 parents bf2b89c + fe3df9f commit 4e09b38
Show file tree
Hide file tree
Showing 6 changed files with 157 additions and 31 deletions.
25 changes: 2 additions & 23 deletions examples/example_cfos_3d.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from auxiliary.nifti.io import read_nifti

from panoptica.panoptica_evaluation import panoptica_evaluation
from panoptica.semantic import SemanticSegmentationEvaluator
from panoptica import CCABackend, SemanticSegmentationEvaluator

pred_masks = read_nifti(
input_nifti_path="/home/florian/flow/cfos_analysis/data/ablation/2021-11-25_23-50-56_2021-10-25_19-38-31_tr_dice_bce_11/patchvolume_695_2.nii.gz"
Expand All @@ -10,27 +9,7 @@
input_nifti_path="/home/florian/flow/cfos_analysis/data/reference/patchvolume_695_2/patchvolume_695_2_binary.nii.gz",
)

# Call panoptica_quality to obtain the result
result = panoptica_evaluation(
ref_mask=ref_masks,
pred_mask=pred_masks,
iou_threshold=0.5,
modus="cc",
)

# Print the metrics
print("Panoptic Quality (PQ):", result.pq)
print("Segmentation Quality (SQ):", result.sq)
print("Recognition Quality (RQ):", result.rq)
print("True Positives (tp):", result.tp)
print("False Positives (fp):", result.fp)
print("False Negatives (fn):", result.fn)
print("instance_dice", result.instance_dice)
print("number of instances in prediction:", result.num_pred_instances)
print("number of instances in reference:", result.num_ref_instances)


eva = SemanticSegmentationEvaluator(cca_backend="cc3d")
eva = SemanticSegmentationEvaluator(cca_backend=CCABackend.cc3d)
res = eva.evaluate(
reference_mask=ref_masks,
prediction_mask=pred_masks,
Expand Down
10 changes: 7 additions & 3 deletions panoptica/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from panoptica.instance.instance_evaluator import InstanceSegmentationEvaluator
from panoptica.semantic.connected_component_backends import CCABackend
from panoptica.semantic.semantic_evaluator import SemanticSegmentationEvaluator
from panoptica.instance_evaluation.instance_evaluator import (
InstanceSegmentationEvaluator,
)
from panoptica.semantic_evaluation.connected_component_backends import CCABackend
from panoptica.semantic_evaluation.semantic_evaluator import (
SemanticSegmentationEvaluator,
)
4 changes: 2 additions & 2 deletions panoptica/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import numpy as np

from .result import PanopticaResult
from panoptica.result import PanopticaResult


class Evaluator(ABC):
Expand All @@ -21,7 +21,7 @@ def evaluate(
reference_mask: np.ndarray,
prediction_mask: np.ndarray,
iou_threshold: float,
)-> PanopticaResult:
) -> PanopticaResult:
"""
Evaluate the instance segmentation results based on the reference and prediction masks.
Expand Down
145 changes: 145 additions & 0 deletions panoptica/instance_evaluation/instance_evaluator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
from __future__ import annotations

import concurrent.futures
import warnings
from typing import Tuple

import numpy as np

from panoptica.evaluator import Evaluator
from panoptica.result import PanopticaResult
from panoptica.timing import measure_time


class InstanceSegmentationEvaluator(Evaluator):
"""
Evaluator for instance segmentation results.
This class extends the Evaluator class and provides methods for evaluating instance segmentation masks
using metrics such as Intersection over Union (IoU) and Dice coefficient.
Methods:
evaluate(reference_mask, prediction_mask, iou_threshold): Evaluate the instance segmentation masks.
_unique_without_zeros(arr): Get unique non-zero values from a NumPy array.
"""

def __init__(self):
# TODO consider initializing evaluator with metrics it should compute
pass

@measure_time
def evaluate(
self,
reference_mask: np.ndarray,
prediction_mask: np.ndarray,
iou_threshold: float,
) -> PanopticaResult:
"""
Evaluate the intersection over union (IoU) and Dice coefficient for instance segmentation masks.
Args:
reference_mask (np.ndarray): The reference instance segmentation mask.
prediction_mask (np.ndarray): The predicted instance segmentation mask.
iou_threshold (float): The IoU threshold for considering a match.
Returns:
PanopticaResult: A named tuple containing evaluation results.
"""
ref_labels = reference_mask
ref_nonzero_unique_labels = self._unique_without_zeros(arr=ref_labels)
num_ref_instances = len(ref_nonzero_unique_labels)

pred_labels = prediction_mask
pred_nonzero_unique_labels = self._unique_without_zeros(arr=pred_labels)
num_pred_instances = len(pred_nonzero_unique_labels)

self._handle_edge_cases(
num_ref_instances=num_ref_instances,
num_pred_instances=num_pred_instances,
)

# Initialize variables for True Positives (tp)
tp, dice_list, iou_list = 0, [], []

# Use concurrent.futures.ThreadPoolExecutor for parallelization
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = [
executor.submit(
self._evaluate_instance,
ref_labels,
pred_labels,
ref_idx,
iou_threshold,
)
for ref_idx in ref_nonzero_unique_labels
]

for future in concurrent.futures.as_completed(futures):
tp_i, dice_i, iou_i = future.result()
tp += tp_i
if dice_i is not None:
dice_list.append(dice_i)
if iou_i is not None:
iou_list.append(iou_i)

# Create and return the PanopticaResult object with computed metrics
return PanopticaResult(
num_ref_instances=num_ref_instances,
num_pred_instances=num_pred_instances,
tp=tp,
dice_list=dice_list,
iou_list=iou_list,
)

def _evaluate_instance(
self,
ref_labels: np.ndarray,
pred_labels: np.ndarray,
ref_idx: int,
iou_threshold: float,
) -> Tuple[int, float, float]:
"""
Evaluate a single instance.
Args:
ref_labels (np.ndarray): Reference instance segmentation mask.
pred_labels (np.ndarray): Predicted instance segmentation mask.
ref_idx (int): The label of the current instance.
iou_threshold (float): The IoU threshold for considering a match.
Returns:
Tuple[int, float, float]: Tuple containing True Positives (int), Dice coefficient (float), and IoU (float).
"""
iou = self._compute_iou(
reference=ref_labels == ref_idx,
prediction=pred_labels == ref_idx,
)
if iou > iou_threshold:
tp = 1
dice = self._compute_dice_coefficient(
reference=ref_labels == ref_idx,
prediction=pred_labels == ref_idx,
)
else:
tp = 0
dice = None

return tp, dice, iou

def _unique_without_zeros(self, arr: np.ndarray) -> np.ndarray:
"""
Get unique non-zero values from a NumPy array.
Parameters:
arr (np.ndarray): Input NumPy array.
Returns:
np.ndarray: Unique non-zero values from the input array.
Issues a warning if negative values are present.
"""
if np.any(arr < 0):
warnings.warn("Negative values are present in the input array.")

return np.unique(arr[arr != 0])
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@

from panoptica.evaluator import Evaluator
from panoptica.result import PanopticaResult
from panoptica.semantic.connected_component_backends import CCABackend

# from panoptica.semantic.connected_component_backends import CCABackend
from panoptica.semantic_evaluation.connected_component_backends import CCABackend
from panoptica.timing import measure_time


Expand Down

0 comments on commit 4e09b38

Please sign in to comment.