Skip to content

Commit

Permalink
Autoformat with black
Browse files Browse the repository at this point in the history
  • Loading branch information
brainless-bot[bot] committed Nov 25, 2024
1 parent 80a5088 commit ac6b27d
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 15 deletions.
69 changes: 55 additions & 14 deletions panoptica/panoptica_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,9 @@ def from_file(cls, file: str):
rows = [row for row in rd]

header = rows[0]
assert header[0] == "subject_name", "First column is not subject_names, something wrong with the file?"
assert (
header[0] == "subject_name"
), "First column is not subject_names, something wrong with the file?"

keys_in_order = list([tuple(c.split("-")) for c in header[1:]])
metric_names = []
Expand Down Expand Up @@ -127,13 +129,19 @@ def from_file(cls, file: str):
return Panoptica_Statistic(subj_names=subj_names, value_dict=value_dict)

def _assertgroup(self, group):
assert group in self.__groupnames, f"group {group} not existent, only got groups {self.__groupnames}"
assert (
group in self.__groupnames
), f"group {group} not existent, only got groups {self.__groupnames}"

def _assertmetric(self, metric):
assert metric in self.__metricnames, f"metric {metric} not existent, only got metrics {self.__metricnames}"
assert (
metric in self.__metricnames
), f"metric {metric} not existent, only got metrics {self.__metricnames}"

def _assertsubject(self, subjectname):
assert subjectname in self.__subj_names, f"subject {subjectname} not in list of subjects, got {self.__subj_names}"
assert (
subjectname in self.__subj_names
), f"subject {subjectname} not in list of subjects, got {self.__subj_names}"

def get(self, group, metric, remove_nones: bool = False) -> list[float]:
"""Returns the list of values for given group and metric
Expand Down Expand Up @@ -166,7 +174,10 @@ def get_one_subject(self, subjectname: str):
"""
self._assertsubject(subjectname)
sidx = self.__subj_names.index(subjectname)
return {g: {m: self.get(g, m)[sidx] for m in self.__metricnames} for g in self.__groupnames}
return {
g: {m: self.get(g, m)[sidx] for m in self.__metricnames}
for g in self.__groupnames
}

def get_across_groups(self, metric) -> list[float]:
"""Given metric, gives list of all values (even across groups!) Treat with care!
Expand Down Expand Up @@ -195,8 +206,13 @@ def get_summary_across_groups(self) -> dict[str, ValueSummary]:
summary_dict[m] = ValueSummary(value_list)
return summary_dict

def get_summary_dict(self, include_across_group: bool = True) -> dict[str, dict[str, ValueSummary]]:
summary_dict = {g: {m: self.get_summary(g, m) for m in self.__metricnames} for g in self.__groupnames}
def get_summary_dict(
self, include_across_group: bool = True
) -> dict[str, dict[str, ValueSummary]]:
summary_dict = {
g: {m: self.get_summary(g, m) for m in self.__metricnames}
for g in self.__groupnames
}
if include_across_group:
summary_dict["across_groups"] = self.get_summary_across_groups()
return summary_dict
Expand Down Expand Up @@ -241,11 +257,17 @@ def get_summary_figure(
_type_: _description_
"""
orientation = "h" if horizontal else "v"
data_plot = {g: np.asarray(self.get(g, metric, remove_nones=True)) for g in self.__groupnames}
data_plot = {
g: np.asarray(self.get(g, metric, remove_nones=True))
for g in self.__groupnames
}
if manual_metric_range is not None:
assert manual_metric_range[0] < manual_metric_range[1], manual_metric_range
change = (manual_metric_range[1] - manual_metric_range[0]) / 100
manual_metric_range = (manual_metric_range[0] - change, manual_metric_range[1] + change)
manual_metric_range = (
manual_metric_range[0] - change,
manual_metric_range[1] + change,
)
return plot_box(
data=data_plot,
orientation=orientation,
Expand Down Expand Up @@ -281,7 +303,9 @@ def make_curve_over_setups(
alternate_groupnames = [alternate_groupnames]
#
for setupname, stat in statistics_dict.items():
assert metric in stat.metricnames, f"metric {metric} not in statistic obj {setupname}"
assert (
metric in stat.metricnames
), f"metric {metric} not in statistic obj {setupname}"

setupnames = list(statistics_dict.keys())
convert_x_to_digit = True
Expand All @@ -306,7 +330,10 @@ def make_curve_over_setups(
plt.grid("major")
# Y values are average metric values in that group and metric
for idx, g in enumerate(groups):
Y = [ValueSummary(stat.get(g, metric, remove_nones=True)).avg for stat in statistics_dict.values()]
Y = [
ValueSummary(stat.get(g, metric, remove_nones=True)).avg
for stat in statistics_dict.values()
]

if plot_lines:
plt.plot(
Expand Down Expand Up @@ -353,10 +380,18 @@ def plot_box(
if sort:
df_by_spec_count = df_data.groupby(name_method).mean()
df_by_spec_count = dict(df_by_spec_count[name_metric].items())
df_data["mean"] = df_data[name_method].apply(lambda x: df_by_spec_count[x] * (1 if orientation == "h" else -1))
df_data["mean"] = df_data[name_method].apply(
lambda x: df_by_spec_count[x] * (1 if orientation == "h" else -1)
)
df_data = df_data.sort_values(by="mean")
if orientation == "v":
fig = px.strip(df_data, x=name_method, y=name_metric, stripmode="overlay", orientation=orientation)
fig = px.strip(
df_data,
x=name_method,
y=name_metric,
stripmode="overlay",
orientation=orientation,
)
fig.update_traces(marker={"size": 5, "color": "#555555"})
for e in data.keys():
fig.add_trace(
Expand All @@ -367,7 +402,13 @@ def plot_box(
)
)
else:
fig = px.strip(df_data, y=name_method, x=name_metric, stripmode="overlay", orientation=orientation)
fig = px.strip(
df_data,
y=name_method,
x=name_metric,
stripmode="overlay",
orientation=orientation,
)
fig.update_traces(marker={"size": 5, "color": "#555555"})
for e in data.keys():
fig.add_trace(
Expand Down
1 change: 0 additions & 1 deletion unit_tests/test_panoptic_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,4 +96,3 @@ def test_simple_evaluation_then_statistic(self):
statistic_obj.print_summary()

os.remove(str(output_test_dir))

0 comments on commit ac6b27d

Please sign in to comment.