Skip to content

Commit

Permalink
Autoformat with black
Browse files Browse the repository at this point in the history
  • Loading branch information
brainless-bot[bot] committed Aug 5, 2024
1 parent d25fabc commit 3e89dd8
Show file tree
Hide file tree
Showing 15 changed files with 184 additions and 55 deletions.
4 changes: 3 additions & 1 deletion examples/example_spine_instance_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
reference_mask = read_nifti(directory + "/spine_seg/matched_instance/ref.nii.gz")
prediction_mask = read_nifti(directory + "/spine_seg/matched_instance/pred.nii.gz")

evaluator = Panoptica_Evaluator.load_from_config_name("panoptica_evaluator_unmatched_instance")
evaluator = Panoptica_Evaluator.load_from_config_name(
"panoptica_evaluator_unmatched_instance"
)


with cProfile.Profile() as pr:
Expand Down
4 changes: 3 additions & 1 deletion examples/example_spine_semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@

with cProfile.Profile() as pr:
if __name__ == "__main__":
result, debug_data = evaluator.evaluate(prediction_mask, reference_mask)["ungrouped"]
result, debug_data = evaluator.evaluate(prediction_mask, reference_mask)[
"ungrouped"
]
print(result)

pr.dump_stats(directory + "/semantic_example.log")
8 changes: 6 additions & 2 deletions panoptica/instance_approximator.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,9 @@ def _approximate_instances(
pass

def _yaml_repr(cls, node) -> dict:
raise NotImplementedError(f"Tried to get yaml representation of abstract class {cls.__name__}")
raise NotImplementedError(
f"Tried to get yaml representation of abstract class {cls.__name__}"
)
return {}

def approximate_instances(
Expand Down Expand Up @@ -146,7 +148,9 @@ def _approximate_instances(
"""
cca_backend = self.cca_backend
if cca_backend is None:
cca_backend = CCABackend.cc3d if semantic_pair.n_dim >= 3 else CCABackend.scipy
cca_backend = (
CCABackend.cc3d if semantic_pair.n_dim >= 3 else CCABackend.scipy
)
assert cca_backend is not None

empty_prediction = len(semantic_pair._pred_labels) == 0
Expand Down
21 changes: 16 additions & 5 deletions panoptica/instance_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,9 @@ def match_instances(
return map_instance_labels(unmatched_instance_pair.copy(), instance_labelmap)

def _yaml_repr(cls, node) -> dict:
raise NotImplementedError(f"Tried to get yaml representation of abstract class {cls.__name__}")
raise NotImplementedError(
f"Tried to get yaml representation of abstract class {cls.__name__}"
)
return {}


Expand Down Expand Up @@ -199,13 +201,20 @@ def _match_instances(
unmatched_instance_pair.prediction_arr,
unmatched_instance_pair.reference_arr,
)
mm_pairs = _calc_matching_metric_of_overlapping_labels(pred_arr, ref_arr, ref_labels, matching_metric=self._matching_metric)
mm_pairs = _calc_matching_metric_of_overlapping_labels(
pred_arr, ref_arr, ref_labels, matching_metric=self._matching_metric
)

# Loop through matched instances to compute PQ components
for matching_score, (ref_label, pred_label) in mm_pairs:
if labelmap.contains_or(pred_label, ref_label) and not self._allow_many_to_one:
if (
labelmap.contains_or(pred_label, ref_label)
and not self._allow_many_to_one
):
continue # -> doesnt make speed difference
if self._matching_metric.score_beats_threshold(matching_score, self._matching_threshold):
if self._matching_metric.score_beats_threshold(
matching_score, self._matching_threshold
):
# Match found, increment true positive count and collect IoU and Dice values
labelmap.add_labelmap_entry(pred_label, ref_label)
# map label ref_idx to pred_idx
Expand Down Expand Up @@ -296,7 +305,9 @@ def _match_instances(
if new_score > score_ref[ref_label]:
labelmap.add_labelmap_entry(pred_label, ref_label)
score_ref[ref_label] = new_score
elif self._matching_metric.score_beats_threshold(matching_score, self._matching_threshold):
elif self._matching_metric.score_beats_threshold(
matching_score, self._matching_threshold
):
# Match found, increment true positive count and collect IoU and Dice values
labelmap.add_labelmap_entry(pred_label, ref_label)
score_ref[ref_label] = matching_score
Expand Down
36 changes: 27 additions & 9 deletions panoptica/panoptica_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,13 @@ def __init__(

self.__segmentation_class_groups = segmentation_class_groups

self.__edge_case_handler = edge_case_handler if edge_case_handler is not None else EdgeCaseHandler()
self.__edge_case_handler = (
edge_case_handler if edge_case_handler is not None else EdgeCaseHandler()
)
if self.__decision_metric is not None:
assert self.__decision_threshold is not None, "decision metric set but no decision threshold for it"
assert (
self.__decision_threshold is not None
), "decision metric set but no decision threshold for it"
#
self.__log_times = log_times
self.__verbose = verbose
Expand Down Expand Up @@ -86,7 +90,9 @@ def evaluate(
verbose: bool | None = None,
) -> dict[str, tuple[PanopticaResult, dict[str, _ProcessingPair]]]:
processing_pair = self.__expected_input(prediction_arr, reference_arr)
assert isinstance(processing_pair, self.__expected_input.value), f"input not of expected type {self.__expected_input}"
assert isinstance(
processing_pair, self.__expected_input.value
), f"input not of expected type {self.__expected_input}"

if self.__segmentation_class_groups is None:
return {
Expand All @@ -105,8 +111,12 @@ def evaluate(
)
}

self.__segmentation_class_groups.has_defined_labels_for(processing_pair.prediction_arr, raise_error=True)
self.__segmentation_class_groups.has_defined_labels_for(processing_pair.reference_arr, raise_error=True)
self.__segmentation_class_groups.has_defined_labels_for(
processing_pair.prediction_arr, raise_error=True
)
self.__segmentation_class_groups.has_defined_labels_for(
processing_pair.reference_arr, raise_error=True
)

result_grouped = {}
for group_name, label_group in self.__segmentation_class_groups.items():
Expand All @@ -118,7 +128,9 @@ def evaluate(
single_instance_mode = label_group.single_instance
processing_pair_grouped = processing_pair.__class__(prediction_arr=prediction_arr_grouped, reference_arr=reference_arr_grouped) # type: ignore
decision_threshold = self.__decision_threshold
if single_instance_mode and not isinstance(processing_pair, MatchedInstancePair):
if single_instance_mode and not isinstance(
processing_pair, MatchedInstancePair
):
processing_pair_grouped = MatchedInstancePair(
prediction_arr=processing_pair_grouped.prediction_arr,
reference_arr=processing_pair_grouped.reference_arr,
Expand All @@ -142,7 +154,9 @@ def evaluate(


def panoptic_evaluate(
processing_pair: SemanticPair | UnmatchedInstancePair | MatchedInstancePair | PanopticaResult,
processing_pair: (
SemanticPair | UnmatchedInstancePair | MatchedInstancePair | PanopticaResult
),
instance_approximator: InstanceApproximator | None = None,
instance_matcher: InstanceMatchingAlgorithm | None = None,
eval_metrics: list[Metric] = [Metric.DSC, Metric.IOU, Metric.ASSD],
Expand Down Expand Up @@ -198,7 +212,9 @@ def panoptic_evaluate(
processing_pair.crop_data()

if isinstance(processing_pair, SemanticPair):
assert instance_approximator is not None, "Got SemanticPair but not InstanceApproximator"
assert (
instance_approximator is not None
), "Got SemanticPair but not InstanceApproximator"
if verbose:
print("-- Got SemanticPair, will approximate instances")
start = perf_counter()
Expand All @@ -218,7 +234,9 @@ def panoptic_evaluate(
if isinstance(processing_pair, UnmatchedInstancePair):
if verbose:
print("-- Got UnmatchedInstancePair, will match instances")
assert instance_matcher is not None, "Got UnmatchedInstancePair but not InstanceMatchingAlgorithm"
assert (
instance_matcher is not None
), "Got UnmatchedInstancePair but not InstanceMatchingAlgorithm"
start = perf_counter()
processing_pair = instance_matcher.match_instances(
processing_pair,
Expand Down
16 changes: 12 additions & 4 deletions panoptica/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,17 @@ def save_from_object(cls, obj: object, file: str | Path):
@classmethod
def load(cls, file: str | Path, registered_class=None):
data = _load_yaml(file, registered_class)
assert isinstance(data, dict), f"The config at {file} is registered to a class. Use load_as_object() instead"
assert isinstance(
data, dict
), f"The config at {file} is registered to a class. Use load_as_object() instead"
return Configuration(data, registered_class=registered_class)

@classmethod
def load_as_object(cls, file: str | Path, registered_class=None):
data = _load_yaml(file, registered_class)
assert not isinstance(data, dict), f"The config at {file} is not registered to a class. Use load() instead"
assert not isinstance(
data, dict
), f"The config at {file} is not registered to a class. Use load() instead"
return data

def save(self, out_file: str | Path):
Expand Down Expand Up @@ -148,7 +152,9 @@ def _register_permanently(cls):
@classmethod
def load_from_config(cls, path: str | Path):
obj = _load_from_config(cls, path)
assert isinstance(obj, cls), f"loaded object was not of the correct class, expected {cls.__name__} but got {type(obj)}"
assert isinstance(
obj, cls
), f"loaded object was not of the correct class, expected {cls.__name__} but got {type(obj)}"
return obj

@classmethod
Expand All @@ -163,7 +169,9 @@ def save_to_config(self, path: str | Path):
@classmethod
def to_yaml(cls, representer, node):
# cls._register_permanently()
assert hasattr(cls, "_yaml_repr"), f"Class {cls.__name__} has no _yaml_repr(cls, node) defined"
assert hasattr(
cls, "_yaml_repr"
), f"Class {cls.__name__} has no _yaml_repr(cls, node) defined"
return representer.represent_mapping("!" + cls.__name__, cls._yaml_repr(node))

@classmethod
Expand Down
7 changes: 6 additions & 1 deletion panoptica/utils/constants.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
from enum import Enum, auto
from panoptica.utils.config import _register_class_to_yaml, _load_from_config, _load_from_config_name, _save_to_config
from panoptica.utils.config import (
_register_class_to_yaml,
_load_from_config,
_load_from_config_name,
_save_to_config,
)
from pathlib import Path
import numpy as np

Expand Down
36 changes: 28 additions & 8 deletions panoptica/utils/edge_case_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,26 @@ def __init__(

self._default_result = default_result
self._edgecase_dict: dict[EdgeCaseZeroTP, EdgeCaseResult] = {}
self._edgecase_dict[EdgeCaseZeroTP.EMPTY_PRED] = empty_prediction_result if empty_prediction_result is not None else default_result
self._edgecase_dict[EdgeCaseZeroTP.EMPTY_REF] = empty_reference_result if empty_reference_result is not None else default_result
self._edgecase_dict[EdgeCaseZeroTP.NO_INSTANCES] = no_instances_result if no_instances_result is not None else default_result
self._edgecase_dict[EdgeCaseZeroTP.NORMAL] = normal if normal is not None else default_result
self._edgecase_dict[EdgeCaseZeroTP.EMPTY_PRED] = (
empty_prediction_result
if empty_prediction_result is not None
else default_result
)
self._edgecase_dict[EdgeCaseZeroTP.EMPTY_REF] = (
empty_reference_result
if empty_reference_result is not None
else default_result
)
self._edgecase_dict[EdgeCaseZeroTP.NO_INSTANCES] = (
no_instances_result if no_instances_result is not None else default_result
)
self._edgecase_dict[EdgeCaseZeroTP.NORMAL] = (
normal if normal is not None else default_result
)

def __call__(self, tp: int, num_pred_instances, num_ref_instances) -> tuple[bool, float | None]:
def __call__(
self, tp: int, num_pred_instances, num_ref_instances
) -> tuple[bool, float | None]:
if tp != 0:
return False, EdgeCaseResult.NONE.value
#
Expand Down Expand Up @@ -117,7 +131,9 @@ def __init__(
},
empty_list_std: EdgeCaseResult = EdgeCaseResult.NAN,
) -> None:
self.__listmetric_zeroTP_handling: dict[Metric, MetricZeroTPEdgeCaseHandling] = listmetric_zeroTP_handling
self.__listmetric_zeroTP_handling: dict[
Metric, MetricZeroTPEdgeCaseHandling
] = listmetric_zeroTP_handling
self.__empty_list_std: EdgeCaseResult = empty_list_std

def handle_zero_tp(
Expand All @@ -130,7 +146,9 @@ def handle_zero_tp(
if tp != 0:
return False, EdgeCaseResult.NONE.value
if metric not in self.__listmetric_zeroTP_handling:
raise NotImplementedError(f"Metric {metric} encountered zero TP, but no edge handling available")
raise NotImplementedError(
f"Metric {metric} encountered zero TP, but no edge handling available"
)

return self.__listmetric_zeroTP_handling[metric](
tp=tp,
Expand Down Expand Up @@ -167,7 +185,9 @@ def _yaml_repr(cls, node) -> dict:

print()
# print(handler.get_metric_zero_tp_handle(ListMetric.IOU))
r = handler.handle_zero_tp(Metric.IOU, tp=0, num_pred_instances=1, num_ref_instances=1)
r = handler.handle_zero_tp(
Metric.IOU, tp=0, num_pred_instances=1, num_ref_instances=1
)
print(r)

iou_test = MetricZeroTPEdgeCaseHandling(
Expand Down
19 changes: 15 additions & 4 deletions panoptica/utils/filepath.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
from pathlib import Path


def search_path(basepath: str | Path, query: str, verbose: bool = False, suppress: bool = False) -> list[Path]:
def search_path(
basepath: str | Path, query: str, verbose: bool = False, suppress: bool = False
) -> list[Path]:
"""Searches from basepath with query
Args:
basepath: ground path to look into
Expand All @@ -16,7 +18,9 @@ def search_path(basepath: str | Path, query: str, verbose: bool = False, suppres
All found paths
"""
basepath = str(basepath)
assert os.path.exists(basepath), f"basepath for search_path() doesnt exist, got {basepath}"
assert os.path.exists(
basepath
), f"basepath for search_path() doesnt exist, got {basepath}"
if not basepath.endswith("/"):
basepath += "/"
print(f"search_path: in {basepath}{query}") if verbose else None
Expand All @@ -28,9 +32,16 @@ def search_path(basepath: str | Path, query: str, verbose: bool = False, suppres

# Find config path
def config_by_name(name: str) -> Path:
directory = Path(__file__.replace("////", "/").replace("\\\\", "/").replace("//", "/").replace("\\", "/")).parent.parent
directory = Path(
__file__.replace("////", "/")
.replace("\\\\", "/")
.replace("//", "/")
.replace("\\", "/")
).parent.parent
if not name.endswith(".yaml"):
name += ".yaml"
p = search_path(directory, query=f"**/{name}", suppress=True)
assert len(p) == 1, f"Did not find exactly one config yaml with name {name} in directory {directory}, got {p}"
assert (
len(p) == 1
), f"Did not find exactly one config yaml with name {name} in directory {directory}, got {p}"
return p[0]
16 changes: 12 additions & 4 deletions panoptica/utils/instancelabelmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ def add_labelmap_entry(self, pred_labels: list[int] | int, ref_label: int):
if not isinstance(pred_labels, list):
pred_labels = [pred_labels]
assert isinstance(ref_label, int), "add_labelmap_entry: got no int as ref_label"
assert np.all([isinstance(r, int) for r in pred_labels]), "add_labelmap_entry: got no int as pred_label"
assert np.all(
[isinstance(r, int) for r in pred_labels]
), "add_labelmap_entry: got no int as pred_label"
for p in pred_labels:
if p in self.labelmap and self.labelmap[p] != ref_label:
raise Exception(
Expand All @@ -30,12 +32,16 @@ def contains_pred(self, pred_label: int):
def contains_ref(self, ref_label: int):
return ref_label in self.labelmap.values()

def contains_and(self, pred_label: int | None = None, ref_label: int | None = None) -> bool:
def contains_and(
self, pred_label: int | None = None, ref_label: int | None = None
) -> bool:
pred_in = True if pred_label is None else pred_label in self.labelmap
ref_in = True if ref_label is None else ref_label in self.labelmap.values()
return pred_in and ref_in

def contains_or(self, pred_label: int | None = None, ref_label: int | None = None) -> bool:
def contains_or(
self, pred_label: int | None = None, ref_label: int | None = None
) -> bool:
pred_in = True if pred_label is None else pred_label in self.labelmap
ref_in = True if ref_label is None else ref_label in self.labelmap.values()
return pred_in or ref_in
Expand All @@ -47,7 +53,9 @@ def __str__(self) -> str:
return str(
list(
[
str(tuple(k for k in self.labelmap.keys() if self.labelmap[k] == v)) + " -> " + str(v)
str(tuple(k for k in self.labelmap.keys() if self.labelmap[k] == v))
+ " -> "
+ str(v)
for v in set(self.labelmap.values())
]
)
Expand Down
Loading

0 comments on commit 3e89dd8

Please sign in to comment.