From 80a508838216d9801bef0cf009753bb5b7b827c9 Mon Sep 17 00:00:00 2001 From: iback Date: Mon, 25 Nov 2024 12:17:57 +0000 Subject: [PATCH 1/3] made panoptica_statistics more sophisticated, supports now values across groups, so averaging value in one group and then averaging over these averages. Added new unittests for the new features. Added ValueSummary to make clear what is meant with a summary --- panoptica/__init__.py | 2 +- panoptica/panoptica_statistics.py | 199 ++++++++++++++++--------- unit_tests/test_panoptic_aggregator.py | 66 -------- unit_tests/test_panoptic_statistics.py | 112 ++++++++++++++ 4 files changed, 240 insertions(+), 139 deletions(-) create mode 100644 unit_tests/test_panoptic_statistics.py diff --git a/panoptica/__init__.py b/panoptica/__init__.py index 4589be4..bb068d6 100644 --- a/panoptica/__init__.py +++ b/panoptica/__init__.py @@ -3,7 +3,7 @@ CCABackend, ) from panoptica.instance_matcher import NaiveThresholdMatching -from panoptica.panoptica_statistics import Panoptica_Statistic +from panoptica.panoptica_statistics import Panoptica_Statistic, ValueSummary from panoptica.panoptica_aggregator import Panoptica_Aggregator from panoptica.panoptica_evaluator import Panoptica_Evaluator from panoptica.panoptica_result import PanopticaResult diff --git a/panoptica/panoptica_statistics.py b/panoptica/panoptica_statistics.py index 96489b3..4b0261a 100644 --- a/panoptica/panoptica_statistics.py +++ b/panoptica/panoptica_statistics.py @@ -13,6 +13,35 @@ print("OPTIONAL PACKAGE MISSING") +class ValueSummary: + def __init__(self, value_list: list[float]) -> None: + self.__value_list = value_list + self.__avg = float(np.average(value_list)) + self.__std = float(np.std(value_list)) + self.__min = min(value_list) + self.__max = max(value_list) + + @property + def values(self) -> list[float]: + return self.__value_list + + @property + def avg(self) -> float: + return self.__avg + + @property + def std(self) -> float: + return self.__std + + @property + def min(self) -> float: + return self.__min + + @property + def max(self) -> float: + return self.__max + + class Panoptica_Statistic: def __init__( @@ -26,6 +55,20 @@ def __init__( self.__groupnames = list(value_dict.keys()) self.__metricnames = list(value_dict[self.__groupnames[0]].keys()) + # assert length of everything + for g in self.groupnames: + assert len(self.metricnames) == len( + list(value_dict[g].keys()) + ), f"Group {g}, has inconsistent number of metrics, got {len(list(value_dict[g].keys()))} but expected {len(self.metricnames)}" + for m in self.metricnames: + assert len(self.get(g, m)) == len( + self.subjectnames + ), f"Group {g}, m {m} has not right subjects, got {len(self.get(g, m))}, expected {len(self.subjectnames)}" + + @property + def subjectnames(self): + return self.__subj_names + @property def groupnames(self): return self.__groupnames @@ -43,9 +86,7 @@ def from_file(cls, file: str): rows = [row for row in rd] header = rows[0] - assert ( - header[0] == "subject_name" - ), "First column is not subject_names, something wrong with the file?" + assert header[0] == "subject_name", "First column is not subject_names, something wrong with the file?" keys_in_order = list([tuple(c.split("-")) for c in header[1:]]) metric_names = [] @@ -76,27 +117,25 @@ def from_file(cls, file: str): if len(value) > 0: value = float(value) - if not np.isnan(value) and value != np.inf: + if value is not None and not np.isnan(value) and value != np.inf: value_dict[group_name][metric_name].append(float(value)) + else: + value_dict[group_name][metric_name].append(None) + else: + value_dict[group_name][metric_name].append(None) return Panoptica_Statistic(subj_names=subj_names, value_dict=value_dict) def _assertgroup(self, group): - assert ( - group in self.__groupnames - ), f"group {group} not existent, only got groups {self.__groupnames}" + assert group in self.__groupnames, f"group {group} not existent, only got groups {self.__groupnames}" def _assertmetric(self, metric): - assert ( - metric in self.__metricnames - ), f"metric {metric} not existent, only got metrics {self.__metricnames}" + assert metric in self.__metricnames, f"metric {metric} not existent, only got metrics {self.__metricnames}" def _assertsubject(self, subjectname): - assert ( - subjectname in self.__subj_names - ), f"subject {subjectname} not in list of subjects, got {self.__subj_names}" + assert subjectname in self.__subj_names, f"subject {subjectname} not in list of subjects, got {self.__subj_names}" - def get(self, group, metric) -> list[float]: + def get(self, group, metric, remove_nones: bool = False) -> list[float]: """Returns the list of values for given group and metric Args: @@ -112,7 +151,9 @@ def get(self, group, metric) -> list[float]: assert ( group in self.__value_dict and metric in self.__value_dict[group] ), f"Values not found for group {group} and metric {metric} evem though they should!" - return self.__value_dict[group][metric] + if not remove_nones: + return self.__value_dict[group][metric] + return [i for i in self.__value_dict[group][metric] if i is not None] def get_one_subject(self, subjectname: str): """Gets the values for ONE subject for each group and metric @@ -125,12 +166,9 @@ def get_one_subject(self, subjectname: str): """ self._assertsubject(subjectname) sidx = self.__subj_names.index(subjectname) - return { - g: {m: self.get(g, m)[sidx] for m in self.__metricnames} - for g in self.__groupnames - } + return {g: {m: self.get(g, m)[sidx] for m in self.__metricnames} for g in self.__groupnames} - def get_across_groups(self, metric): + def get_across_groups(self, metric) -> list[float]: """Given metric, gives list of all values (even across groups!) Treat with care! Args: @@ -141,41 +179,57 @@ def get_across_groups(self, metric): """ values = [] for g in self.__groupnames: - values.append(self.get(g, metric)) + values += self.get(g, metric) return values - def get_summary_dict(self): - return { - g: {m: self.get_summary(g, m) for m in self.__metricnames} - for g in self.__groupnames - } - - def get_summary(self, group, metric): - # TODO maybe more here? range, stuff like that - return self.avg_std(group, metric) - - def avg_std(self, group, metric) -> tuple[float, float]: - values = self.get(group, metric) - avg = float(np.average(values)) - std = float(np.std(values)) - return (avg, std) + def get_summary_across_groups(self) -> dict[str, ValueSummary]: + """Calculates the average and std over all groups (so group-wise avg first, then average over those) - def print_summary(self, ndigits: int = 3): - summary = self.get_summary_dict() + Returns: + dict[str, tuple[float, float]]: _description_ + """ + summary_dict = {} + for m in self.__metricnames: + value_list = [self.get_summary(g, m).avg for g in self.__groupnames] + assert len(value_list) == len(self.__groupnames) + summary_dict[m] = ValueSummary(value_list) + return summary_dict + + def get_summary_dict(self, include_across_group: bool = True) -> dict[str, dict[str, ValueSummary]]: + summary_dict = {g: {m: self.get_summary(g, m) for m in self.__metricnames} for g in self.__groupnames} + if include_across_group: + summary_dict["across_groups"] = self.get_summary_across_groups() + return summary_dict + + def get_summary(self, group, metric) -> ValueSummary: + values = self.get(group, metric, remove_nones=True) + return ValueSummary(values) + + def print_summary( + self, + ndigits: int = 3, + only_across_groups: bool = True, + ): + summary = self.get_summary_dict(include_across_group=only_across_groups) print() - for g in self.__groupnames: + groups = list(summary.keys()) + if only_across_groups: + groups = ["across_groups"] + for g in groups: print(f"Group {g}:") for m in self.__metricnames: - avg, std = summary[g][m] + avg, std = summary[g][m].avg, summary[g][m].std print(m, ":", round(avg, ndigits), "+-", round(std, ndigits)) print() def get_summary_figure( self, metric: str, + manual_metric_range: None | tuple[float, float] = None, + name_method: str = "Structure", horizontal: bool = True, sort: bool = True, - # title overwrite? + title: str = "", ): """Returns a figure object that shows the given metric for each group and its std @@ -187,12 +241,19 @@ def get_summary_figure( _type_: _description_ """ orientation = "h" if horizontal else "v" - data_plot = {g: np.asarray(self.get(g, metric)) for g in self.__groupnames} + data_plot = {g: np.asarray(self.get(g, metric, remove_nones=True)) for g in self.__groupnames} + if manual_metric_range is not None: + assert manual_metric_range[0] < manual_metric_range[1], manual_metric_range + change = (manual_metric_range[1] - manual_metric_range[0]) / 100 + manual_metric_range = (manual_metric_range[0] - change, manual_metric_range[1] + change) return plot_box( data=data_plot, orientation=orientation, - score=metric, + name_method=name_method, + name_metric=metric, sort=sort, + figure_title=title, + manual_metric_range=manual_metric_range, ) # groupwise or in total @@ -220,9 +281,7 @@ def make_curve_over_setups( alternate_groupnames = [alternate_groupnames] # for setupname, stat in statistics_dict.items(): - assert ( - metric in stat.metricnames - ), f"metric {metric} not in statistic obj {setupname}" + assert metric in stat.metricnames, f"metric {metric} not in statistic obj {setupname}" setupnames = list(statistics_dict.keys()) convert_x_to_digit = True @@ -247,7 +306,7 @@ def make_curve_over_setups( plt.grid("major") # Y values are average metric values in that group and metric for idx, g in enumerate(groups): - Y = [stat.avg_std(g, metric)[0] for stat in statistics_dict.values()] + Y = [ValueSummary(stat.get(g, metric, remove_nones=True)).avg for stat in statistics_dict.values()] if plot_lines: plt.plot( @@ -274,56 +333,46 @@ def plot_box( data: dict[str, np.ndarray], sort=True, orientation="h", - # graph_name: str = "Structure", - score: str = "Dice-Score", + name_method: str = "Structure", + name_metric: str = "Dice-Score", + figure_title: str = "", width=850, height=1200, - yaxis_title=None, - xaxis_title=None, + manual_metric_range: None | tuple[float, float] = None, ): - graph_name: str = "Structure" - - if xaxis_title is None: - xaxis_title = score if orientation == "h" else graph_name - if yaxis_title is None: - yaxis_title = score if orientation != "h" else graph_name + xaxis_title = name_metric if orientation == "h" else name_method + yaxis_title = name_metric if orientation != "h" else name_method data = {e.replace("_", " "): v for e, v in data.items()} df_data = pd.DataFrame( { - graph_name: _flatten_extend([([e] * len(y0)) for e, y0 in data.items()]), - score: np.concatenate([*data.values()], 0), + name_method: _flatten_extend([([e] * len(y0)) for e, y0 in data.items()]), + name_metric: np.concatenate([*data.values()], 0), } ) if sort: - df_by_spec_count = df_data.groupby(graph_name).mean() - df_by_spec_count = dict(df_by_spec_count[score].items()) - df_data["mean"] = df_data[graph_name].apply( - lambda x: df_by_spec_count[x] * (1 if orientation == "h" else -1) - ) + df_by_spec_count = df_data.groupby(name_method).mean() + df_by_spec_count = dict(df_by_spec_count[name_metric].items()) + df_data["mean"] = df_data[name_method].apply(lambda x: df_by_spec_count[x] * (1 if orientation == "h" else -1)) df_data = df_data.sort_values(by="mean") if orientation == "v": - fig = px.strip( - df_data, x=graph_name, y=score, stripmode="overlay", orientation=orientation - ) + fig = px.strip(df_data, x=name_method, y=name_metric, stripmode="overlay", orientation=orientation) fig.update_traces(marker={"size": 5, "color": "#555555"}) for e in data.keys(): fig.add_trace( go.Box( - y=df_data.query(f'{graph_name} == "{e}"')[score], + y=df_data.query(f'{name_method} == "{e}"')[name_metric], name=e, orientation=orientation, ) ) else: - fig = px.strip( - df_data, y=graph_name, x=score, stripmode="overlay", orientation=orientation - ) + fig = px.strip(df_data, y=name_method, x=name_metric, stripmode="overlay", orientation=orientation) fig.update_traces(marker={"size": 5, "color": "#555555"}) for e in data.keys(): fig.add_trace( go.Box( - x=df_data.query(f'{graph_name} == "{e}"')[score], + x=df_data.query(f'{name_method} == "{e}"')[name_metric], name=e, orientation=orientation, boxpoints=False, @@ -337,6 +386,12 @@ def plot_box( yaxis_title=yaxis_title, xaxis_title=xaxis_title, font={"family": "Arial"}, + title=figure_title, ) + if manual_metric_range is not None: + if orientation == "h": + fig.update_xaxes(range=[manual_metric_range[0], manual_metric_range[1]]) + else: + fig.update_yaxes(range=[manual_metric_range[0], manual_metric_range[1]]) fig.update_traces(orientation=orientation) return fig diff --git a/unit_tests/test_panoptic_aggregator.py b/unit_tests/test_panoptic_aggregator.py index 53b34b1..3e49376 100644 --- a/unit_tests/test_panoptic_aggregator.py +++ b/unit_tests/test_panoptic_aggregator.py @@ -97,69 +97,3 @@ def test_simple_evaluation_then_statistic(self): os.remove(str(output_test_dir)) - -class Test_Panoptica_Statistics(unittest.TestCase): - def setUp(self) -> None: - os.environ["PANOPTICA_CITATION_REMINDER"] = "False" - return super().setUp() - - def test_simple_statistic(self): - a = np.zeros([50, 50], dtype=np.uint16) - b = a.copy().astype(a.dtype) - a[20:40, 10:20] = 1 - b[20:35, 10:20] = 2 - - evaluator = Panoptica_Evaluator( - expected_input=InputType.SEMANTIC, - instance_approximator=ConnectedComponentsInstanceApproximator(), - instance_matcher=NaiveThresholdMatching(), - ) - - aggregator = Panoptica_Aggregator(evaluator, output_file=output_test_dir) - - aggregator.evaluate(b, a, "test") - - statistic_obj = Panoptica_Statistic.from_file(output_test_dir) - - statistic_obj.print_summary() - - self.assertEqual(statistic_obj.get("ungrouped", "tp"), [1.0]) - self.assertEqual(statistic_obj.get("ungrouped", "sq"), [0.75]) - self.assertEqual(statistic_obj.get("ungrouped", "sq_rvd"), [-0.25]) - - self.assertEqual(statistic_obj.avg_std("ungrouped", "tp")[0], 1.0) - self.assertEqual(statistic_obj.avg_std("ungrouped", "sq")[0], 0.75) - - os.remove(str(output_test_dir)) - - def test_multiple_samples_statistic(self): - a = np.zeros([50, 50], dtype=np.uint16) - b = a.copy().astype(a.dtype) - c = a.copy().astype(a.dtype) - a[20:40, 10:20] = 1 - b[20:35, 10:20] = 2 - c[20:40, 10:20] = 5 - c[0:10, 0:10] = 3 - - evaluator = Panoptica_Evaluator( - expected_input=InputType.SEMANTIC, - instance_approximator=ConnectedComponentsInstanceApproximator(), - instance_matcher=NaiveThresholdMatching(), - ) - - aggregator = Panoptica_Aggregator(evaluator, output_file=output_test_dir) - - aggregator.evaluate(b, a, "test") - aggregator.evaluate(a, c, "test2") - - statistic_obj = Panoptica_Statistic.from_file(output_test_dir) - - statistic_obj.print_summary() - - self.assertEqual(statistic_obj.avg_std("ungrouped", "tp")[0], 1.0) - self.assertEqual(statistic_obj.avg_std("ungrouped", "sq")[0], 0.875) - self.assertEqual(statistic_obj.avg_std("ungrouped", "fn")[0], 0.5) - self.assertEqual(statistic_obj.avg_std("ungrouped", "rec")[0], 0.75) - self.assertEqual(statistic_obj.avg_std("ungrouped", "rec")[1], 0.25) - - os.remove(str(output_test_dir)) diff --git a/unit_tests/test_panoptic_statistics.py b/unit_tests/test_panoptic_statistics.py new file mode 100644 index 0000000..3aebe71 --- /dev/null +++ b/unit_tests/test_panoptic_statistics.py @@ -0,0 +1,112 @@ +# Call 'python -m unittest' on this folder +# coverage run -m unittest +# coverage report +# coverage html +import os +import unittest + +import numpy as np + +from panoptica import InputType, Panoptica_Aggregator, Panoptica_Statistic, ValueSummary +from panoptica.instance_approximator import ConnectedComponentsInstanceApproximator +from panoptica.instance_matcher import MaximizeMergeMatching, NaiveThresholdMatching +from panoptica.metrics import Metric +from panoptica.panoptica_evaluator import Panoptica_Evaluator +from panoptica.panoptica_result import MetricCouldNotBeComputedException +from panoptica.utils.processing_pair import SemanticPair +from panoptica.utils.segmentation_class import SegmentationClassGroups +import sys +from pathlib import Path + + +output_test_dir = Path(__file__).parent.joinpath("unittest_tmp_file.tsv") + +input_test_file = Path(__file__).parent.joinpath("test_unittest_file.tsv") + + +class Test_Panoptica_Statistics(unittest.TestCase): + def setUp(self) -> None: + os.environ["PANOPTICA_CITATION_REMINDER"] = "False" + return super().setUp() + + def test_simple_statistic(self): + a = np.zeros([50, 50], dtype=np.uint16) + b = a.copy().astype(a.dtype) + a[20:40, 10:20] = 1 + b[20:35, 10:20] = 2 + + evaluator = Panoptica_Evaluator( + expected_input=InputType.SEMANTIC, + instance_approximator=ConnectedComponentsInstanceApproximator(), + instance_matcher=NaiveThresholdMatching(), + ) + + aggregator = Panoptica_Aggregator(evaluator, output_file=output_test_dir) + + aggregator.evaluate(b, a, "test") + + statistic_obj = Panoptica_Statistic.from_file(output_test_dir) + + statistic_obj.print_summary() + + self.assertEqual(statistic_obj.get("ungrouped", "tp"), [1.0]) + self.assertEqual(statistic_obj.get("ungrouped", "sq"), [0.75]) + self.assertEqual(statistic_obj.get("ungrouped", "sq_rvd"), [-0.25]) + + tp_values = statistic_obj.get("ungrouped", "tp") + sq_values = statistic_obj.get("ungrouped", "sq") + self.assertEqual(ValueSummary(tp_values).avg, 1.0) + self.assertEqual(ValueSummary(sq_values).avg, 0.75) + + os.remove(str(output_test_dir)) + + def test_multiple_samples_statistic(self): + a = np.zeros([50, 50], dtype=np.uint16) + b = a.copy().astype(a.dtype) + c = a.copy().astype(a.dtype) + a[20:40, 10:20] = 1 + b[20:35, 10:20] = 2 + c[20:40, 10:20] = 5 + c[0:10, 0:10] = 3 + + evaluator = Panoptica_Evaluator( + expected_input=InputType.SEMANTIC, + instance_approximator=ConnectedComponentsInstanceApproximator(), + instance_matcher=NaiveThresholdMatching(), + ) + + aggregator = Panoptica_Aggregator(evaluator, output_file=output_test_dir) + + aggregator.evaluate(b, a, "test") + aggregator.evaluate(a, c, "test2") + + statistic_obj = Panoptica_Statistic.from_file(output_test_dir) + + statistic_obj.print_summary() + + self.assertEqual(ValueSummary(statistic_obj.get("ungrouped", "tp")).avg, 1.0) + self.assertEqual(ValueSummary(statistic_obj.get("ungrouped", "sq")).avg, 0.875) + self.assertEqual(ValueSummary(statistic_obj.get("ungrouped", "fn")).avg, 0.5) + self.assertEqual(ValueSummary(statistic_obj.get("ungrouped", "rec")).avg, 0.75) + self.assertEqual(ValueSummary(statistic_obj.get("ungrouped", "rec")).std, 0.25) + + os.remove(str(output_test_dir)) + + def test_statistics_from_file(self): + statistic_obj = Panoptica_Statistic.from_file(input_test_file) + # + test2 = statistic_obj.get_one_subject("test2") # get one subject + print() + print("test2", test2) + self.assertEqual(test2["ungrouped"]["num_ref_instances"], 2) + + all_num_ref_instances = statistic_obj.get_across_groups("num_ref_instances") + print() + print("all_num_ref_instances", all_num_ref_instances) + self.assertEqual(len(all_num_ref_instances), 2) + self.assertEqual(sum(all_num_ref_instances), 3) + + groupwise_summary = statistic_obj.get_summary_across_groups() + print() + print(groupwise_summary) + self.assertEqual(groupwise_summary["num_ref_instances"].avg, 1.5) From ac6b27d8dfdaa51bd66e859c8b869cb1a45bc246 Mon Sep 17 00:00:00 2001 From: "brainless-bot[bot]" <153751247+brainless-bot[bot]@users.noreply.github.com> Date: Mon, 25 Nov 2024 12:19:01 +0000 Subject: [PATCH 2/3] Autoformat with black --- panoptica/panoptica_statistics.py | 69 ++++++++++++++++++++------ unit_tests/test_panoptic_aggregator.py | 1 - 2 files changed, 55 insertions(+), 15 deletions(-) diff --git a/panoptica/panoptica_statistics.py b/panoptica/panoptica_statistics.py index 4b0261a..7495054 100644 --- a/panoptica/panoptica_statistics.py +++ b/panoptica/panoptica_statistics.py @@ -86,7 +86,9 @@ def from_file(cls, file: str): rows = [row for row in rd] header = rows[0] - assert header[0] == "subject_name", "First column is not subject_names, something wrong with the file?" + assert ( + header[0] == "subject_name" + ), "First column is not subject_names, something wrong with the file?" keys_in_order = list([tuple(c.split("-")) for c in header[1:]]) metric_names = [] @@ -127,13 +129,19 @@ def from_file(cls, file: str): return Panoptica_Statistic(subj_names=subj_names, value_dict=value_dict) def _assertgroup(self, group): - assert group in self.__groupnames, f"group {group} not existent, only got groups {self.__groupnames}" + assert ( + group in self.__groupnames + ), f"group {group} not existent, only got groups {self.__groupnames}" def _assertmetric(self, metric): - assert metric in self.__metricnames, f"metric {metric} not existent, only got metrics {self.__metricnames}" + assert ( + metric in self.__metricnames + ), f"metric {metric} not existent, only got metrics {self.__metricnames}" def _assertsubject(self, subjectname): - assert subjectname in self.__subj_names, f"subject {subjectname} not in list of subjects, got {self.__subj_names}" + assert ( + subjectname in self.__subj_names + ), f"subject {subjectname} not in list of subjects, got {self.__subj_names}" def get(self, group, metric, remove_nones: bool = False) -> list[float]: """Returns the list of values for given group and metric @@ -166,7 +174,10 @@ def get_one_subject(self, subjectname: str): """ self._assertsubject(subjectname) sidx = self.__subj_names.index(subjectname) - return {g: {m: self.get(g, m)[sidx] for m in self.__metricnames} for g in self.__groupnames} + return { + g: {m: self.get(g, m)[sidx] for m in self.__metricnames} + for g in self.__groupnames + } def get_across_groups(self, metric) -> list[float]: """Given metric, gives list of all values (even across groups!) Treat with care! @@ -195,8 +206,13 @@ def get_summary_across_groups(self) -> dict[str, ValueSummary]: summary_dict[m] = ValueSummary(value_list) return summary_dict - def get_summary_dict(self, include_across_group: bool = True) -> dict[str, dict[str, ValueSummary]]: - summary_dict = {g: {m: self.get_summary(g, m) for m in self.__metricnames} for g in self.__groupnames} + def get_summary_dict( + self, include_across_group: bool = True + ) -> dict[str, dict[str, ValueSummary]]: + summary_dict = { + g: {m: self.get_summary(g, m) for m in self.__metricnames} + for g in self.__groupnames + } if include_across_group: summary_dict["across_groups"] = self.get_summary_across_groups() return summary_dict @@ -241,11 +257,17 @@ def get_summary_figure( _type_: _description_ """ orientation = "h" if horizontal else "v" - data_plot = {g: np.asarray(self.get(g, metric, remove_nones=True)) for g in self.__groupnames} + data_plot = { + g: np.asarray(self.get(g, metric, remove_nones=True)) + for g in self.__groupnames + } if manual_metric_range is not None: assert manual_metric_range[0] < manual_metric_range[1], manual_metric_range change = (manual_metric_range[1] - manual_metric_range[0]) / 100 - manual_metric_range = (manual_metric_range[0] - change, manual_metric_range[1] + change) + manual_metric_range = ( + manual_metric_range[0] - change, + manual_metric_range[1] + change, + ) return plot_box( data=data_plot, orientation=orientation, @@ -281,7 +303,9 @@ def make_curve_over_setups( alternate_groupnames = [alternate_groupnames] # for setupname, stat in statistics_dict.items(): - assert metric in stat.metricnames, f"metric {metric} not in statistic obj {setupname}" + assert ( + metric in stat.metricnames + ), f"metric {metric} not in statistic obj {setupname}" setupnames = list(statistics_dict.keys()) convert_x_to_digit = True @@ -306,7 +330,10 @@ def make_curve_over_setups( plt.grid("major") # Y values are average metric values in that group and metric for idx, g in enumerate(groups): - Y = [ValueSummary(stat.get(g, metric, remove_nones=True)).avg for stat in statistics_dict.values()] + Y = [ + ValueSummary(stat.get(g, metric, remove_nones=True)).avg + for stat in statistics_dict.values() + ] if plot_lines: plt.plot( @@ -353,10 +380,18 @@ def plot_box( if sort: df_by_spec_count = df_data.groupby(name_method).mean() df_by_spec_count = dict(df_by_spec_count[name_metric].items()) - df_data["mean"] = df_data[name_method].apply(lambda x: df_by_spec_count[x] * (1 if orientation == "h" else -1)) + df_data["mean"] = df_data[name_method].apply( + lambda x: df_by_spec_count[x] * (1 if orientation == "h" else -1) + ) df_data = df_data.sort_values(by="mean") if orientation == "v": - fig = px.strip(df_data, x=name_method, y=name_metric, stripmode="overlay", orientation=orientation) + fig = px.strip( + df_data, + x=name_method, + y=name_metric, + stripmode="overlay", + orientation=orientation, + ) fig.update_traces(marker={"size": 5, "color": "#555555"}) for e in data.keys(): fig.add_trace( @@ -367,7 +402,13 @@ def plot_box( ) ) else: - fig = px.strip(df_data, y=name_method, x=name_metric, stripmode="overlay", orientation=orientation) + fig = px.strip( + df_data, + y=name_method, + x=name_metric, + stripmode="overlay", + orientation=orientation, + ) fig.update_traces(marker={"size": 5, "color": "#555555"}) for e in data.keys(): fig.add_trace( diff --git a/unit_tests/test_panoptic_aggregator.py b/unit_tests/test_panoptic_aggregator.py index 3e49376..cbba390 100644 --- a/unit_tests/test_panoptic_aggregator.py +++ b/unit_tests/test_panoptic_aggregator.py @@ -96,4 +96,3 @@ def test_simple_evaluation_then_statistic(self): statistic_obj.print_summary() os.remove(str(output_test_dir)) - From 2d1ffc27859aa4fd7a52c0df50b19286d05c319f Mon Sep 17 00:00:00 2001 From: iback Date: Mon, 25 Nov 2024 12:22:07 +0000 Subject: [PATCH 3/3] added test file --- unit_tests/test_unittest_file.tsv | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 unit_tests/test_unittest_file.tsv diff --git a/unit_tests/test_unittest_file.tsv b/unit_tests/test_unittest_file.tsv new file mode 100644 index 0000000..68b2144 --- /dev/null +++ b/unit_tests/test_unittest_file.tsv @@ -0,0 +1,3 @@ +subject_name ungrouped-num_ref_instances ungrouped-num_pred_instances ungrouped-tp ungrouped-fp ungrouped-fn ungrouped-prec ungrouped-rec ungrouped-rq ungrouped-sq ungrouped-sq_std ungrouped-pq ungrouped-sq_dsc ungrouped-sq_dsc_std ungrouped-pq_dsc ungrouped-sq_assd ungrouped-sq_assd_std ungrouped-sq_rvd ungrouped-sq_rvd_std ungrouped-global_bin_dsc +test 1 1 1 0 0 1.0 1.0 1.0 0.75 0.0 0.75 0.8571428571428571 0.0 0.8571428571428571 0.842391304347826 0.0 -0.25 0.0 0.8571428571428571 +test2 2 1 1 0 1 1.0 0.5 0.6666666666666666 1.0 0.0 0.6666666666666666 1.0 0.0 0.6666666666666666 0.0 0.0 0.0 0.0 0.8