From ac6b27d8dfdaa51bd66e859c8b869cb1a45bc246 Mon Sep 17 00:00:00 2001 From: "brainless-bot[bot]" <153751247+brainless-bot[bot]@users.noreply.github.com> Date: Mon, 25 Nov 2024 12:19:01 +0000 Subject: [PATCH] Autoformat with black --- panoptica/panoptica_statistics.py | 69 ++++++++++++++++++++------ unit_tests/test_panoptic_aggregator.py | 1 - 2 files changed, 55 insertions(+), 15 deletions(-) diff --git a/panoptica/panoptica_statistics.py b/panoptica/panoptica_statistics.py index 4b0261a..7495054 100644 --- a/panoptica/panoptica_statistics.py +++ b/panoptica/panoptica_statistics.py @@ -86,7 +86,9 @@ def from_file(cls, file: str): rows = [row for row in rd] header = rows[0] - assert header[0] == "subject_name", "First column is not subject_names, something wrong with the file?" + assert ( + header[0] == "subject_name" + ), "First column is not subject_names, something wrong with the file?" keys_in_order = list([tuple(c.split("-")) for c in header[1:]]) metric_names = [] @@ -127,13 +129,19 @@ def from_file(cls, file: str): return Panoptica_Statistic(subj_names=subj_names, value_dict=value_dict) def _assertgroup(self, group): - assert group in self.__groupnames, f"group {group} not existent, only got groups {self.__groupnames}" + assert ( + group in self.__groupnames + ), f"group {group} not existent, only got groups {self.__groupnames}" def _assertmetric(self, metric): - assert metric in self.__metricnames, f"metric {metric} not existent, only got metrics {self.__metricnames}" + assert ( + metric in self.__metricnames + ), f"metric {metric} not existent, only got metrics {self.__metricnames}" def _assertsubject(self, subjectname): - assert subjectname in self.__subj_names, f"subject {subjectname} not in list of subjects, got {self.__subj_names}" + assert ( + subjectname in self.__subj_names + ), f"subject {subjectname} not in list of subjects, got {self.__subj_names}" def get(self, group, metric, remove_nones: bool = False) -> list[float]: """Returns the list of values for given group and metric @@ -166,7 +174,10 @@ def get_one_subject(self, subjectname: str): """ self._assertsubject(subjectname) sidx = self.__subj_names.index(subjectname) - return {g: {m: self.get(g, m)[sidx] for m in self.__metricnames} for g in self.__groupnames} + return { + g: {m: self.get(g, m)[sidx] for m in self.__metricnames} + for g in self.__groupnames + } def get_across_groups(self, metric) -> list[float]: """Given metric, gives list of all values (even across groups!) Treat with care! @@ -195,8 +206,13 @@ def get_summary_across_groups(self) -> dict[str, ValueSummary]: summary_dict[m] = ValueSummary(value_list) return summary_dict - def get_summary_dict(self, include_across_group: bool = True) -> dict[str, dict[str, ValueSummary]]: - summary_dict = {g: {m: self.get_summary(g, m) for m in self.__metricnames} for g in self.__groupnames} + def get_summary_dict( + self, include_across_group: bool = True + ) -> dict[str, dict[str, ValueSummary]]: + summary_dict = { + g: {m: self.get_summary(g, m) for m in self.__metricnames} + for g in self.__groupnames + } if include_across_group: summary_dict["across_groups"] = self.get_summary_across_groups() return summary_dict @@ -241,11 +257,17 @@ def get_summary_figure( _type_: _description_ """ orientation = "h" if horizontal else "v" - data_plot = {g: np.asarray(self.get(g, metric, remove_nones=True)) for g in self.__groupnames} + data_plot = { + g: np.asarray(self.get(g, metric, remove_nones=True)) + for g in self.__groupnames + } if manual_metric_range is not None: assert manual_metric_range[0] < manual_metric_range[1], manual_metric_range change = (manual_metric_range[1] - manual_metric_range[0]) / 100 - manual_metric_range = (manual_metric_range[0] - change, manual_metric_range[1] + change) + manual_metric_range = ( + manual_metric_range[0] - change, + manual_metric_range[1] + change, + ) return plot_box( data=data_plot, orientation=orientation, @@ -281,7 +303,9 @@ def make_curve_over_setups( alternate_groupnames = [alternate_groupnames] # for setupname, stat in statistics_dict.items(): - assert metric in stat.metricnames, f"metric {metric} not in statistic obj {setupname}" + assert ( + metric in stat.metricnames + ), f"metric {metric} not in statistic obj {setupname}" setupnames = list(statistics_dict.keys()) convert_x_to_digit = True @@ -306,7 +330,10 @@ def make_curve_over_setups( plt.grid("major") # Y values are average metric values in that group and metric for idx, g in enumerate(groups): - Y = [ValueSummary(stat.get(g, metric, remove_nones=True)).avg for stat in statistics_dict.values()] + Y = [ + ValueSummary(stat.get(g, metric, remove_nones=True)).avg + for stat in statistics_dict.values() + ] if plot_lines: plt.plot( @@ -353,10 +380,18 @@ def plot_box( if sort: df_by_spec_count = df_data.groupby(name_method).mean() df_by_spec_count = dict(df_by_spec_count[name_metric].items()) - df_data["mean"] = df_data[name_method].apply(lambda x: df_by_spec_count[x] * (1 if orientation == "h" else -1)) + df_data["mean"] = df_data[name_method].apply( + lambda x: df_by_spec_count[x] * (1 if orientation == "h" else -1) + ) df_data = df_data.sort_values(by="mean") if orientation == "v": - fig = px.strip(df_data, x=name_method, y=name_metric, stripmode="overlay", orientation=orientation) + fig = px.strip( + df_data, + x=name_method, + y=name_metric, + stripmode="overlay", + orientation=orientation, + ) fig.update_traces(marker={"size": 5, "color": "#555555"}) for e in data.keys(): fig.add_trace( @@ -367,7 +402,13 @@ def plot_box( ) ) else: - fig = px.strip(df_data, y=name_method, x=name_metric, stripmode="overlay", orientation=orientation) + fig = px.strip( + df_data, + y=name_method, + x=name_metric, + stripmode="overlay", + orientation=orientation, + ) fig.update_traces(marker={"size": 5, "color": "#555555"}) for e in data.keys(): fig.add_trace( diff --git a/unit_tests/test_panoptic_aggregator.py b/unit_tests/test_panoptic_aggregator.py index 3e49376..cbba390 100644 --- a/unit_tests/test_panoptic_aggregator.py +++ b/unit_tests/test_panoptic_aggregator.py @@ -96,4 +96,3 @@ def test_simple_evaluation_then_statistic(self): statistic_obj.print_summary() os.remove(str(output_test_dir)) -