diff --git a/panoptica/panoptica_result.py b/panoptica/panoptica_result.py index 2123981..0db17c2 100644 --- a/panoptica/panoptica_result.py +++ b/panoptica/panoptica_result.py @@ -563,29 +563,3 @@ def sq_rvd_std(res: PanopticaResult): # endregion - - -def _build_global_bin_metric_function(metric: Metric): - - def function_template(res: PanopticaResult): - if metric not in res._global_metrics: - raise MetricCouldNotBeComputedException(f"Global Metric {metric} not set") - if res.tp == 0: - is_edgecase, result = res._edge_case_handler.handle_zero_tp( - metric, res.tp, res.num_pred_instances, res.num_ref_instances - ) - if is_edgecase: - return result - pred_binary = res._prediction_arr.copy() - ref_binary = res._reference_arr.copy() - pred_binary[pred_binary != 0] = 1 - ref_binary[ref_binary != 0] = 1 - return metric( - reference_arr=res._reference_arr, - prediction_arr=res._prediction_arr, - ) - - return lambda x: function_template(x) - - -# endregion diff --git a/panoptica/utils/config.py b/panoptica/utils/config.py index 241bc0a..a3d5157 100644 --- a/panoptica/utils/config.py +++ b/panoptica/utils/config.py @@ -43,67 +43,6 @@ def _save_yaml(data_dict: dict | object, out_file: str | Path, registered_class= print(f"Saved config into {out_file}") -#################### -# TODO Merge into SupportsConfig -class Configuration: - """General Configuration class that handles yaml""" - - _data_dict: dict - _registered_class = None - - def __init__(self, data_dict: dict, registered_class=None) -> None: - assert isinstance(data_dict, dict) - self._data_dict = data_dict - if registered_class is not None: - self.register_to_class(registered_class) - - def register_to_class(self, cls): - global supported_helper_classes - if cls not in supported_helper_classes: - supported_helper_classes.append(cls) - self._registered_class = cls - return self - - @classmethod - def save_from_object(cls, obj: object, file: str | Path): - _save_yaml(obj, file, registered_class=type(obj)) - # return Configuration.load(file, registered_class=type(obj)) - - @classmethod - def load(cls, file: str | Path, registered_class=None): - data = _load_yaml(file, registered_class) - assert isinstance( - data, dict - ), f"The config at {file} is registered to a class. Use load_as_object() instead" - return Configuration(data, registered_class=registered_class) - - @classmethod - def load_as_object(cls, file: str | Path, registered_class=None): - data = _load_yaml(file, registered_class) - assert not isinstance( - data, dict - ), f"The config at {file} is not registered to a class. Use load() instead" - return data - - def save(self, out_file: str | Path): - _save_yaml(self._data_dict, out_file) - - def cls_object_from_this(self): - assert self._registered_class is not None - return self._registered_class(**self._data_dict) - - @property - def data_dict(self): - return self._data_dict - - @property - def cls(self): - return self._registered_class - - def __str__(self) -> str: - return f"Config({self.cls.__name__ if self.cls is not None else 'NoClass'} = {self.data_dict})" # type: ignore - - ######### # Universal Functions ######### @@ -118,7 +57,7 @@ def _load_from_config(cls, path: str | Path): if isinstance(path, str): path = Path(path) assert path.exists(), f"load_from_config: {path} does not exist" - obj = Configuration.load_as_object(path, registered_class=cls) + obj = _load_yaml(path, registered_class=cls) assert isinstance(obj, cls), f"Loaded config was not for class {cls.__name__}" return obj @@ -132,7 +71,7 @@ def _load_from_config_name(cls, name: str): def _save_to_config(obj, path: str | Path): if isinstance(path, str): path = Path(path) - Configuration.save_from_object(obj, path) + _save_yaml(obj, path, registered_class=type(obj)) def _save_to_config_by_name(obj, name: str): diff --git a/panoptica/utils/numpy_utils.py b/panoptica/utils/numpy_utils.py index d28f686..70180a1 100644 --- a/panoptica/utils/numpy_utils.py +++ b/panoptica/utils/numpy_utils.py @@ -17,7 +17,7 @@ def _unique_without_zeros(arr: np.ndarray) -> np.ndarray: Issues a warning if negative values are present. """ if np.any(arr < 0): - warnings.warn("Negative values are present in the input array.") + warnings.warn("Negative values are present in the input array.", UserWarning) return np.unique(arr[arr != 0]) @@ -33,7 +33,7 @@ def _count_unique_without_zeros(arr: np.ndarray) -> int: int: Number of unique elements excluding zeros. """ if np.any(arr < 0): - warnings.warn("Negative values are present in the input array.") + warnings.warn("Negative values are present in the input array.", UserWarning) return len(_unique_without_zeros(arr)) diff --git a/unit_tests/test_metrics.py b/unit_tests/test_metrics.py index 9e945d2..e26203c 100644 --- a/unit_tests/test_metrics.py +++ b/unit_tests/test_metrics.py @@ -96,6 +96,12 @@ def setUp(self) -> None: def test_rvd_case_simple_identical(self): + pred_arr, ref_arr = case_simple_identical() + rvd = Metric.RVD(reference_arr=ref_arr, prediction_arr=pred_arr, ref_instance_idx=1, pred_instance_idx=1) + self.assertEqual(rvd, 0.0) + + def test_rvd_case_simple_identical_idx(self): + pred_arr, ref_arr = case_simple_identical() rvd = Metric.RVD(reference_arr=ref_arr, prediction_arr=pred_arr) self.assertEqual(rvd, 0.0) @@ -130,6 +136,12 @@ def test_dsc_case_simple_identical(self): dsc = Metric.DSC(reference_arr=ref_arr, prediction_arr=pred_arr) self.assertEqual(dsc, 1.0) + def test_dsc_case_simple_identical_idx(self): + + pred_arr, ref_arr = case_simple_identical() + dsc = Metric.DSC(reference_arr=ref_arr, prediction_arr=pred_arr, ref_instance_idx=1, pred_instance_idx=1) + self.assertEqual(dsc, 1.0) + def test_dsc_case_simple_nooverlap(self): pred_arr, ref_arr = case_simple_nooverlap() @@ -159,6 +171,11 @@ def test_st_case_simple_identical(self): st = Metric.ASSD(reference_arr=ref_arr, prediction_arr=pred_arr) self.assertEqual(st, 0.0) + def test_st_case_simple_identical_idx(self): + pred_arr, ref_arr = case_simple_identical() + st = Metric.ASSD(reference_arr=ref_arr, prediction_arr=pred_arr, ref_instance_idx=1, pred_instance_idx=1) + self.assertEqual(st, 0.0) + def test_st_case_simple_nooverlap(self): pred_arr, ref_arr = case_simple_nooverlap() st = Metric.ASSD(reference_arr=ref_arr, prediction_arr=pred_arr) @@ -185,6 +202,11 @@ def test_st_case_simple_identical(self): st = Metric.clDSC(reference_arr=ref_arr, prediction_arr=pred_arr) self.assertEqual(st, 1.0) + def test_st_case_simple_identical_idx(self): + pred_arr, ref_arr = case_simple_identical() + st = Metric.clDSC(reference_arr=ref_arr, prediction_arr=pred_arr, ref_instance_idx=1, pred_instance_idx=1) + self.assertEqual(st, 1.0) + def test_st_case_simple_nooverlap(self): pred_arr, ref_arr = case_simple_nooverlap() st = Metric.clDSC(reference_arr=ref_arr, prediction_arr=pred_arr) diff --git a/unit_tests/test_panoptic_pipeline.py b/unit_tests/test_panoptic_pipeline.py index 0b6ad29..2d75a95 100644 --- a/unit_tests/test_panoptic_pipeline.py +++ b/unit_tests/test_panoptic_pipeline.py @@ -9,7 +9,7 @@ from panoptica import InputType from panoptica.instance_approximator import ConnectedComponentsInstanceApproximator -from panoptica.instance_matcher import MaximizeMergeMatching, NaiveThresholdMatching +from panoptica.instance_matcher import MaximizeMergeMatching, NaiveThresholdMatching, InstanceLabelMap from panoptica.metrics import Metric from panoptica.instance_evaluator import ( evaluate_matched_instance, @@ -33,6 +33,26 @@ def setUp(self) -> None: os.environ["PANOPTICA_CITATION_REMINDER"] = "False" return super().setUp() + def test_labelmap(self): + labelmap = InstanceLabelMap() + + labelmap.add_labelmap_entry(1, 1) + labelmap.add_labelmap_entry([2, 3], 2) + + with self.assertRaises(Exception): + labelmap.add_labelmap_entry(1, 1) + labelmap.add_labelmap_entry(1, 2) + + self.assertTrue(labelmap.contains_and(None, None)) + self.assertTrue(labelmap.contains_and(1, 1)) + self.assertTrue(not labelmap.contains_and(1, 3)) + self.assertTrue(not labelmap.contains_and(4, 1)) + + print(labelmap) + + with self.assertRaises(Exception): + labelmap.labelmap = {} + class Test_Panoptica_Instance_Evaluation(unittest.TestCase): def setUp(self) -> None: diff --git a/unit_tests/test_utils.py b/unit_tests/test_utils.py new file mode 100644 index 0000000..8433e67 --- /dev/null +++ b/unit_tests/test_utils.py @@ -0,0 +1,52 @@ +# Call 'python -m unittest' on this folder +# coverage run -m unittest +# coverage report +# coverage html +import os +import unittest +import numpy as np +from panoptica.utils.numpy_utils import _unique_without_zeros, _count_unique_without_zeros, _get_smallest_fitting_uint +from panoptica.utils.citation_reminder import citation_reminder + + +class Test_Citation_Reminder(unittest.TestCase): + def setUp(self) -> None: + os.environ["PANOPTICA_CITATION_REMINDER"] = "True" + return super().setUp() + + def test_citation_code(self): + + @citation_reminder + def foo(): + return "bar" + + foo() + + +class Test_Numpy_Utils(unittest.TestCase): + def setUp(self) -> None: + os.environ["PANOPTICA_CITATION_REMINDER"] = "False" + return super().setUp() + + def test_np_unique(self): + a = np.array([0, 1, 2, 3, 6]) + b = _unique_without_zeros(a) + + self.assertTrue(b[0] == 1) + self.assertTrue(b[1] == 2) + self.assertTrue(b[2] == 3) + self.assertTrue(b[3] == 6) + self.assertEqual(b.shape[0], 4) + + with self.assertWarns(UserWarning): + a = np.array([0, 1, -2, 3, -6]) + b = _unique_without_zeros(a) + + def test_np_count_unique(self): + a = np.array([0, 1, 2, 3, 6]) + b = _count_unique_without_zeros(a) + self.assertEqual(b, 4) + # + with self.assertWarns(UserWarning): + a = np.array([0, 1, -2, 3, -6]) + b = _count_unique_without_zeros(a)