diff --git a/panoptica/__init__.py b/panoptica/__init__.py index 4589be4..bb068d6 100644 --- a/panoptica/__init__.py +++ b/panoptica/__init__.py @@ -3,7 +3,7 @@ CCABackend, ) from panoptica.instance_matcher import NaiveThresholdMatching -from panoptica.panoptica_statistics import Panoptica_Statistic +from panoptica.panoptica_statistics import Panoptica_Statistic, ValueSummary from panoptica.panoptica_aggregator import Panoptica_Aggregator from panoptica.panoptica_evaluator import Panoptica_Evaluator from panoptica.panoptica_result import PanopticaResult diff --git a/panoptica/panoptica_statistics.py b/panoptica/panoptica_statistics.py index 96489b3..7495054 100644 --- a/panoptica/panoptica_statistics.py +++ b/panoptica/panoptica_statistics.py @@ -13,6 +13,35 @@ print("OPTIONAL PACKAGE MISSING") +class ValueSummary: + def __init__(self, value_list: list[float]) -> None: + self.__value_list = value_list + self.__avg = float(np.average(value_list)) + self.__std = float(np.std(value_list)) + self.__min = min(value_list) + self.__max = max(value_list) + + @property + def values(self) -> list[float]: + return self.__value_list + + @property + def avg(self) -> float: + return self.__avg + + @property + def std(self) -> float: + return self.__std + + @property + def min(self) -> float: + return self.__min + + @property + def max(self) -> float: + return self.__max + + class Panoptica_Statistic: def __init__( @@ -26,6 +55,20 @@ def __init__( self.__groupnames = list(value_dict.keys()) self.__metricnames = list(value_dict[self.__groupnames[0]].keys()) + # assert length of everything + for g in self.groupnames: + assert len(self.metricnames) == len( + list(value_dict[g].keys()) + ), f"Group {g}, has inconsistent number of metrics, got {len(list(value_dict[g].keys()))} but expected {len(self.metricnames)}" + for m in self.metricnames: + assert len(self.get(g, m)) == len( + self.subjectnames + ), f"Group {g}, m {m} has not right subjects, got {len(self.get(g, m))}, expected {len(self.subjectnames)}" + + @property + def subjectnames(self): + return self.__subj_names + @property def groupnames(self): return self.__groupnames @@ -76,8 +119,12 @@ def from_file(cls, file: str): if len(value) > 0: value = float(value) - if not np.isnan(value) and value != np.inf: + if value is not None and not np.isnan(value) and value != np.inf: value_dict[group_name][metric_name].append(float(value)) + else: + value_dict[group_name][metric_name].append(None) + else: + value_dict[group_name][metric_name].append(None) return Panoptica_Statistic(subj_names=subj_names, value_dict=value_dict) @@ -96,7 +143,7 @@ def _assertsubject(self, subjectname): subjectname in self.__subj_names ), f"subject {subjectname} not in list of subjects, got {self.__subj_names}" - def get(self, group, metric) -> list[float]: + def get(self, group, metric, remove_nones: bool = False) -> list[float]: """Returns the list of values for given group and metric Args: @@ -112,7 +159,9 @@ def get(self, group, metric) -> list[float]: assert ( group in self.__value_dict and metric in self.__value_dict[group] ), f"Values not found for group {group} and metric {metric} evem though they should!" - return self.__value_dict[group][metric] + if not remove_nones: + return self.__value_dict[group][metric] + return [i for i in self.__value_dict[group][metric] if i is not None] def get_one_subject(self, subjectname: str): """Gets the values for ONE subject for each group and metric @@ -130,7 +179,7 @@ def get_one_subject(self, subjectname: str): for g in self.__groupnames } - def get_across_groups(self, metric): + def get_across_groups(self, metric) -> list[float]: """Given metric, gives list of all values (even across groups!) Treat with care! Args: @@ -141,41 +190,62 @@ def get_across_groups(self, metric): """ values = [] for g in self.__groupnames: - values.append(self.get(g, metric)) + values += self.get(g, metric) return values - def get_summary_dict(self): - return { + def get_summary_across_groups(self) -> dict[str, ValueSummary]: + """Calculates the average and std over all groups (so group-wise avg first, then average over those) + + Returns: + dict[str, tuple[float, float]]: _description_ + """ + summary_dict = {} + for m in self.__metricnames: + value_list = [self.get_summary(g, m).avg for g in self.__groupnames] + assert len(value_list) == len(self.__groupnames) + summary_dict[m] = ValueSummary(value_list) + return summary_dict + + def get_summary_dict( + self, include_across_group: bool = True + ) -> dict[str, dict[str, ValueSummary]]: + summary_dict = { g: {m: self.get_summary(g, m) for m in self.__metricnames} for g in self.__groupnames } + if include_across_group: + summary_dict["across_groups"] = self.get_summary_across_groups() + return summary_dict - def get_summary(self, group, metric): - # TODO maybe more here? range, stuff like that - return self.avg_std(group, metric) + def get_summary(self, group, metric) -> ValueSummary: + values = self.get(group, metric, remove_nones=True) + return ValueSummary(values) - def avg_std(self, group, metric) -> tuple[float, float]: - values = self.get(group, metric) - avg = float(np.average(values)) - std = float(np.std(values)) - return (avg, std) - - def print_summary(self, ndigits: int = 3): - summary = self.get_summary_dict() + def print_summary( + self, + ndigits: int = 3, + only_across_groups: bool = True, + ): + summary = self.get_summary_dict(include_across_group=only_across_groups) print() - for g in self.__groupnames: + groups = list(summary.keys()) + if only_across_groups: + groups = ["across_groups"] + for g in groups: print(f"Group {g}:") for m in self.__metricnames: - avg, std = summary[g][m] + avg, std = summary[g][m].avg, summary[g][m].std print(m, ":", round(avg, ndigits), "+-", round(std, ndigits)) print() def get_summary_figure( self, metric: str, + manual_metric_range: None | tuple[float, float] = None, + name_method: str = "Structure", horizontal: bool = True, sort: bool = True, - # title overwrite? + title: str = "", ): """Returns a figure object that shows the given metric for each group and its std @@ -187,12 +257,25 @@ def get_summary_figure( _type_: _description_ """ orientation = "h" if horizontal else "v" - data_plot = {g: np.asarray(self.get(g, metric)) for g in self.__groupnames} + data_plot = { + g: np.asarray(self.get(g, metric, remove_nones=True)) + for g in self.__groupnames + } + if manual_metric_range is not None: + assert manual_metric_range[0] < manual_metric_range[1], manual_metric_range + change = (manual_metric_range[1] - manual_metric_range[0]) / 100 + manual_metric_range = ( + manual_metric_range[0] - change, + manual_metric_range[1] + change, + ) return plot_box( data=data_plot, orientation=orientation, - score=metric, + name_method=name_method, + name_metric=metric, sort=sort, + figure_title=title, + manual_metric_range=manual_metric_range, ) # groupwise or in total @@ -247,7 +330,10 @@ def make_curve_over_setups( plt.grid("major") # Y values are average metric values in that group and metric for idx, g in enumerate(groups): - Y = [stat.avg_std(g, metric)[0] for stat in statistics_dict.values()] + Y = [ + ValueSummary(stat.get(g, metric, remove_nones=True)).avg + for stat in statistics_dict.values() + ] if plot_lines: plt.plot( @@ -274,56 +360,60 @@ def plot_box( data: dict[str, np.ndarray], sort=True, orientation="h", - # graph_name: str = "Structure", - score: str = "Dice-Score", + name_method: str = "Structure", + name_metric: str = "Dice-Score", + figure_title: str = "", width=850, height=1200, - yaxis_title=None, - xaxis_title=None, + manual_metric_range: None | tuple[float, float] = None, ): - graph_name: str = "Structure" - - if xaxis_title is None: - xaxis_title = score if orientation == "h" else graph_name - if yaxis_title is None: - yaxis_title = score if orientation != "h" else graph_name + xaxis_title = name_metric if orientation == "h" else name_method + yaxis_title = name_metric if orientation != "h" else name_method data = {e.replace("_", " "): v for e, v in data.items()} df_data = pd.DataFrame( { - graph_name: _flatten_extend([([e] * len(y0)) for e, y0 in data.items()]), - score: np.concatenate([*data.values()], 0), + name_method: _flatten_extend([([e] * len(y0)) for e, y0 in data.items()]), + name_metric: np.concatenate([*data.values()], 0), } ) if sort: - df_by_spec_count = df_data.groupby(graph_name).mean() - df_by_spec_count = dict(df_by_spec_count[score].items()) - df_data["mean"] = df_data[graph_name].apply( + df_by_spec_count = df_data.groupby(name_method).mean() + df_by_spec_count = dict(df_by_spec_count[name_metric].items()) + df_data["mean"] = df_data[name_method].apply( lambda x: df_by_spec_count[x] * (1 if orientation == "h" else -1) ) df_data = df_data.sort_values(by="mean") if orientation == "v": fig = px.strip( - df_data, x=graph_name, y=score, stripmode="overlay", orientation=orientation + df_data, + x=name_method, + y=name_metric, + stripmode="overlay", + orientation=orientation, ) fig.update_traces(marker={"size": 5, "color": "#555555"}) for e in data.keys(): fig.add_trace( go.Box( - y=df_data.query(f'{graph_name} == "{e}"')[score], + y=df_data.query(f'{name_method} == "{e}"')[name_metric], name=e, orientation=orientation, ) ) else: fig = px.strip( - df_data, y=graph_name, x=score, stripmode="overlay", orientation=orientation + df_data, + y=name_method, + x=name_metric, + stripmode="overlay", + orientation=orientation, ) fig.update_traces(marker={"size": 5, "color": "#555555"}) for e in data.keys(): fig.add_trace( go.Box( - x=df_data.query(f'{graph_name} == "{e}"')[score], + x=df_data.query(f'{name_method} == "{e}"')[name_metric], name=e, orientation=orientation, boxpoints=False, @@ -337,6 +427,12 @@ def plot_box( yaxis_title=yaxis_title, xaxis_title=xaxis_title, font={"family": "Arial"}, + title=figure_title, ) + if manual_metric_range is not None: + if orientation == "h": + fig.update_xaxes(range=[manual_metric_range[0], manual_metric_range[1]]) + else: + fig.update_yaxes(range=[manual_metric_range[0], manual_metric_range[1]]) fig.update_traces(orientation=orientation) return fig diff --git a/unit_tests/test_panoptic_aggregator.py b/unit_tests/test_panoptic_aggregator.py index 53b34b1..cbba390 100644 --- a/unit_tests/test_panoptic_aggregator.py +++ b/unit_tests/test_panoptic_aggregator.py @@ -96,70 +96,3 @@ def test_simple_evaluation_then_statistic(self): statistic_obj.print_summary() os.remove(str(output_test_dir)) - - -class Test_Panoptica_Statistics(unittest.TestCase): - def setUp(self) -> None: - os.environ["PANOPTICA_CITATION_REMINDER"] = "False" - return super().setUp() - - def test_simple_statistic(self): - a = np.zeros([50, 50], dtype=np.uint16) - b = a.copy().astype(a.dtype) - a[20:40, 10:20] = 1 - b[20:35, 10:20] = 2 - - evaluator = Panoptica_Evaluator( - expected_input=InputType.SEMANTIC, - instance_approximator=ConnectedComponentsInstanceApproximator(), - instance_matcher=NaiveThresholdMatching(), - ) - - aggregator = Panoptica_Aggregator(evaluator, output_file=output_test_dir) - - aggregator.evaluate(b, a, "test") - - statistic_obj = Panoptica_Statistic.from_file(output_test_dir) - - statistic_obj.print_summary() - - self.assertEqual(statistic_obj.get("ungrouped", "tp"), [1.0]) - self.assertEqual(statistic_obj.get("ungrouped", "sq"), [0.75]) - self.assertEqual(statistic_obj.get("ungrouped", "sq_rvd"), [-0.25]) - - self.assertEqual(statistic_obj.avg_std("ungrouped", "tp")[0], 1.0) - self.assertEqual(statistic_obj.avg_std("ungrouped", "sq")[0], 0.75) - - os.remove(str(output_test_dir)) - - def test_multiple_samples_statistic(self): - a = np.zeros([50, 50], dtype=np.uint16) - b = a.copy().astype(a.dtype) - c = a.copy().astype(a.dtype) - a[20:40, 10:20] = 1 - b[20:35, 10:20] = 2 - c[20:40, 10:20] = 5 - c[0:10, 0:10] = 3 - - evaluator = Panoptica_Evaluator( - expected_input=InputType.SEMANTIC, - instance_approximator=ConnectedComponentsInstanceApproximator(), - instance_matcher=NaiveThresholdMatching(), - ) - - aggregator = Panoptica_Aggregator(evaluator, output_file=output_test_dir) - - aggregator.evaluate(b, a, "test") - aggregator.evaluate(a, c, "test2") - - statistic_obj = Panoptica_Statistic.from_file(output_test_dir) - - statistic_obj.print_summary() - - self.assertEqual(statistic_obj.avg_std("ungrouped", "tp")[0], 1.0) - self.assertEqual(statistic_obj.avg_std("ungrouped", "sq")[0], 0.875) - self.assertEqual(statistic_obj.avg_std("ungrouped", "fn")[0], 0.5) - self.assertEqual(statistic_obj.avg_std("ungrouped", "rec")[0], 0.75) - self.assertEqual(statistic_obj.avg_std("ungrouped", "rec")[1], 0.25) - - os.remove(str(output_test_dir)) diff --git a/unit_tests/test_panoptic_statistics.py b/unit_tests/test_panoptic_statistics.py new file mode 100644 index 0000000..3aebe71 --- /dev/null +++ b/unit_tests/test_panoptic_statistics.py @@ -0,0 +1,112 @@ +# Call 'python -m unittest' on this folder +# coverage run -m unittest +# coverage report +# coverage html +import os +import unittest + +import numpy as np + +from panoptica import InputType, Panoptica_Aggregator, Panoptica_Statistic, ValueSummary +from panoptica.instance_approximator import ConnectedComponentsInstanceApproximator +from panoptica.instance_matcher import MaximizeMergeMatching, NaiveThresholdMatching +from panoptica.metrics import Metric +from panoptica.panoptica_evaluator import Panoptica_Evaluator +from panoptica.panoptica_result import MetricCouldNotBeComputedException +from panoptica.utils.processing_pair import SemanticPair +from panoptica.utils.segmentation_class import SegmentationClassGroups +import sys +from pathlib import Path + + +output_test_dir = Path(__file__).parent.joinpath("unittest_tmp_file.tsv") + +input_test_file = Path(__file__).parent.joinpath("test_unittest_file.tsv") + + +class Test_Panoptica_Statistics(unittest.TestCase): + def setUp(self) -> None: + os.environ["PANOPTICA_CITATION_REMINDER"] = "False" + return super().setUp() + + def test_simple_statistic(self): + a = np.zeros([50, 50], dtype=np.uint16) + b = a.copy().astype(a.dtype) + a[20:40, 10:20] = 1 + b[20:35, 10:20] = 2 + + evaluator = Panoptica_Evaluator( + expected_input=InputType.SEMANTIC, + instance_approximator=ConnectedComponentsInstanceApproximator(), + instance_matcher=NaiveThresholdMatching(), + ) + + aggregator = Panoptica_Aggregator(evaluator, output_file=output_test_dir) + + aggregator.evaluate(b, a, "test") + + statistic_obj = Panoptica_Statistic.from_file(output_test_dir) + + statistic_obj.print_summary() + + self.assertEqual(statistic_obj.get("ungrouped", "tp"), [1.0]) + self.assertEqual(statistic_obj.get("ungrouped", "sq"), [0.75]) + self.assertEqual(statistic_obj.get("ungrouped", "sq_rvd"), [-0.25]) + + tp_values = statistic_obj.get("ungrouped", "tp") + sq_values = statistic_obj.get("ungrouped", "sq") + self.assertEqual(ValueSummary(tp_values).avg, 1.0) + self.assertEqual(ValueSummary(sq_values).avg, 0.75) + + os.remove(str(output_test_dir)) + + def test_multiple_samples_statistic(self): + a = np.zeros([50, 50], dtype=np.uint16) + b = a.copy().astype(a.dtype) + c = a.copy().astype(a.dtype) + a[20:40, 10:20] = 1 + b[20:35, 10:20] = 2 + c[20:40, 10:20] = 5 + c[0:10, 0:10] = 3 + + evaluator = Panoptica_Evaluator( + expected_input=InputType.SEMANTIC, + instance_approximator=ConnectedComponentsInstanceApproximator(), + instance_matcher=NaiveThresholdMatching(), + ) + + aggregator = Panoptica_Aggregator(evaluator, output_file=output_test_dir) + + aggregator.evaluate(b, a, "test") + aggregator.evaluate(a, c, "test2") + + statistic_obj = Panoptica_Statistic.from_file(output_test_dir) + + statistic_obj.print_summary() + + self.assertEqual(ValueSummary(statistic_obj.get("ungrouped", "tp")).avg, 1.0) + self.assertEqual(ValueSummary(statistic_obj.get("ungrouped", "sq")).avg, 0.875) + self.assertEqual(ValueSummary(statistic_obj.get("ungrouped", "fn")).avg, 0.5) + self.assertEqual(ValueSummary(statistic_obj.get("ungrouped", "rec")).avg, 0.75) + self.assertEqual(ValueSummary(statistic_obj.get("ungrouped", "rec")).std, 0.25) + + os.remove(str(output_test_dir)) + + def test_statistics_from_file(self): + statistic_obj = Panoptica_Statistic.from_file(input_test_file) + # + test2 = statistic_obj.get_one_subject("test2") # get one subject + print() + print("test2", test2) + self.assertEqual(test2["ungrouped"]["num_ref_instances"], 2) + + all_num_ref_instances = statistic_obj.get_across_groups("num_ref_instances") + print() + print("all_num_ref_instances", all_num_ref_instances) + self.assertEqual(len(all_num_ref_instances), 2) + self.assertEqual(sum(all_num_ref_instances), 3) + + groupwise_summary = statistic_obj.get_summary_across_groups() + print() + print(groupwise_summary) + self.assertEqual(groupwise_summary["num_ref_instances"].avg, 1.5) diff --git a/unit_tests/test_unittest_file.tsv b/unit_tests/test_unittest_file.tsv new file mode 100644 index 0000000..68b2144 --- /dev/null +++ b/unit_tests/test_unittest_file.tsv @@ -0,0 +1,3 @@ +subject_name ungrouped-num_ref_instances ungrouped-num_pred_instances ungrouped-tp ungrouped-fp ungrouped-fn ungrouped-prec ungrouped-rec ungrouped-rq ungrouped-sq ungrouped-sq_std ungrouped-pq ungrouped-sq_dsc ungrouped-sq_dsc_std ungrouped-pq_dsc ungrouped-sq_assd ungrouped-sq_assd_std ungrouped-sq_rvd ungrouped-sq_rvd_std ungrouped-global_bin_dsc +test 1 1 1 0 0 1.0 1.0 1.0 0.75 0.0 0.75 0.8571428571428571 0.0 0.8571428571428571 0.842391304347826 0.0 -0.25 0.0 0.8571428571428571 +test2 2 1 1 0 1 1.0 0.5 0.6666666666666666 1.0 0.0 0.6666666666666666 1.0 0.0 0.6666666666666666 0.0 0.0 0.0 0.0 0.8