Skip to content

Commit

Permalink
Merge pull request #33 from BrainLesion/logging_first
Browse files Browse the repository at this point in the history
Logging first
  • Loading branch information
Hendrik-code authored Nov 15, 2023
2 parents bdd0d3a + 6dbaf7d commit 6d805f4
Show file tree
Hide file tree
Showing 9 changed files with 36 additions and 19 deletions.
3 changes: 1 addition & 2 deletions examples/example_spine_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@

pred_masks = read_nifti(directory + "/spine_seg/matched_instance/pred.nii.gz")

sample = MatchedInstancePair(
prediction_arr=pred_masks, reference_arr=ref_masks)
sample = MatchedInstancePair(prediction_arr=pred_masks, reference_arr=ref_masks)


evaluator = Panoptic_Evaluator(
Expand Down
8 changes: 2 additions & 6 deletions examples/example_spine_semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,8 @@

directory = turbopath(__file__).parent

ref_masks = read_nifti(
directory + "/spine_seg/semantic/ref.nii.gz"
)
pred_masks = read_nifti(
directory + "/spine_seg/semantic/pred.nii.gz"
)
ref_masks = read_nifti(directory + "/spine_seg/semantic/ref.nii.gz")
pred_masks = read_nifti(directory + "/spine_seg/semantic/pred.nii.gz")

sample = SemanticPair(pred_masks, ref_masks)

Expand Down
19 changes: 17 additions & 2 deletions panoptica/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ def __init__(
expected_input: Type[SemanticPair] | Type[UnmatchedInstancePair] | Type[MatchedInstancePair] = MatchedInstancePair,
instance_approximator: InstanceApproximator | None = None,
instance_matcher: InstanceMatchingAlgorithm | None = None,
log_times: bool = False,
verbose: bool = False,
iou_threshold: float = 0.5,
) -> None:
"""Creates a Panoptic_Evaluator, that saves some parameters to be used for all subsequent evaluations
Expand All @@ -36,22 +38,30 @@ def __init__(
self.__instance_approximator = instance_approximator
self.__instance_matcher = instance_matcher
self.__iou_threshold = iou_threshold
self.__log_times = log_times
self.__verbose = verbose

@measure_time
def evaluate(self, processing_pair: _ProcessingPair) -> tuple[PanopticaResult, dict[str, _ProcessingPair]]:
def evaluate(
self, processing_pair: SemanticPair | UnmatchedInstancePair | MatchedInstancePair | PanopticaResult
) -> tuple[PanopticaResult, dict[str, _ProcessingPair]]:
assert type(processing_pair) == self.__expected_input, f"input not of expected type {self.__expected_input}"
return panoptic_evaluate(
processing_pair=processing_pair,
instance_approximator=self.__instance_approximator,
instance_matcher=self.__instance_matcher,
iou_threshold=self.__iou_threshold,
log_times=self.__log_times,
verbose=self.__verbose,
)


def panoptic_evaluate(
processing_pair: SemanticPair | UnmatchedInstancePair | MatchedInstancePair | PanopticaResult,
instance_approximator: InstanceApproximator | None = None,
instance_matcher: InstanceMatchingAlgorithm | None = None,
log_times: bool = False,
verbose: bool = False,
iou_threshold: float = 0.5,
**kwargs,
) -> tuple[PanopticaResult, dict[str, _ProcessingPair]]:
Expand Down Expand Up @@ -82,13 +92,16 @@ def panoptic_evaluate(
>>> panoptic_evaluate(SemanticPair(...), instance_approximator=InstanceApproximator(), iou_threshold=0.6)
(PanopticaResult(...), {'UnmatchedInstanceMap': _ProcessingPair(...), 'MatchedInstanceMap': _ProcessingPair(...)})
"""
print("Panoptic: Start Evaluation")
debug_data: dict[str, _ProcessingPair] = {}
# First Phase: Instance Approximation
if isinstance(processing_pair, PanopticaResult):
print("-- Input was Panoptic Result, will just return")
return processing_pair, debug_data

if isinstance(processing_pair, SemanticPair):
assert instance_approximator is not None, "Got SemanticPair but not InstanceApproximator"
print("-- Got SemanticPair, will approximate instances")
processing_pair = instance_approximator.approximate_instances(processing_pair)
debug_data["UnmatchedInstanceMap"] = processing_pair.copy()

Expand All @@ -97,6 +110,7 @@ def panoptic_evaluate(
processing_pair = _handle_zero_instances_cases(processing_pair)

if isinstance(processing_pair, UnmatchedInstancePair):
print("-- Got UnmatchedInstancePair, will match instances")
assert instance_matcher is not None, "Got UnmatchedInstancePair but not InstanceMatchingAlgorithm"
processing_pair = instance_matcher.match_instances(processing_pair)

Expand All @@ -107,6 +121,7 @@ def panoptic_evaluate(
processing_pair = _handle_zero_instances_cases(processing_pair)

if isinstance(processing_pair, MatchedInstancePair):
print("-- Got MatchedInstancePair, will evaluate instances")
processing_pair = evaluate_matched_instance(processing_pair, iou_threshold=iou_threshold)

if isinstance(processing_pair, PanopticaResult):
Expand All @@ -131,7 +146,7 @@ def _handle_zero_instances_cases(
n_reference_instance = processing_pair.n_reference_instance
n_prediction_instance = processing_pair.n_prediction_instance
# Handle cases where either the reference or the prediction is empty
if n_prediction_instance == 0 or n_reference_instance == 0:
if n_prediction_instance == 0 and n_reference_instance == 0:
# Both references and predictions are empty, perfect match
return PanopticaResult(
num_ref_instances=0,
Expand Down
2 changes: 2 additions & 0 deletions panoptica/instance_approximator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from panoptica.utils.datatypes import SemanticPair, UnmatchedInstancePair, MatchedInstancePair
from panoptica._functionals import _connected_components, CCABackend
from panoptica.utils.numpy_utils import _get_smallest_fitting_uint
from panoptica.timing import measure_time
import numpy as np


Expand Down Expand Up @@ -47,6 +48,7 @@ def _approximate_instances(self, semantic_pair: SemanticPair, **kwargs) -> Unmat
"""
pass

@measure_time
def approximate_instances(self, semantic_pair: SemanticPair, **kwargs) -> UnmatchedInstancePair | MatchedInstancePair:
"""
Perform instance approximation on the given SemanticPair.
Expand Down
2 changes: 2 additions & 0 deletions panoptica/instance_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
from panoptica.utils.datatypes import MatchedInstancePair
from panoptica.result import PanopticaResult
from panoptica.metrics import _compute_iou, _compute_dice_coefficient, _average_symmetric_surface_distance
from panoptica.timing import measure_time
import numpy as np


@measure_time
def evaluate_matched_instance(matched_instance_pair: MatchedInstancePair, iou_threshold: float, **kwargs) -> PanopticaResult:
"""
Map instance labels based on the provided labelmap and create a MatchedInstancePair.
Expand Down
3 changes: 3 additions & 0 deletions panoptica/instance_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
MatchedInstancePair,
UnmatchedInstancePair,
)
from panoptica.timing import measure_time
from scipy.optimize import linear_sum_assignment


Expand Down Expand Up @@ -50,6 +51,7 @@ def _match_instances(self, unmatched_instance_pair: UnmatchedInstancePair, **kwa
"""
pass

@measure_time
def match_instances(self, unmatched_instance_pair: UnmatchedInstancePair, **kwargs) -> MatchedInstancePair:
"""
Perform instance matching on the given UnmatchedInstancePair.
Expand Down Expand Up @@ -114,6 +116,7 @@ def _match_instances(self, unmatched_instance_pair: UnmatchedInstancePair, **kwa
"""
ref_labels = unmatched_instance_pair.ref_labels
pred_labels = unmatched_instance_pair.pred_labels
# TODO bounding boxes first, then only calc iou over bboxes collisions
iou_matrix = _calc_iou_matrix(
unmatched_instance_pair.prediction_arr.flatten(),
unmatched_instance_pair.reference_arr.flatten(),
Expand Down
8 changes: 4 additions & 4 deletions panoptica/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def rq(self) -> float:
float: Recognition Quality (RQ).
"""
if self.tp == 0:
return 0.0
return 0.0 if self.num_pred_instances + self.num_ref_instances > 0 else np.nan
return self.tp / (self.tp + 0.5 * self.fp + 0.5 * self.fn)

@property
Expand All @@ -146,7 +146,7 @@ def sq(self) -> float:
float: Segmentation Quality (SQ).
"""
if self.tp == 0:
return 0.0
return 0.0 if self.num_pred_instances + self.num_ref_instances > 0 else np.nan
return np.sum(self._iou_list) / self.tp

@property
Expand Down Expand Up @@ -178,7 +178,7 @@ def sq_dsc(self) -> float:
float: Average Dice coefficient.
"""
if self.tp == 0:
return 0.0
return 0.0 if self.num_pred_instances + self.num_ref_instances > 0 else np.nan
return np.sum(self._dice_list) / self.tp

@property
Expand Down Expand Up @@ -210,7 +210,7 @@ def instance_assd(self) -> float:
float: average symmetric surface distance.
"""
if self.tp == 0:
return 0.0
return np.nan if self.num_pred_instances + self.num_ref_instances == 0 else np.inf
return np.sum(self._assd_list) / self.tp

@property
Expand Down
2 changes: 1 addition & 1 deletion panoptica/timing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ def wrapper(*args, **kwargs):
result = func(*args, **kwargs)
end_time = time.time()
elapsed_time = end_time - start_time
print(f"{func.__name__} took {elapsed_time} seconds to execute.")
print(f"-- {func.__name__} took {elapsed_time} seconds to execute.")
return result

return wrapper
8 changes: 4 additions & 4 deletions panoptica/utils/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ class _ProcessingPair(ABC):
prediction_arr: np.ndarray
reference_arr: np.ndarray
# unique labels without zero
ref_labels: tuple[int]
pred_labels: tuple[int]
ref_labels: tuple[int, ...]
pred_labels: tuple[int, ...]
n_dim: int

def __init__(self, prediction_arr: np.ndarray, reference_arr: np.ndarray, dtype: type | None) -> None:
Expand All @@ -35,8 +35,8 @@ def __init__(self, prediction_arr: np.ndarray, reference_arr: np.ndarray, dtype:
self.reference_arr = reference_arr
self.dtype = dtype
self.n_dim = reference_arr.ndim
self.ref_labels: tuple[int] = tuple(_unique_without_zeros(reference_arr)) # type:ignore
self.pred_labels: tuple[int] = tuple(_unique_without_zeros(prediction_arr)) # type:ignore
self.ref_labels: tuple[int, ...] = tuple(_unique_without_zeros(reference_arr)) # type:ignore
self.pred_labels: tuple[int, ...] = tuple(_unique_without_zeros(prediction_arr)) # type:ignore

# Make all variables read-only!
def __setattr__(self, attr, value):
Expand Down

0 comments on commit 6d805f4

Please sign in to comment.