Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Maximize matcher #85

Merged
merged 8 commits into from
Jan 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion panoptica/_functionals.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,9 @@ def _calc_matching_metric_of_overlapping_labels(
(i, (instance_pairs[idx][2], instance_pairs[idx][3]))
for idx, i in enumerate(mm_values)
]
mm_pairs = sorted(mm_pairs, key=lambda x: x[0], reverse=matching_metric.decreasing)
mm_pairs = sorted(
mm_pairs, key=lambda x: x[0], reverse=not matching_metric.decreasing
)

return mm_pairs

Expand Down
2 changes: 1 addition & 1 deletion panoptica/instance_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
import numpy as np

from panoptica.metrics import (
Metrics,
_MatchingMetric,
)
from panoptica.panoptic_result import PanopticaResult
from panoptica.timing import measure_time
from panoptica.utils import EdgeCaseHandler
from panoptica.utils.processing_pair import MatchedInstancePair
from panoptica.metrics import Metrics


def evaluate_matched_instance(
Expand Down
96 changes: 86 additions & 10 deletions panoptica/instance_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,6 @@ def map_instance_labels(
pred_labelmap = labelmap.get_one_to_one_dictionary()
ref_matched_labels = list([r for r in ref_labels if r in pred_labelmap.values()])

n_matched_instances = len(ref_matched_labels)

# assign missed instances to next unused labels sequentially
missed_ref_labels = list([r for r in ref_labels if r not in ref_matched_labels])
missed_pred_labels = list([p for p in pred_labels if p not in pred_labelmap])
Expand All @@ -127,11 +125,6 @@ def map_instance_labels(
matched_instance_pair = MatchedInstancePair(
prediction_arr=prediction_arr_relabeled,
reference_arr=processing_pair._reference_arr,
missed_reference_labels=missed_ref_labels,
missed_prediction_labels=missed_pred_labels,
n_prediction_instance=processing_pair.n_prediction_instance,
n_reference_instance=processing_pair.n_reference_instance,
matched_instances=ref_matched_labels,
)
return matched_instance_pair

Expand Down Expand Up @@ -223,18 +216,101 @@ def _match_instances(

class MaximizeMergeMatching(InstanceMatchingAlgorithm):
"""
Instance matching algorithm that performs many-to-one matching based on metric. Will merge if combined instance metric is greater than individual one
Instance matching algorithm that performs many-to-one matching based on metric. Will merge if combined instance metric is greater than individual one. Only matches if at least a single instance exceeds the threshold


Methods:
_match_instances(self, unmatched_instance_pair: UnmatchedInstancePair, **kwargs) -> Instance_Label_Map:
Perform one-to-one instance matching based on IoU values.

Raises:
AssertionError: If the specified IoU threshold is not within the valid range.
"""

pass
def __init__(
self,
matching_metric: _MatchingMetric = Metrics.IOU,
matching_threshold: float = 0.5,
) -> None:
"""
Initialize the MaximizeMergeMatching instance.

Args:
matching_metric (_MatchingMetric): The metric to be used for matching.
matching_threshold (float, optional): The metric threshold for matching instances. Defaults to 0.5.

Raises:
AssertionError: If the specified IoU threshold is not within the valid range.
"""
self.matching_metric = matching_metric
self.matching_threshold = matching_threshold

def _match_instances(
self,
unmatched_instance_pair: UnmatchedInstancePair,
**kwargs,
) -> InstanceLabelMap:
"""
Perform one-to-one instance matching based on IoU values.

Args:
unmatched_instance_pair (UnmatchedInstancePair): The unmatched instance pair to be matched.
**kwargs: Additional keyword arguments.

Returns:
Instance_Label_Map: The result of the instance matching.
"""
ref_labels = unmatched_instance_pair._ref_labels
# pred_labels = unmatched_instance_pair._pred_labels

# Initialize variables for True Positives (tp) and False Positives (fp)
labelmap = InstanceLabelMap()
score_ref: dict[int, float] = {}

pred_arr, ref_arr = (
unmatched_instance_pair._prediction_arr,
unmatched_instance_pair._reference_arr,
)
mm_pairs = _calc_matching_metric_of_overlapping_labels(
pred_arr, ref_arr, ref_labels, matching_metric=self.matching_metric
)

# Loop through matched instances to compute PQ components
for matching_score, (ref_label, pred_label) in mm_pairs:
if labelmap.contains_pred(pred_label=pred_label):
# skip if prediction label is already matched
continue
if labelmap.contains_ref(ref_label):
pred_labels_ = labelmap.get_pred_labels_matched_to_ref(ref_label)
new_score = self.new_combination_score(
pred_labels_, pred_label, ref_label, unmatched_instance_pair
)
if new_score > score_ref[ref_label]:
labelmap.add_labelmap_entry(pred_label, ref_label)
score_ref[ref_label] = new_score
elif self.matching_metric.score_beats_threshold(
matching_score, self.matching_threshold
):
# Match found, increment true positive count and collect IoU and Dice values
labelmap.add_labelmap_entry(pred_label, ref_label)
score_ref[ref_label] = matching_score
# map label ref_idx to pred_idx
return labelmap

def new_combination_score(
self,
pred_labels: list[int],
new_pred_label: int,
ref_label: int,
unmatched_instance_pair: UnmatchedInstancePair,
):
pred_labels.append(new_pred_label)
score = self.matching_metric(
unmatched_instance_pair.reference_arr,
prediction_arr=unmatched_instance_pair.prediction_arr,
ref_instance_idx=ref_label,
pred_instance_idx=pred_labels,
)
return score


class MatchUntilConvergenceMatching(InstanceMatchingAlgorithm):
Expand Down
4 changes: 2 additions & 2 deletions panoptica/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
)
from panoptica.metrics.iou import _compute_instance_iou, _compute_iou
from panoptica.metrics.metrics import (
EvalMetric,
Metrics,
ListMetric,
EvalMetric,
MetricDict,
Metrics,
_MatchingMetric,
)
6 changes: 4 additions & 2 deletions panoptica/metrics/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,15 @@ def __call__(
reference_arr: np.ndarray,
prediction_arr: np.ndarray,
ref_instance_idx: int | None = None,
pred_instance_idx: int | None = None,
pred_instance_idx: int | list[int] | None = None,
*args,
**kwargs,
):
if ref_instance_idx is not None and pred_instance_idx is not None:
reference_arr = reference_arr.copy() == ref_instance_idx
prediction_arr = prediction_arr.copy() == pred_instance_idx
if isinstance(pred_instance_idx, int):
pred_instance_idx = [pred_instance_idx]
prediction_arr = np.isin(prediction_arr.copy(), pred_instance_idx)
return self._metric_function(reference_arr, prediction_arr, *args, **kwargs)

def __eq__(self, __value: object) -> bool:
Expand Down
9 changes: 9 additions & 0 deletions panoptica/utils/processing_pair.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,15 @@ def add_labelmap_entry(self, pred_labels: list[int] | int, ref_label: int):
)
self.labelmap[p] = ref_label

def get_pred_labels_matched_to_ref(self, ref_label: int):
return [k for k, v in self.labelmap.items() if v == ref_label]

def contains_pred(self, pred_label: int):
return pred_label in self.labelmap

def contains_ref(self, ref_label: int):
return ref_label in self.labelmap.values()

def contains_and(
self, pred_label: int | None = None, ref_label: int | None = None
) -> bool:
Expand Down
73 changes: 71 additions & 2 deletions unit_tests/test_panoptic_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

from panoptica.panoptic_evaluator import Panoptic_Evaluator
from panoptica.instance_approximator import ConnectedComponentsInstanceApproximator
from panoptica.instance_matcher import NaiveThresholdMatching
from panoptica.metrics import Metrics
from panoptica.instance_matcher import NaiveThresholdMatching, MaximizeMergeMatching
from panoptica.metrics import _MatchingMetric, Metrics
from panoptica.utils.processing_pair import SemanticPair


Expand Down Expand Up @@ -238,3 +238,72 @@ def test_dtype_evaluation(self):
self.assertEqual(result.fp, 0)
self.assertEqual(result.sq, 0.75)
self.assertEqual(result.pq, 0.75)

def test_simple_evaluation_maximize_matcher(self):
a = np.zeros([50, 50], dtype=np.uint16)
b = a.copy().astype(a.dtype)
a[20:40, 10:20] = 1
b[20:35, 10:20] = 2

sample = SemanticPair(b, a)

evaluator = Panoptic_Evaluator(
expected_input=SemanticPair,
instance_approximator=ConnectedComponentsInstanceApproximator(),
instance_matcher=MaximizeMergeMatching(),
)

result, debug_data = evaluator.evaluate(sample)
print(result)
self.assertEqual(result.tp, 1)
self.assertEqual(result.fp, 0)
self.assertEqual(result.sq, 0.75)
self.assertEqual(result.pq, 0.75)

def test_simple_evaluation_maximize_matcher_overlaptwo(self):
a = np.zeros([50, 50], dtype=np.uint16)
b = a.copy().astype(a.dtype)
a[20:40, 10:20] = 1
b[20:35, 10:20] = 2
b[36:38, 10:20] = 3

sample = SemanticPair(b, a)

evaluator = Panoptic_Evaluator(
expected_input=SemanticPair,
instance_approximator=ConnectedComponentsInstanceApproximator(),
instance_matcher=MaximizeMergeMatching(),
)

result, debug_data = evaluator.evaluate(sample)
print(result)
self.assertEqual(result.tp, 1)
self.assertEqual(result.fp, 0)
self.assertEqual(result.sq, 0.85)
self.assertEqual(result.pq, 0.85)

def test_simple_evaluation_maximize_matcher_overlap(self):
a = np.zeros([50, 50], dtype=np.uint16)
b = a.copy().astype(a.dtype)
a[20:40, 10:20] = 1
b[20:35, 10:20] = 2
b[36:38, 10:20] = 3
# match the two above to 1 and the 4 to nothing (FP)
b[39:47, 10:20] = 4

sample = SemanticPair(b, a)

evaluator = Panoptic_Evaluator(
expected_input=SemanticPair,
instance_approximator=ConnectedComponentsInstanceApproximator(),
instance_matcher=MaximizeMergeMatching(),
)

result, debug_data = evaluator.evaluate(sample)
print(result)
self.assertEqual(result.tp, 1)
self.assertEqual(result.fp, 1)
self.assertEqual(result.sq, 0.85)
self.assertAlmostEqual(result.pq, 0.56666666)
self.assertAlmostEqual(result.rq, 0.66666666)
self.assertAlmostEqual(result.sq_dsc, 0.9189189189189)