Skip to content

Commit

Permalink
Merge pull request #41 from BrainLesion/speedup_matcher
Browse files Browse the repository at this point in the history
Speedup instance matcher and evaluator
  • Loading branch information
Hendrik-code authored Nov 15, 2023
2 parents a5366ba + 87ab267 commit 5df2885
Show file tree
Hide file tree
Showing 7 changed files with 238 additions and 68 deletions.
66 changes: 66 additions & 0 deletions panoptica/_functionals.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,67 @@
import numpy as np
from panoptica.metrics import _compute_instance_iou
from panoptica.utils.constants import CCABackend
from panoptica.utils.numpy_utils import _get_bbox_nd
from multiprocessing import Pool


def _calc_overlapping_labels(
prediction_arr: np.ndarray,
reference_arr: np.ndarray,
ref_labels: tuple[int, ...],
) -> list[tuple[int, int]]:
"""Calculates the pairs of labels that are overlapping in at least one voxel (fast)
Args:
prediction_arr (np.ndarray): Numpy array containing the prediction labels.
reference_arr (np.ndarray): Numpy array containing the reference labels.
ref_labels (list[int]): List of unique reference labels.
Returns:
_type_: _description_
"""
overlap_arr = prediction_arr.astype(np.uint32)
max_ref = max(ref_labels) + 1
overlap_arr = (overlap_arr * max_ref) + reference_arr
overlap_arr[reference_arr == 0] = 0
# overlapping_indices = [(i % (max_ref), i // (max_ref)) for i in np.unique(overlap_arr) if i > max_ref]
# instance_pairs = [(reference_arr, prediction_arr, i, j) for i, j in overlapping_indices]

# (ref, pred)
return [(i % (max_ref), i // (max_ref)) for i in np.unique(overlap_arr) if i > max_ref]


def _calc_iou_of_overlapping_labels(
prediction_arr: np.ndarray, reference_arr: np.ndarray, ref_labels: tuple[int, ...], pred_labels: tuple[int, ...]
) -> list[tuple[float, tuple[int, int]]]:
"""Calculates the IOU for all overlapping labels (fast!)
Args:
prediction_arr (np.ndarray): Numpy array containing the prediction labels.
reference_arr (np.ndarray): Numpy array containing the reference labels.
ref_labels (list[int]): List of unique reference labels.
pred_labels (list[int]): List of unique prediction labels.
Returns:
list[tuple[float, tuple[int, int]]]: List of pairs in style: (iou, (ref_label, pred_label))
"""
instance_pairs = [
(reference_arr, prediction_arr, i[0], i[1])
for i in _calc_overlapping_labels(
prediction_arr=prediction_arr,
reference_arr=reference_arr,
ref_labels=ref_labels,
)
]
with Pool() as pool:
iou_values = pool.starmap(_compute_instance_iou, instance_pairs)

iou_pairs = [(i, (instance_pairs[idx][2], instance_pairs[idx][3])) for idx, i in enumerate(iou_values)]
iou_pairs = sorted(iou_pairs, key=lambda x: x[0], reverse=True)

return iou_pairs


def _calc_iou_matrix(prediction_arr: np.ndarray, reference_arr: np.ndarray, ref_labels: tuple[int, ...], pred_labels: tuple[int, ...]):
"""
Calculate the Intersection over Union (IoU) matrix between reference and prediction arrays.
Expand Down Expand Up @@ -92,3 +150,11 @@ def _connected_components(
raise NotImplementedError(cca_backend)

return cc_arr.astype(array.dtype), n_instances


def _get_paired_crop(prediction_arr: np.ndarray, reference_arr: np.ndarray, px_pad: int = 2):
assert prediction_arr.shape == reference_arr.shape
bbox1 = _get_bbox_nd(prediction_arr, px_dist=px_pad)
bbox2 = _get_bbox_nd(reference_arr, px_dist=px_pad)

return tuple(slice(min(bbox1[idx].start, bbox2[idx].start), max(bbox1[idx].stop, bbox2[idx].stop)) for idx in range(len(bbox1)))
3 changes: 3 additions & 0 deletions panoptica/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,9 @@ def panoptic_evaluate(
print("-- Input was Panoptic Result, will just return")
return processing_pair, debug_data

# Crops away unecessary space of zeroes
processing_pair.crop_data()

if isinstance(processing_pair, SemanticPair):
assert instance_approximator is not None, "Got SemanticPair but not InstanceApproximator"
print("-- Got SemanticPair, will approximate instances")
Expand Down
16 changes: 9 additions & 7 deletions panoptica/instance_approximator.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def approximate_instances(self, semantic_pair: SemanticPair, **kwargs) -> Unmatc
# Call algorithm
instance_pair = self._approximate_instances(semantic_pair, **kwargs)
# Check validity
pred_labels, ref_labels = instance_pair.pred_labels, instance_pair.ref_labels
pred_labels, ref_labels = instance_pair._pred_labels, instance_pair._ref_labels
pred_label_range = (np.min(pred_labels), np.max(pred_labels)) if len(pred_labels) > 0 else (0, 0)
ref_label_range = (np.min(ref_labels), np.max(ref_labels)) if len(ref_labels) > 0 else (0, 0)
#
Expand All @@ -75,8 +75,8 @@ def approximate_instances(self, semantic_pair: SemanticPair, **kwargs) -> Unmatc
# Set dtype to smalles fitting uint
max_value = max(np.max(pred_label_range[1]), np.max(ref_label_range[1]))
dtype = _get_smallest_fitting_uint(max_value)
instance_pair.prediction_arr.astype(dtype)
instance_pair.reference_arr.astype(dtype)
instance_pair._prediction_arr.astype(dtype)
instance_pair._reference_arr.astype(dtype)
return instance_pair


Expand Down Expand Up @@ -124,13 +124,15 @@ def _approximate_instances(self, semantic_pair: SemanticPair, **kwargs) -> Unmat
cca_backend = CCABackend.cc3d if semantic_pair.n_dim >= 3 else CCABackend.scipy
assert cca_backend is not None

empty_prediction = len(semantic_pair.pred_labels) == 0
empty_reference = len(semantic_pair.ref_labels) == 0
empty_prediction = len(semantic_pair._pred_labels) == 0
empty_reference = len(semantic_pair._ref_labels) == 0
prediction_arr, n_prediction_instance = (
_connected_components(semantic_pair.prediction_arr, cca_backend) if not empty_prediction else (semantic_pair.prediction_arr, 0)
_connected_components(semantic_pair._prediction_arr, cca_backend)
if not empty_prediction
else (semantic_pair._prediction_arr, 0)
)
reference_arr, n_reference_instance = (
_connected_components(semantic_pair.reference_arr, cca_backend) if not empty_reference else (semantic_pair.reference_arr, 0)
_connected_components(semantic_pair._reference_arr, cca_backend) if not empty_reference else (semantic_pair._reference_arr, 0)
)
return UnmatchedInstancePair(
prediction_arr=prediction_arr,
Expand Down
64 changes: 43 additions & 21 deletions panoptica/instance_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from panoptica.metrics import _compute_iou, _compute_dice_coefficient, _average_symmetric_surface_distance
from panoptica.timing import measure_time
import numpy as np
import gc
from multiprocessing import Pool


@measure_time
Expand All @@ -26,29 +28,49 @@ def evaluate_matched_instance(matched_instance_pair: MatchedInstancePair, iou_th
# Initialize variables for True Positives (tp)
tp, dice_list, iou_list, assd_list = 0, [], [], []

reference_arr, prediction_arr = matched_instance_pair.reference_arr, matched_instance_pair.prediction_arr
ref_labels = matched_instance_pair.ref_labels
reference_arr, prediction_arr = matched_instance_pair._reference_arr, matched_instance_pair._prediction_arr
ref_labels = matched_instance_pair._ref_labels

# Use concurrent.futures.ThreadPoolExecutor for parallelization
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = [
executor.submit(
_evaluate_instance,
reference_arr,
prediction_arr,
ref_idx,
iou_threshold,
)
for ref_idx in ref_labels
]
# instance_pairs = _calc_overlapping_labels(
# prediction_arr=prediction_arr,
# reference_arr=reference_arr,
# ref_labels=ref_labels,
# )
# instance_pairs = [(ra, pa, rl, iou_threshold) for (ra, pa, rl, pl) in instance_pairs]

instance_pairs = [(reference_arr, prediction_arr, ref_idx, iou_threshold) for ref_idx in ref_labels]
with Pool() as pool:
metric_values = pool.starmap(_evaluate_instance, instance_pairs)

for future in concurrent.futures.as_completed(futures):
tp_i, dice_i, iou_i, assd_i = future.result()
tp += tp_i
if dice_i is not None and iou_i is not None and assd_i is not None:
dice_list.append(dice_i)
iou_list.append(iou_i)
assd_list.append(assd_i)
for tp_i, dice_i, iou_i, assd_i in metric_values:
tp += tp_i
if dice_i is not None and iou_i is not None and assd_i is not None:
dice_list.append(dice_i)
iou_list.append(iou_i)
assd_list.append(assd_i)

# Use concurrent.futures.ThreadPoolExecutor for parallelization
# with concurrent.futures.ThreadPoolExecutor() as executor:
# futures = [
# executor.submit(
# _evaluate_instance,
# reference_arr,
# prediction_arr,
# ref_idx,
# iou_threshold,
# )
# for ref_idx in ref_labels
# ]
#
# for future in concurrent.futures.as_completed(futures):
# tp_i, dice_i, iou_i, assd_i = future.result()
# tp += tp_i
# if dice_i is not None and iou_i is not None and assd_i is not None:
# dice_list.append(dice_i)
# iou_list.append(iou_i)
# assd_list.append(assd_i)
# del future
# gc.collect()
# Create and return the PanopticaResult object with computed metrics
return PanopticaResult(
num_ref_instances=matched_instance_pair.n_reference_instance,
Expand Down
38 changes: 17 additions & 21 deletions panoptica/instance_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import numpy as np

from panoptica._functionals import _calc_iou_matrix, _map_labels
from panoptica._functionals import _calc_iou_matrix, _map_labels, _calc_iou_of_overlapping_labels
from panoptica.utils.datatypes import (
InstanceLabelMap,
MatchedInstancePair,
Expand Down Expand Up @@ -68,6 +68,9 @@ def match_instances(self, unmatched_instance_pair: UnmatchedInstancePair, **kwar
return map_instance_labels(unmatched_instance_pair.copy(), instance_labelmap)


from multiprocessing import Pool


class NaiveThresholdMatching(InstanceMatchingAlgorithm):
"""
Instance matching algorithm that performs one-to-one matching based on IoU values.
Expand Down Expand Up @@ -114,29 +117,22 @@ def _match_instances(self, unmatched_instance_pair: UnmatchedInstancePair, **kwa
Returns:
Instance_Label_Map: The result of the instance matching.
"""
ref_labels = unmatched_instance_pair.ref_labels
pred_labels = unmatched_instance_pair.pred_labels
# TODO bounding boxes first, then only calc iou over bboxes collisions
iou_matrix = _calc_iou_matrix(
unmatched_instance_pair.prediction_arr.flatten(),
unmatched_instance_pair.reference_arr.flatten(),
ref_labels,
pred_labels,
)
# Use linear_sum_assignment to find the best matches
ref_indices, pred_indices = linear_sum_assignment(-iou_matrix)
ref_labels = unmatched_instance_pair._ref_labels
pred_labels = unmatched_instance_pair._pred_labels

# Initialize variables for True Positives (tp) and False Positives (fp)
labelmap = InstanceLabelMap()

pred_arr, ref_arr = unmatched_instance_pair._prediction_arr, unmatched_instance_pair._reference_arr
iou_pairs = _calc_iou_of_overlapping_labels(pred_arr, ref_arr, ref_labels, pred_labels)

# Loop through matched instances to compute PQ components
for ref_idx, pred_idx in zip(ref_indices, pred_indices):
if labelmap.contains_or(pred_labels[pred_idx], ref_labels[ref_idx]) and not self.allow_many_to_one:
for iou, (ref_label, pred_label) in iou_pairs:
if labelmap.contains_or(pred_label, ref_label) and not False:
continue # -> doesnt make speed difference
iou = iou_matrix[ref_idx][pred_idx]
if iou >= self.iou_threshold:
if iou >= 0.5:
# Match found, increment true positive count and collect IoU and Dice values
labelmap.add_labelmap_entry(pred_labels[pred_idx], ref_labels[ref_idx])
labelmap.add_labelmap_entry(pred_label, ref_label)
# map label ref_idx to pred_idx
return labelmap

Expand All @@ -157,10 +153,10 @@ def map_instance_labels(processing_pair: UnmatchedInstancePair, labelmap: Instan
>>> labelmap = [([1, 2], [3, 4]), ([5], [6])]
>>> result = map_instance_labels(unmatched_instance_pair, labelmap)
"""
prediction_arr = processing_pair.prediction_arr
prediction_arr = processing_pair._prediction_arr

ref_labels = processing_pair.ref_labels
pred_labels = processing_pair.pred_labels
ref_labels = processing_pair._ref_labels
pred_labels = processing_pair._pred_labels

ref_matched_labels = []
label_counter = max(ref_labels) + 1
Expand All @@ -185,7 +181,7 @@ def map_instance_labels(processing_pair: UnmatchedInstancePair, labelmap: Instan
# Build a MatchedInstancePair out of the newly derived data
matched_instance_pair = MatchedInstancePair(
prediction_arr=prediction_arr_relabeled,
reference_arr=processing_pair.reference_arr,
reference_arr=processing_pair._reference_arr,
missed_reference_labels=missed_ref_labels,
missed_prediction_labels=missed_pred_labels,
n_prediction_instance=processing_pair.n_prediction_instance,
Expand Down
Loading

0 comments on commit 5df2885

Please sign in to comment.