Skip to content

Commit

Permalink
increased panoptica statistics utility
Browse files Browse the repository at this point in the history
  • Loading branch information
Hendrik-code committed Dec 12, 2024
1 parent 17e53dc commit 03ddcc7
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 103 deletions.
55 changes: 17 additions & 38 deletions panoptica/panoptica_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,7 @@ def from_file(cls, file: str):
rows = [row for row in rd]

header = rows[0]
assert (
header[0] == "subject_name"
), "First column is not subject_names, something wrong with the file?"
assert header[0] == "subject_name", "First column is not subject_names, something wrong with the file?"

keys_in_order = list([tuple(c.split("-")) for c in header[1:]])
metric_names = []
Expand Down Expand Up @@ -129,19 +127,13 @@ def from_file(cls, file: str):
return Panoptica_Statistic(subj_names=subj_names, value_dict=value_dict)

def _assertgroup(self, group):
assert (
group in self.__groupnames
), f"group {group} not existent, only got groups {self.__groupnames}"
assert group in self.__groupnames, f"group {group} not existent, only got groups {self.__groupnames}"

def _assertmetric(self, metric):
assert (
metric in self.__metricnames
), f"metric {metric} not existent, only got metrics {self.__metricnames}"
assert metric in self.__metricnames, f"metric {metric} not existent, only got metrics {self.__metricnames}"

def _assertsubject(self, subjectname):
assert (
subjectname in self.__subj_names
), f"subject {subjectname} not in list of subjects, got {self.__subj_names}"
assert subjectname in self.__subj_names, f"subject {subjectname} not in list of subjects, got {self.__subj_names}"

def get(self, group, metric, remove_nones: bool = False) -> list[float]:
"""Returns the list of values for given group and metric
Expand Down Expand Up @@ -174,10 +166,7 @@ def get_one_subject(self, subjectname: str):
"""
self._assertsubject(subjectname)
sidx = self.__subj_names.index(subjectname)
return {
g: {m: self.get(g, m)[sidx] for m in self.__metricnames}
for g in self.__groupnames
}
return {g: {m: self.get(g, m)[sidx] for m in self.__metricnames} for g in self.__groupnames}

def get_across_groups(self, metric) -> list[float]:
"""Given metric, gives list of all values (even across groups!) Treat with care!
Expand Down Expand Up @@ -206,13 +195,8 @@ def get_summary_across_groups(self) -> dict[str, ValueSummary]:
summary_dict[m] = ValueSummary(value_list)
return summary_dict

def get_summary_dict(
self, include_across_group: bool = True
) -> dict[str, dict[str, ValueSummary]]:
summary_dict = {
g: {m: self.get_summary(g, m) for m in self.__metricnames}
for g in self.__groupnames
}
def get_summary_dict(self, include_across_group: bool = True) -> dict[str, dict[str, ValueSummary]]:
summary_dict = {g: {m: self.get_summary(g, m) for m in self.__metricnames} for g in self.__groupnames}
if include_across_group:
summary_dict["across_groups"] = self.get_summary_across_groups()
return summary_dict
Expand Down Expand Up @@ -257,10 +241,7 @@ def get_summary_figure(
_type_: _description_
"""
orientation = "h" if horizontal else "v"
data_plot = {
g: np.asarray(self.get(g, metric, remove_nones=True))
for g in self.__groupnames
}
data_plot = {g: np.asarray(self.get(g, metric, remove_nones=True)) for g in self.__groupnames}
if manual_metric_range is not None:
assert manual_metric_range[0] < manual_metric_range[1], manual_metric_range
change = (manual_metric_range[1] - manual_metric_range[0]) / 100
Expand Down Expand Up @@ -293,6 +274,7 @@ def make_curve_over_setups(
fig: None = None,
plot_dotsize: int | None = None,
plot_lines: bool = True,
plot_std: bool = False,
):
if groups is None:
groups = list(statistics_dict.values())[0].groupnames
Expand All @@ -303,9 +285,7 @@ def make_curve_over_setups(
alternate_groupnames = [alternate_groupnames]
#
for setupname, stat in statistics_dict.items():
assert (
metric in stat.metricnames
), f"metric {metric} not in statistic obj {setupname}"
assert metric in stat.metricnames, f"metric {metric} not in statistic obj {setupname}"

setupnames = list(statistics_dict.keys())
convert_x_to_digit = True
Expand All @@ -330,18 +310,19 @@ def make_curve_over_setups(
plt.grid("major")
# Y values are average metric values in that group and metric
for idx, g in enumerate(groups):
Y = [
ValueSummary(stat.get(g, metric, remove_nones=True)).avg
for stat in statistics_dict.values()
]
Y = [ValueSummary(stat.get(g, metric, remove_nones=True)).avg for stat in statistics_dict.values()]
Ystd = [ValueSummary(stat.get(g, metric, remove_nones=True)).std for stat in statistics_dict.values()]

if plot_lines:
plt.plot(
p = plt.plot(
X,
Y,
label=g if alternate_groupnames is None else alternate_groupnames[idx],
)

if plot_std:
plt.fill_between(X, np.subtract(Y, Ystd), np.add(Y, Ystd), alpha=0.25, edgecolor=p[-1].get_color())

if plot_dotsize is not None:
plt.scatter(X, Y, s=plot_dotsize)

Expand Down Expand Up @@ -380,9 +361,7 @@ def plot_box(
if sort:
df_by_spec_count = df_data.groupby(name_method).mean()
df_by_spec_count = dict(df_by_spec_count[name_metric].items())
df_data["mean"] = df_data[name_method].apply(
lambda x: df_by_spec_count[x] * (1 if orientation == "h" else -1)
)
df_data["mean"] = df_data[name_method].apply(lambda x: df_by_spec_count[x] * (1 if orientation == "h" else -1))
df_data = df_data.sort_values(by="mean")
if orientation == "v":
fig = px.strip(
Expand Down
84 changes: 19 additions & 65 deletions panoptica/utils/processing_pair.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,7 @@ class _ProcessingPair(ABC):
_pred_labels: tuple[int, ...]
n_dim: int

def __init__(
self, prediction_arr: np.ndarray, reference_arr: np.ndarray, dtype: type | None
) -> None:
def __init__(self, prediction_arr: np.ndarray, reference_arr: np.ndarray, dtype: type | None) -> None:
"""Initializes the processing pair with prediction and reference arrays.
Args:
Expand All @@ -48,12 +46,8 @@ def __init__(
self._reference_arr = reference_arr
self.dtype = dtype
self.n_dim = reference_arr.ndim
self._ref_labels: tuple[int, ...] = tuple(
_unique_without_zeros(reference_arr)
) # type:ignore
self._pred_labels: tuple[int, ...] = tuple(
_unique_without_zeros(prediction_arr)
) # type:ignore
self._ref_labels: tuple[int, ...] = tuple(_unique_without_zeros(reference_arr)) # type:ignore
self._pred_labels: tuple[int, ...] = tuple(_unique_without_zeros(prediction_arr)) # type:ignore
self.crop: tuple[slice, ...] = None
self.is_cropped: bool = False
self.uncropped_shape: tuple[int, ...] = reference_arr.shape
Expand All @@ -75,13 +69,7 @@ def crop_data(self, verbose: bool = False):

self._prediction_arr = self._prediction_arr[self.crop]
self._reference_arr = self._reference_arr[self.crop]
(
print(
f"-- Cropped from {self.uncropped_shape} to {self._prediction_arr.shape}"
)
if verbose
else None
)
(print(f"-- Cropped from {self.uncropped_shape} to {self._prediction_arr.shape}") if verbose else None)
self.is_cropped = True

def uncrop_data(self, verbose: bool = False):
Expand All @@ -92,22 +80,14 @@ def uncrop_data(self, verbose: bool = False):
"""
if self.is_cropped == False:
return
assert (
self.uncropped_shape is not None
), "Calling uncrop_data() without having cropped first"
assert self.uncropped_shape is not None, "Calling uncrop_data() without having cropped first"
prediction_arr = np.zeros(self.uncropped_shape)
prediction_arr[self.crop] = self._prediction_arr
self._prediction_arr = prediction_arr

reference_arr = np.zeros(self.uncropped_shape)
reference_arr[self.crop] = self._reference_arr
(
print(
f"-- Uncropped from {self._reference_arr.shape} to {self.uncropped_shape}"
)
if verbose
else None
)
(print(f"-- Uncropped from {self._reference_arr.shape} to {self.uncropped_shape}") if verbose else None)
self._reference_arr = reference_arr
self.is_cropped = False

Expand All @@ -117,9 +97,7 @@ def set_dtype(self, type):
Args:
dtype (type): Expected integer type for the arrays.
"""
assert np.issubdtype(
type, int_type
), "set_dtype: tried to set dtype to something other than integers"
assert np.issubdtype(type, int_type), "set_dtype: tried to set dtype to something other than integers"
self._prediction_arr = self._prediction_arr.astype(type)
self._reference_arr = self._reference_arr.astype(type)

Expand Down Expand Up @@ -211,9 +189,7 @@ def copy(self):
) # type:ignore


def _check_array_integrity(
prediction_arr: np.ndarray, reference_arr: np.ndarray, dtype: type | None = None
):
def _check_array_integrity(prediction_arr: np.ndarray, reference_arr: np.ndarray, dtype: type | None = None):
"""Validates integrity between two arrays, checking shape, dtype, and consistency with `dtype`.
Args:
Expand All @@ -234,12 +210,8 @@ def _check_array_integrity(
assert isinstance(prediction_arr, np.ndarray) and isinstance(
reference_arr, np.ndarray
), "prediction and/or reference are not numpy arrays"
assert (
prediction_arr.shape == reference_arr.shape
), f"shape mismatch, got {prediction_arr.shape},{reference_arr.shape}"
assert (
prediction_arr.dtype == reference_arr.dtype
), f"dtype mismatch, got {prediction_arr.dtype},{reference_arr.dtype}"
assert prediction_arr.shape == reference_arr.shape, f"shape mismatch, got {prediction_arr.shape},{reference_arr.shape}"
# assert prediction_arr.dtype == reference_arr.dtype, f"dtype mismatch, got {prediction_arr.dtype},{reference_arr.dtype}"
if dtype is not None:
assert (
np.issubdtype(prediction_arr.dtype, dtype)
Expand Down Expand Up @@ -331,15 +303,11 @@ def __init__(
self.matched_instances = matched_instances

if missed_reference_labels is None:
missed_reference_labels = list(
[i for i in self._ref_labels if i not in self._pred_labels]
)
missed_reference_labels = list([i for i in self._ref_labels if i not in self._pred_labels])
self.missed_reference_labels = missed_reference_labels

if missed_prediction_labels is None:
missed_prediction_labels = list(
[i for i in self._pred_labels if i not in self._ref_labels]
)
missed_prediction_labels = list([i for i in self._pred_labels if i not in self._ref_labels])
self.missed_prediction_labels = missed_prediction_labels

@property
Expand Down Expand Up @@ -412,9 +380,7 @@ class InputType(_Enum_Compare):
UNMATCHED_INSTANCE = UnmatchedInstancePair
MATCHED_INSTANCE = MatchedInstancePair

def __call__(
self, prediction_arr: np.ndarray, reference_arr: np.ndarray
) -> _ProcessingPair:
def __call__(self, prediction_arr: np.ndarray, reference_arr: np.ndarray) -> _ProcessingPair:
return self.value(prediction_arr, reference_arr)


Expand All @@ -432,9 +398,7 @@ def __init__(self, original_input: _ProcessingPair | None):
self._original_input = original_input
self._intermediatesteps: dict[str, _ProcessingPair] = {}

def add_intermediate_arr_data(
self, processing_pair: _ProcessingPair, inputtype: InputType
):
def add_intermediate_arr_data(self, processing_pair: _ProcessingPair, inputtype: InputType):
type_name = inputtype.name
self.add_intermediate_data(type_name, processing_pair)

Expand All @@ -444,36 +408,26 @@ def add_intermediate_data(self, key, value):

@property
def original_prediction_arr(self):
assert (
self._original_input is not None
), "Original prediction_arr is None, there are no intermediate steps"
assert self._original_input is not None, "Original prediction_arr is None, there are no intermediate steps"
return self._original_input.prediction_arr

@property
def original_reference_arr(self):
assert (
self._original_input is not None
), "Original reference_arr is None, there are no intermediate steps"
assert self._original_input is not None, "Original reference_arr is None, there are no intermediate steps"
return self._original_input.reference_arr

def prediction_arr(self, inputtype: InputType):
type_name = inputtype.name
procpair = self[type_name]
assert isinstance(
procpair, _ProcessingPair
), f"step {type_name} is not a processing pair, error"
assert isinstance(procpair, _ProcessingPair), f"step {type_name} is not a processing pair, error"
return procpair.prediction_arr

def reference_arr(self, inputtype: InputType):
type_name = inputtype.name
procpair = self[type_name]
assert isinstance(
procpair, _ProcessingPair
), f"step {type_name} is not a processing pair, error"
assert isinstance(procpair, _ProcessingPair), f"step {type_name} is not a processing pair, error"
return procpair.reference_arr

def __getitem__(self, key):
assert (
key in self._intermediatesteps
), f"key {key} not in intermediate steps, maybe the step was skipped?"
assert key in self._intermediatesteps, f"key {key} not in intermediate steps, maybe the step was skipped?"
return self._intermediatesteps[key]

0 comments on commit 03ddcc7

Please sign in to comment.