Skip to content

Commit

Permalink
Merge pull request #153 from BrainLesion/122
Browse files Browse the repository at this point in the history
added at least GPT docstrings for everything
  • Loading branch information
Hendrik-code authored Oct 25, 2024
2 parents e151b86 + dbbb54b commit ef36e5a
Show file tree
Hide file tree
Showing 15 changed files with 792 additions and 72 deletions.
30 changes: 30 additions & 0 deletions panoptica/_functionals.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,23 @@ def _get_paired_crop(
reference_arr: np.ndarray,
px_pad: int = 2,
):
"""Calculates a bounding box based on paired prediction and reference arrays.
This function combines the prediction and reference arrays, checks if they are identical,
and computes a bounding box around the non-zero regions. If both arrays are completely zero,
a small value is added to ensure the bounding box is valid.
Args:
prediction_arr (np.ndarray): The predicted segmentation array.
reference_arr (np.ndarray): The ground truth segmentation array.
px_pad (int, optional): Padding to apply around the bounding box. Defaults to 2.
Returns:
np.ndarray: The bounding box coordinates around the combined non-zero regions.
Raises:
AssertionError: If the prediction and reference arrays do not have the same shape.
"""
assert prediction_arr.shape == reference_arr.shape

combined = prediction_arr + reference_arr
Expand All @@ -150,6 +167,19 @@ def _get_paired_crop(


def _round_to_n(value: float | int, n_significant_digits: int = 2):
"""Rounds a number to a specified number of significant digits.
This function rounds the given value to the specified number of significant digits.
If the value is zero, it is returned unchanged.
Args:
value (float | int): The number to be rounded.
n_significant_digits (int, optional): The number of significant digits to round to.
Defaults to 2.
Returns:
float: The rounded value.
"""
return (
value
if value == 0
Expand Down
23 changes: 23 additions & 0 deletions panoptica/metrics/assd.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,29 @@ def _distance_transform_edt(
return_distances=True,
return_indices=False,
):
"""Computes the Euclidean distance transform and/or feature transform of a binary array.
This function calculates the Euclidean distance transform (EDT) of a binary array,
which gives the distance from each non-zero point to the nearest zero point. It can
also return the feature transform, which provides indices to the nearest non-zero point.
Args:
input_array (np.ndarray): The input binary array where non-zero values are considered
foreground.
sampling (optional): A sequence or array that specifies the spacing along each dimension.
If provided, scales the distances by the sampling value along each axis.
return_distances (bool, optional): If True, returns the distance transform. Default is True.
return_indices (bool, optional): If True, returns the feature transform with indices to
the nearest foreground points. Default is False.
Returns:
np.ndarray or tuple[np.ndarray, ...]: If `return_distances` is True, returns the distance
transform as an array. If `return_indices` is True, returns the feature transform. If both
are True, returns a tuple with the distance and feature transforms.
Raises:
ValueError: If the input array is empty or has unsupported dimensions.
"""
# calculate the feature transform
# input = np.atleast_1d(np.where(input, 1, 0).astype(np.int8))
# if sampling is not None:
Expand Down
72 changes: 63 additions & 9 deletions panoptica/metrics/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,25 @@

@dataclass
class _Metric:
"""A Metric class containing a name, whether higher or lower values is better, and a function to calculate that metric between two instances in an array"""
"""Represents a metric with a name, direction (increasing or decreasing), and a calculation function.
This class provides a framework for defining and calculating metrics, which can be used
to evaluate the similarity or performance between reference and prediction arrays.
The metric direction indicates whether higher or lower values are better.
Attributes:
name (str): Short name of the metric.
long_name (str): Full descriptive name of the metric.
decreasing (bool): If True, lower metric values are better; otherwise, higher values are preferred.
_metric_function (Callable): A callable function that computes the metric
between two input arrays.
Example:
>>> my_metric = _Metric(name="accuracy", long_name="Accuracy", decreasing=False, _metric_function=accuracy_function)
>>> score = my_metric(reference_array, prediction_array)
>>> print(score)
"""

name: str
long_name: str
Expand All @@ -36,6 +54,20 @@ def __call__(
*args,
**kwargs,
) -> int | float:
"""Calculates the metric between reference and prediction arrays.
Args:
reference_arr (np.ndarray): The reference array.
prediction_arr (np.ndarray): The prediction array.
ref_instance_idx (int, optional): The instance index to filter in the reference array.
pred_instance_idx (int | list[int], optional): Instance index or indices to filter in
the prediction array.
*args: Additional positional arguments for the metric function.
**kwargs: Additional keyword arguments for the metric function.
Returns:
int | float: The computed metric value.
"""
if ref_instance_idx is not None and pred_instance_idx is not None:
reference_arr = reference_arr.copy() == ref_instance_idx
if isinstance(pred_instance_idx, int):
Expand All @@ -60,15 +92,35 @@ def __repr__(self) -> str:
return str(self)

def __hash__(self) -> int:
"""Hash based on metric name, constrained to fit within 8 digits.
Returns:
int: The hash value of the metric.
"""
return abs(hash(self.name)) % (10**8)

@property
def increasing(self):
"""Indicates if higher values of the metric are better.
Returns:
bool: True if increasing values are preferred, otherwise False.
"""
return not self.decreasing

def score_beats_threshold(
self, matching_score: float, matching_threshold: float
) -> bool:
"""Determines if a matching score meets a specified threshold.
Args:
matching_score (float): The score to evaluate.
matching_threshold (float): The threshold value to compare against.
Returns:
bool: True if the score meets the threshold, taking into account the
metric's preferred direction.
"""
return (self.increasing and matching_score >= matching_threshold) or (
self.decreasing and matching_score <= matching_threshold
)
Expand Down Expand Up @@ -206,6 +258,16 @@ def __init__(self, *args: object) -> None:


class Evaluation_Metric:
"""This represents a metric in the evaluation derived from other metrics or list metrics (no circular dependancies!)
Args:
name_id (str): code-name of this metric, must be same as the member variable of PanopticResult
calc_func (Callable): the function to calculate this metric based on the PanopticResult object
long_name (str | None, optional): A longer descriptive name for printing/logging purposes. Defaults to None.
was_calculated (bool, optional): Whether this metric has been calculated or not. Defaults to False.
error (bool, optional): If true, means the metric could not have been calculated (because dependancies do not exist or have this flag set to True). Defaults to False.
"""

def __init__(
self,
name_id: str,
Expand All @@ -215,15 +277,7 @@ def __init__(
was_calculated: bool = False,
error: bool = False,
):
"""This represents a metric in the evaluation derived from other metrics or list metrics (no circular dependancies!)

Args:
name_id (str): code-name of this metric, must be same as the member variable of PanopticResult
calc_func (Callable): the function to calculate this metric based on the PanopticResult object
long_name (str | None, optional): A longer descriptive name for printing/logging purposes. Defaults to None.
was_calculated (bool, optional): Whether this metric has been calculated or not. Defaults to False.
error (bool, optional): If true, means the metric could not have been calculated (because dependancies do not exist or have this flag set to True). Defaults to False.
"""
self.id = name_id
self.metric_type = metric_type
self._calc_func = calc_func
Expand Down
80 changes: 66 additions & 14 deletions panoptica/panoptica_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,11 @@

#
class Panoptica_Aggregator:
# internal_list_lock = Lock()
#
"""Aggregator that calls evaluations and saves the resulting metrics per sample. Can be used to create statistics, ..."""
"""Aggregator that manages evaluations and saves resulting metrics per sample.
This class interfaces with the `Panoptica_Evaluator` to perform evaluations,
store results, and manage file outputs for statistical analysis.
"""

def __init__(
self,
Expand All @@ -41,10 +43,18 @@ def __init__(
log_times: bool = False,
continue_file: bool = True,
):
"""
"""Initializes the Panoptica_Aggregator.
Args:
panoptica_evaluator (Panoptica_Evaluator): The Panoptica_Evaluator used for the pipeline.
output_file (Path | None, optional): If given, will stream the sample results into this file. If the file is existent, will append results if not already there. Defaults to None.
panoptica_evaluator (Panoptica_Evaluator): The evaluator used for performing evaluations.
output_file (Path | str): Path to the output file for storing results. If the file exists,
results will be appended. If it doesn't, a new file will be created.
log_times (bool, optional): If True, computation times will be logged. Defaults to False.
continue_file (bool, optional): If True, results will continue from existing entries in the file.
Defaults to True.
Raises:
AssertionError: If the output directory does not exist or if the file extension is not `.tsv`.
"""
self.__panoptica_evaluator = panoptica_evaluator
self.__class_group_names = panoptica_evaluator.segmentation_class_groups_names
Expand Down Expand Up @@ -115,10 +125,16 @@ def __init__(
atexit.register(self.__exist_handler)

def __exist_handler(self):
"""Handles cleanup upon program exit by removing the temporary output buffer file."""
if self.__output_buffer_file is not None and self.__output_buffer_file.exists():
os.remove(str(self.__output_buffer_file))

def make_statistic(self) -> Panoptica_Statistic:
"""Generates statistics from the aggregated evaluation results.
Returns:
Panoptica_Statistic: The statistics object containing the results.
"""
with filelock:
obj = Panoptica_Statistic.from_file(self.__output_file)
return obj
Expand All @@ -129,14 +145,16 @@ def evaluate(
reference_arr: np.ndarray,
subject_name: str,
):
"""Evaluates one case
"""Evaluates a single case using the provided prediction and reference arrays.
Args:
prediction_arr (np.ndarray): Prediction array
reference_arr (np.ndarray): reference array
subject_name (str | None, optional): Unique name of the sample. If none, will give it a name based on count. Defaults to None.
skip_already_existent (bool): If true, will skip subjects which were already evaluated instead of crashing. Defaults to False.
verbose (bool | None, optional): Verbose. Defaults to None.
prediction_arr (np.ndarray): The array containing the predicted segmentation.
reference_arr (np.ndarray): The array containing the ground truth segmentation.
subject_name (str): A unique name for the sample being evaluated. If none is provided,
a name will be generated based on the count.
Raises:
ValueError: If the subject name has already been evaluated or is in process.
"""
# Read tmp file to see which sample names are blocked
with inevalfilelock:
Expand Down Expand Up @@ -164,6 +182,12 @@ def evaluate(
self._save_one_subject(subject_name, res)

def _save_one_subject(self, subject_name, result_grouped):
"""Saves the evaluation results for a single subject.
Args:
subject_name (str): The name of the subject whose results are being saved.
result_grouped (dict): A dictionary of grouped results from the evaluation.
"""
with filelock:
#
content = [subject_name]
Expand All @@ -186,9 +210,19 @@ def panoptica_evaluator(self):


def _read_first_row(file: str | Path):
"""Reads the first row of a TSV file.
NOT THREAD SAFE BY ITSELF!
Args:
file (str | Path): The path to the file from which to read the first row.
Returns:
list: The first row of the file as a list of strings.
"""
if isinstance(file, Path):
file = str(file)
# NOT THREAD SAFE BY ITSELF!
#
with open(str(file), "r", encoding="utf8", newline="") as tsvfile:
rd = csv.reader(tsvfile, delimiter="\t", lineterminator="\n")

Expand All @@ -202,7 +236,19 @@ def _read_first_row(file: str | Path):


def _load_first_column_entries(file: str | Path):
# NOT THREAD SAFE BY ITSELF!
"""Loads the entries from the first column of a TSV file.
NOT THREAD SAFE BY ITSELF!
Args:
file (str | Path): The path to the file from which to load entries.
Returns:
list: A list of entries from the first column of the file.
Raises:
AssertionError: If the file contains duplicate entries.
"""
if isinstance(file, Path):
file = str(file)
with open(str(file), "r", encoding="utf8", newline="") as tsvfile:
Expand All @@ -221,6 +267,12 @@ def _load_first_column_entries(file: str | Path):


def _write_content(file: str | Path, content: list[list[str]]):
"""Writes content to a TSV file.
Args:
file (str | Path): The path to the file where content will be written.
content (list[list[str]]): A list of lists containing the rows of data to write.
"""
if isinstance(file, Path):
file = str(file)
# NOT THREAD SAFE BY ITSELF!
Expand Down
Loading

0 comments on commit ef36e5a

Please sign in to comment.