Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Panoptica Statistics Beta 2 #155

Merged
merged 3 commits into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion panoptica/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
CCABackend,
)
from panoptica.instance_matcher import NaiveThresholdMatching
from panoptica.panoptica_statistics import Panoptica_Statistic
from panoptica.panoptica_statistics import Panoptica_Statistic, ValueSummary
from panoptica.panoptica_aggregator import Panoptica_Aggregator
from panoptica.panoptica_evaluator import Panoptica_Evaluator
from panoptica.panoptica_result import PanopticaResult
Expand Down
182 changes: 139 additions & 43 deletions panoptica/panoptica_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,35 @@
print("OPTIONAL PACKAGE MISSING")


class ValueSummary:
def __init__(self, value_list: list[float]) -> None:
self.__value_list = value_list
self.__avg = float(np.average(value_list))
self.__std = float(np.std(value_list))
self.__min = min(value_list)
self.__max = max(value_list)

@property
def values(self) -> list[float]:
return self.__value_list

@property
def avg(self) -> float:
return self.__avg

@property
def std(self) -> float:
return self.__std

@property
def min(self) -> float:
return self.__min

@property
def max(self) -> float:
return self.__max


class Panoptica_Statistic:

def __init__(
Expand All @@ -26,6 +55,20 @@ def __init__(
self.__groupnames = list(value_dict.keys())
self.__metricnames = list(value_dict[self.__groupnames[0]].keys())

# assert length of everything
for g in self.groupnames:
assert len(self.metricnames) == len(
list(value_dict[g].keys())
), f"Group {g}, has inconsistent number of metrics, got {len(list(value_dict[g].keys()))} but expected {len(self.metricnames)}"
for m in self.metricnames:
assert len(self.get(g, m)) == len(
self.subjectnames
), f"Group {g}, m {m} has not right subjects, got {len(self.get(g, m))}, expected {len(self.subjectnames)}"

@property
def subjectnames(self):
return self.__subj_names

@property
def groupnames(self):
return self.__groupnames
Expand Down Expand Up @@ -76,8 +119,12 @@ def from_file(cls, file: str):

if len(value) > 0:
value = float(value)
if not np.isnan(value) and value != np.inf:
if value is not None and not np.isnan(value) and value != np.inf:
value_dict[group_name][metric_name].append(float(value))
else:
value_dict[group_name][metric_name].append(None)
else:
value_dict[group_name][metric_name].append(None)

return Panoptica_Statistic(subj_names=subj_names, value_dict=value_dict)

Expand All @@ -96,7 +143,7 @@ def _assertsubject(self, subjectname):
subjectname in self.__subj_names
), f"subject {subjectname} not in list of subjects, got {self.__subj_names}"

def get(self, group, metric) -> list[float]:
def get(self, group, metric, remove_nones: bool = False) -> list[float]:
"""Returns the list of values for given group and metric

Args:
Expand All @@ -112,7 +159,9 @@ def get(self, group, metric) -> list[float]:
assert (
group in self.__value_dict and metric in self.__value_dict[group]
), f"Values not found for group {group} and metric {metric} evem though they should!"
return self.__value_dict[group][metric]
if not remove_nones:
return self.__value_dict[group][metric]
return [i for i in self.__value_dict[group][metric] if i is not None]

def get_one_subject(self, subjectname: str):
"""Gets the values for ONE subject for each group and metric
Expand All @@ -130,7 +179,7 @@ def get_one_subject(self, subjectname: str):
for g in self.__groupnames
}

def get_across_groups(self, metric):
def get_across_groups(self, metric) -> list[float]:
"""Given metric, gives list of all values (even across groups!) Treat with care!

Args:
Expand All @@ -141,41 +190,62 @@ def get_across_groups(self, metric):
"""
values = []
for g in self.__groupnames:
values.append(self.get(g, metric))
values += self.get(g, metric)
return values

def get_summary_dict(self):
return {
def get_summary_across_groups(self) -> dict[str, ValueSummary]:
"""Calculates the average and std over all groups (so group-wise avg first, then average over those)

Returns:
dict[str, tuple[float, float]]: _description_
"""
summary_dict = {}
for m in self.__metricnames:
value_list = [self.get_summary(g, m).avg for g in self.__groupnames]
assert len(value_list) == len(self.__groupnames)
summary_dict[m] = ValueSummary(value_list)
return summary_dict

def get_summary_dict(
self, include_across_group: bool = True
) -> dict[str, dict[str, ValueSummary]]:
summary_dict = {
g: {m: self.get_summary(g, m) for m in self.__metricnames}
for g in self.__groupnames
}
if include_across_group:
summary_dict["across_groups"] = self.get_summary_across_groups()
return summary_dict

def get_summary(self, group, metric):
# TODO maybe more here? range, stuff like that
return self.avg_std(group, metric)
def get_summary(self, group, metric) -> ValueSummary:
values = self.get(group, metric, remove_nones=True)
return ValueSummary(values)

def avg_std(self, group, metric) -> tuple[float, float]:
values = self.get(group, metric)
avg = float(np.average(values))
std = float(np.std(values))
return (avg, std)

def print_summary(self, ndigits: int = 3):
summary = self.get_summary_dict()
def print_summary(
self,
ndigits: int = 3,
only_across_groups: bool = True,
):
summary = self.get_summary_dict(include_across_group=only_across_groups)
print()
for g in self.__groupnames:
groups = list(summary.keys())
if only_across_groups:
groups = ["across_groups"]
for g in groups:
print(f"Group {g}:")
for m in self.__metricnames:
avg, std = summary[g][m]
avg, std = summary[g][m].avg, summary[g][m].std
print(m, ":", round(avg, ndigits), "+-", round(std, ndigits))
print()

def get_summary_figure(
self,
metric: str,
manual_metric_range: None | tuple[float, float] = None,
name_method: str = "Structure",
horizontal: bool = True,
sort: bool = True,
# title overwrite?
title: str = "",
):
"""Returns a figure object that shows the given metric for each group and its std

Expand All @@ -187,12 +257,25 @@ def get_summary_figure(
_type_: _description_
"""
orientation = "h" if horizontal else "v"
data_plot = {g: np.asarray(self.get(g, metric)) for g in self.__groupnames}
data_plot = {
g: np.asarray(self.get(g, metric, remove_nones=True))
for g in self.__groupnames
}
if manual_metric_range is not None:
assert manual_metric_range[0] < manual_metric_range[1], manual_metric_range
change = (manual_metric_range[1] - manual_metric_range[0]) / 100
manual_metric_range = (
manual_metric_range[0] - change,
manual_metric_range[1] + change,
)
return plot_box(
data=data_plot,
orientation=orientation,
score=metric,
name_method=name_method,
name_metric=metric,
sort=sort,
figure_title=title,
manual_metric_range=manual_metric_range,
)

# groupwise or in total
Expand Down Expand Up @@ -247,7 +330,10 @@ def make_curve_over_setups(
plt.grid("major")
# Y values are average metric values in that group and metric
for idx, g in enumerate(groups):
Y = [stat.avg_std(g, metric)[0] for stat in statistics_dict.values()]
Y = [
ValueSummary(stat.get(g, metric, remove_nones=True)).avg
for stat in statistics_dict.values()
]

if plot_lines:
plt.plot(
Expand All @@ -274,56 +360,60 @@ def plot_box(
data: dict[str, np.ndarray],
sort=True,
orientation="h",
# graph_name: str = "Structure",
score: str = "Dice-Score",
name_method: str = "Structure",
name_metric: str = "Dice-Score",
figure_title: str = "",
width=850,
height=1200,
yaxis_title=None,
xaxis_title=None,
manual_metric_range: None | tuple[float, float] = None,
):
graph_name: str = "Structure"

if xaxis_title is None:
xaxis_title = score if orientation == "h" else graph_name
if yaxis_title is None:
yaxis_title = score if orientation != "h" else graph_name
xaxis_title = name_metric if orientation == "h" else name_method
yaxis_title = name_metric if orientation != "h" else name_method

data = {e.replace("_", " "): v for e, v in data.items()}
df_data = pd.DataFrame(
{
graph_name: _flatten_extend([([e] * len(y0)) for e, y0 in data.items()]),
score: np.concatenate([*data.values()], 0),
name_method: _flatten_extend([([e] * len(y0)) for e, y0 in data.items()]),
name_metric: np.concatenate([*data.values()], 0),
}
)
if sort:
df_by_spec_count = df_data.groupby(graph_name).mean()
df_by_spec_count = dict(df_by_spec_count[score].items())
df_data["mean"] = df_data[graph_name].apply(
df_by_spec_count = df_data.groupby(name_method).mean()
df_by_spec_count = dict(df_by_spec_count[name_metric].items())
df_data["mean"] = df_data[name_method].apply(
lambda x: df_by_spec_count[x] * (1 if orientation == "h" else -1)
)
df_data = df_data.sort_values(by="mean")
if orientation == "v":
fig = px.strip(
df_data, x=graph_name, y=score, stripmode="overlay", orientation=orientation
df_data,
x=name_method,
y=name_metric,
stripmode="overlay",
orientation=orientation,
)
fig.update_traces(marker={"size": 5, "color": "#555555"})
for e in data.keys():
fig.add_trace(
go.Box(
y=df_data.query(f'{graph_name} == "{e}"')[score],
y=df_data.query(f'{name_method} == "{e}"')[name_metric],
name=e,
orientation=orientation,
)
)
else:
fig = px.strip(
df_data, y=graph_name, x=score, stripmode="overlay", orientation=orientation
df_data,
y=name_method,
x=name_metric,
stripmode="overlay",
orientation=orientation,
)
fig.update_traces(marker={"size": 5, "color": "#555555"})
for e in data.keys():
fig.add_trace(
go.Box(
x=df_data.query(f'{graph_name} == "{e}"')[score],
x=df_data.query(f'{name_method} == "{e}"')[name_metric],
name=e,
orientation=orientation,
boxpoints=False,
Expand All @@ -337,6 +427,12 @@ def plot_box(
yaxis_title=yaxis_title,
xaxis_title=xaxis_title,
font={"family": "Arial"},
title=figure_title,
)
if manual_metric_range is not None:
if orientation == "h":
fig.update_xaxes(range=[manual_metric_range[0], manual_metric_range[1]])
else:
fig.update_yaxes(range=[manual_metric_range[0], manual_metric_range[1]])
fig.update_traces(orientation=orientation)
return fig
67 changes: 0 additions & 67 deletions unit_tests/test_panoptic_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,70 +96,3 @@ def test_simple_evaluation_then_statistic(self):
statistic_obj.print_summary()

os.remove(str(output_test_dir))


class Test_Panoptica_Statistics(unittest.TestCase):
def setUp(self) -> None:
os.environ["PANOPTICA_CITATION_REMINDER"] = "False"
return super().setUp()

def test_simple_statistic(self):
a = np.zeros([50, 50], dtype=np.uint16)
b = a.copy().astype(a.dtype)
a[20:40, 10:20] = 1
b[20:35, 10:20] = 2

evaluator = Panoptica_Evaluator(
expected_input=InputType.SEMANTIC,
instance_approximator=ConnectedComponentsInstanceApproximator(),
instance_matcher=NaiveThresholdMatching(),
)

aggregator = Panoptica_Aggregator(evaluator, output_file=output_test_dir)

aggregator.evaluate(b, a, "test")

statistic_obj = Panoptica_Statistic.from_file(output_test_dir)

statistic_obj.print_summary()

self.assertEqual(statistic_obj.get("ungrouped", "tp"), [1.0])
self.assertEqual(statistic_obj.get("ungrouped", "sq"), [0.75])
self.assertEqual(statistic_obj.get("ungrouped", "sq_rvd"), [-0.25])

self.assertEqual(statistic_obj.avg_std("ungrouped", "tp")[0], 1.0)
self.assertEqual(statistic_obj.avg_std("ungrouped", "sq")[0], 0.75)

os.remove(str(output_test_dir))

def test_multiple_samples_statistic(self):
a = np.zeros([50, 50], dtype=np.uint16)
b = a.copy().astype(a.dtype)
c = a.copy().astype(a.dtype)
a[20:40, 10:20] = 1
b[20:35, 10:20] = 2
c[20:40, 10:20] = 5
c[0:10, 0:10] = 3

evaluator = Panoptica_Evaluator(
expected_input=InputType.SEMANTIC,
instance_approximator=ConnectedComponentsInstanceApproximator(),
instance_matcher=NaiveThresholdMatching(),
)

aggregator = Panoptica_Aggregator(evaluator, output_file=output_test_dir)

aggregator.evaluate(b, a, "test")
aggregator.evaluate(a, c, "test2")

statistic_obj = Panoptica_Statistic.from_file(output_test_dir)

statistic_obj.print_summary()

self.assertEqual(statistic_obj.avg_std("ungrouped", "tp")[0], 1.0)
self.assertEqual(statistic_obj.avg_std("ungrouped", "sq")[0], 0.875)
self.assertEqual(statistic_obj.avg_std("ungrouped", "fn")[0], 0.5)
self.assertEqual(statistic_obj.avg_std("ungrouped", "rec")[0], 0.75)
self.assertEqual(statistic_obj.avg_std("ungrouped", "rec")[1], 0.25)

os.remove(str(output_test_dir))
Loading