Skip to content

Commit

Permalink
Merge pull request #155 from BrainLesion/panoptica_statistics_ugrade
Browse files Browse the repository at this point in the history
Panoptica Statistics Beta 2
  • Loading branch information
Hendrik-code authored Nov 25, 2024
2 parents ef36e5a + 2d1ffc2 commit 17e53dc
Show file tree
Hide file tree
Showing 5 changed files with 255 additions and 111 deletions.
2 changes: 1 addition & 1 deletion panoptica/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
CCABackend,
)
from panoptica.instance_matcher import NaiveThresholdMatching
from panoptica.panoptica_statistics import Panoptica_Statistic
from panoptica.panoptica_statistics import Panoptica_Statistic, ValueSummary
from panoptica.panoptica_aggregator import Panoptica_Aggregator
from panoptica.panoptica_evaluator import Panoptica_Evaluator
from panoptica.panoptica_result import PanopticaResult
Expand Down
182 changes: 139 additions & 43 deletions panoptica/panoptica_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,35 @@
print("OPTIONAL PACKAGE MISSING")


class ValueSummary:
def __init__(self, value_list: list[float]) -> None:
self.__value_list = value_list
self.__avg = float(np.average(value_list))
self.__std = float(np.std(value_list))
self.__min = min(value_list)
self.__max = max(value_list)

@property
def values(self) -> list[float]:
return self.__value_list

@property
def avg(self) -> float:
return self.__avg

@property
def std(self) -> float:
return self.__std

@property
def min(self) -> float:
return self.__min

@property
def max(self) -> float:
return self.__max


class Panoptica_Statistic:

def __init__(
Expand All @@ -26,6 +55,20 @@ def __init__(
self.__groupnames = list(value_dict.keys())
self.__metricnames = list(value_dict[self.__groupnames[0]].keys())

# assert length of everything
for g in self.groupnames:
assert len(self.metricnames) == len(
list(value_dict[g].keys())
), f"Group {g}, has inconsistent number of metrics, got {len(list(value_dict[g].keys()))} but expected {len(self.metricnames)}"
for m in self.metricnames:
assert len(self.get(g, m)) == len(
self.subjectnames
), f"Group {g}, m {m} has not right subjects, got {len(self.get(g, m))}, expected {len(self.subjectnames)}"

@property
def subjectnames(self):
return self.__subj_names

@property
def groupnames(self):
return self.__groupnames
Expand Down Expand Up @@ -76,8 +119,12 @@ def from_file(cls, file: str):

if len(value) > 0:
value = float(value)
if not np.isnan(value) and value != np.inf:
if value is not None and not np.isnan(value) and value != np.inf:
value_dict[group_name][metric_name].append(float(value))
else:
value_dict[group_name][metric_name].append(None)
else:
value_dict[group_name][metric_name].append(None)

return Panoptica_Statistic(subj_names=subj_names, value_dict=value_dict)

Expand All @@ -96,7 +143,7 @@ def _assertsubject(self, subjectname):
subjectname in self.__subj_names
), f"subject {subjectname} not in list of subjects, got {self.__subj_names}"

def get(self, group, metric) -> list[float]:
def get(self, group, metric, remove_nones: bool = False) -> list[float]:
"""Returns the list of values for given group and metric
Args:
Expand All @@ -112,7 +159,9 @@ def get(self, group, metric) -> list[float]:
assert (
group in self.__value_dict and metric in self.__value_dict[group]
), f"Values not found for group {group} and metric {metric} evem though they should!"
return self.__value_dict[group][metric]
if not remove_nones:
return self.__value_dict[group][metric]
return [i for i in self.__value_dict[group][metric] if i is not None]

def get_one_subject(self, subjectname: str):
"""Gets the values for ONE subject for each group and metric
Expand All @@ -130,7 +179,7 @@ def get_one_subject(self, subjectname: str):
for g in self.__groupnames
}

def get_across_groups(self, metric):
def get_across_groups(self, metric) -> list[float]:
"""Given metric, gives list of all values (even across groups!) Treat with care!
Args:
Expand All @@ -141,41 +190,62 @@ def get_across_groups(self, metric):
"""
values = []
for g in self.__groupnames:
values.append(self.get(g, metric))
values += self.get(g, metric)
return values

def get_summary_dict(self):
return {
def get_summary_across_groups(self) -> dict[str, ValueSummary]:
"""Calculates the average and std over all groups (so group-wise avg first, then average over those)
Returns:
dict[str, tuple[float, float]]: _description_
"""
summary_dict = {}
for m in self.__metricnames:
value_list = [self.get_summary(g, m).avg for g in self.__groupnames]
assert len(value_list) == len(self.__groupnames)
summary_dict[m] = ValueSummary(value_list)
return summary_dict

def get_summary_dict(
self, include_across_group: bool = True
) -> dict[str, dict[str, ValueSummary]]:
summary_dict = {
g: {m: self.get_summary(g, m) for m in self.__metricnames}
for g in self.__groupnames
}
if include_across_group:
summary_dict["across_groups"] = self.get_summary_across_groups()
return summary_dict

def get_summary(self, group, metric):
# TODO maybe more here? range, stuff like that
return self.avg_std(group, metric)
def get_summary(self, group, metric) -> ValueSummary:
values = self.get(group, metric, remove_nones=True)
return ValueSummary(values)

def avg_std(self, group, metric) -> tuple[float, float]:
values = self.get(group, metric)
avg = float(np.average(values))
std = float(np.std(values))
return (avg, std)

def print_summary(self, ndigits: int = 3):
summary = self.get_summary_dict()
def print_summary(
self,
ndigits: int = 3,
only_across_groups: bool = True,
):
summary = self.get_summary_dict(include_across_group=only_across_groups)
print()
for g in self.__groupnames:
groups = list(summary.keys())
if only_across_groups:
groups = ["across_groups"]
for g in groups:
print(f"Group {g}:")
for m in self.__metricnames:
avg, std = summary[g][m]
avg, std = summary[g][m].avg, summary[g][m].std
print(m, ":", round(avg, ndigits), "+-", round(std, ndigits))
print()

def get_summary_figure(
self,
metric: str,
manual_metric_range: None | tuple[float, float] = None,
name_method: str = "Structure",
horizontal: bool = True,
sort: bool = True,
# title overwrite?
title: str = "",
):
"""Returns a figure object that shows the given metric for each group and its std
Expand All @@ -187,12 +257,25 @@ def get_summary_figure(
_type_: _description_
"""
orientation = "h" if horizontal else "v"
data_plot = {g: np.asarray(self.get(g, metric)) for g in self.__groupnames}
data_plot = {
g: np.asarray(self.get(g, metric, remove_nones=True))
for g in self.__groupnames
}
if manual_metric_range is not None:
assert manual_metric_range[0] < manual_metric_range[1], manual_metric_range
change = (manual_metric_range[1] - manual_metric_range[0]) / 100
manual_metric_range = (
manual_metric_range[0] - change,
manual_metric_range[1] + change,
)
return plot_box(
data=data_plot,
orientation=orientation,
score=metric,
name_method=name_method,
name_metric=metric,
sort=sort,
figure_title=title,
manual_metric_range=manual_metric_range,
)

# groupwise or in total
Expand Down Expand Up @@ -247,7 +330,10 @@ def make_curve_over_setups(
plt.grid("major")
# Y values are average metric values in that group and metric
for idx, g in enumerate(groups):
Y = [stat.avg_std(g, metric)[0] for stat in statistics_dict.values()]
Y = [
ValueSummary(stat.get(g, metric, remove_nones=True)).avg
for stat in statistics_dict.values()
]

if plot_lines:
plt.plot(
Expand All @@ -274,56 +360,60 @@ def plot_box(
data: dict[str, np.ndarray],
sort=True,
orientation="h",
# graph_name: str = "Structure",
score: str = "Dice-Score",
name_method: str = "Structure",
name_metric: str = "Dice-Score",
figure_title: str = "",
width=850,
height=1200,
yaxis_title=None,
xaxis_title=None,
manual_metric_range: None | tuple[float, float] = None,
):
graph_name: str = "Structure"

if xaxis_title is None:
xaxis_title = score if orientation == "h" else graph_name
if yaxis_title is None:
yaxis_title = score if orientation != "h" else graph_name
xaxis_title = name_metric if orientation == "h" else name_method
yaxis_title = name_metric if orientation != "h" else name_method

data = {e.replace("_", " "): v for e, v in data.items()}
df_data = pd.DataFrame(
{
graph_name: _flatten_extend([([e] * len(y0)) for e, y0 in data.items()]),
score: np.concatenate([*data.values()], 0),
name_method: _flatten_extend([([e] * len(y0)) for e, y0 in data.items()]),
name_metric: np.concatenate([*data.values()], 0),
}
)
if sort:
df_by_spec_count = df_data.groupby(graph_name).mean()
df_by_spec_count = dict(df_by_spec_count[score].items())
df_data["mean"] = df_data[graph_name].apply(
df_by_spec_count = df_data.groupby(name_method).mean()
df_by_spec_count = dict(df_by_spec_count[name_metric].items())
df_data["mean"] = df_data[name_method].apply(
lambda x: df_by_spec_count[x] * (1 if orientation == "h" else -1)
)
df_data = df_data.sort_values(by="mean")
if orientation == "v":
fig = px.strip(
df_data, x=graph_name, y=score, stripmode="overlay", orientation=orientation
df_data,
x=name_method,
y=name_metric,
stripmode="overlay",
orientation=orientation,
)
fig.update_traces(marker={"size": 5, "color": "#555555"})
for e in data.keys():
fig.add_trace(
go.Box(
y=df_data.query(f'{graph_name} == "{e}"')[score],
y=df_data.query(f'{name_method} == "{e}"')[name_metric],
name=e,
orientation=orientation,
)
)
else:
fig = px.strip(
df_data, y=graph_name, x=score, stripmode="overlay", orientation=orientation
df_data,
y=name_method,
x=name_metric,
stripmode="overlay",
orientation=orientation,
)
fig.update_traces(marker={"size": 5, "color": "#555555"})
for e in data.keys():
fig.add_trace(
go.Box(
x=df_data.query(f'{graph_name} == "{e}"')[score],
x=df_data.query(f'{name_method} == "{e}"')[name_metric],
name=e,
orientation=orientation,
boxpoints=False,
Expand All @@ -337,6 +427,12 @@ def plot_box(
yaxis_title=yaxis_title,
xaxis_title=xaxis_title,
font={"family": "Arial"},
title=figure_title,
)
if manual_metric_range is not None:
if orientation == "h":
fig.update_xaxes(range=[manual_metric_range[0], manual_metric_range[1]])
else:
fig.update_yaxes(range=[manual_metric_range[0], manual_metric_range[1]])
fig.update_traces(orientation=orientation)
return fig
67 changes: 0 additions & 67 deletions unit_tests/test_panoptic_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,70 +96,3 @@ def test_simple_evaluation_then_statistic(self):
statistic_obj.print_summary()

os.remove(str(output_test_dir))


class Test_Panoptica_Statistics(unittest.TestCase):
def setUp(self) -> None:
os.environ["PANOPTICA_CITATION_REMINDER"] = "False"
return super().setUp()

def test_simple_statistic(self):
a = np.zeros([50, 50], dtype=np.uint16)
b = a.copy().astype(a.dtype)
a[20:40, 10:20] = 1
b[20:35, 10:20] = 2

evaluator = Panoptica_Evaluator(
expected_input=InputType.SEMANTIC,
instance_approximator=ConnectedComponentsInstanceApproximator(),
instance_matcher=NaiveThresholdMatching(),
)

aggregator = Panoptica_Aggregator(evaluator, output_file=output_test_dir)

aggregator.evaluate(b, a, "test")

statistic_obj = Panoptica_Statistic.from_file(output_test_dir)

statistic_obj.print_summary()

self.assertEqual(statistic_obj.get("ungrouped", "tp"), [1.0])
self.assertEqual(statistic_obj.get("ungrouped", "sq"), [0.75])
self.assertEqual(statistic_obj.get("ungrouped", "sq_rvd"), [-0.25])

self.assertEqual(statistic_obj.avg_std("ungrouped", "tp")[0], 1.0)
self.assertEqual(statistic_obj.avg_std("ungrouped", "sq")[0], 0.75)

os.remove(str(output_test_dir))

def test_multiple_samples_statistic(self):
a = np.zeros([50, 50], dtype=np.uint16)
b = a.copy().astype(a.dtype)
c = a.copy().astype(a.dtype)
a[20:40, 10:20] = 1
b[20:35, 10:20] = 2
c[20:40, 10:20] = 5
c[0:10, 0:10] = 3

evaluator = Panoptica_Evaluator(
expected_input=InputType.SEMANTIC,
instance_approximator=ConnectedComponentsInstanceApproximator(),
instance_matcher=NaiveThresholdMatching(),
)

aggregator = Panoptica_Aggregator(evaluator, output_file=output_test_dir)

aggregator.evaluate(b, a, "test")
aggregator.evaluate(a, c, "test2")

statistic_obj = Panoptica_Statistic.from_file(output_test_dir)

statistic_obj.print_summary()

self.assertEqual(statistic_obj.avg_std("ungrouped", "tp")[0], 1.0)
self.assertEqual(statistic_obj.avg_std("ungrouped", "sq")[0], 0.875)
self.assertEqual(statistic_obj.avg_std("ungrouped", "fn")[0], 0.5)
self.assertEqual(statistic_obj.avg_std("ungrouped", "rec")[0], 0.75)
self.assertEqual(statistic_obj.avg_std("ungrouped", "rec")[1], 0.25)

os.remove(str(output_test_dir))
Loading

0 comments on commit 17e53dc

Please sign in to comment.