Skip to content

Commit

Permalink
Merge pull request #32 from BrainLesion/29-add-unit-tests-for-core-fu…
Browse files Browse the repository at this point in the history
…nctionality

29 add unit tests for core functionality
  • Loading branch information
MarcelRosier authored Aug 29, 2024
2 parents 6bd0c3c + 7f05c38 commit 10410ba
Show file tree
Hide file tree
Showing 16 changed files with 1,047 additions and 164 deletions.
6 changes: 3 additions & 3 deletions brats/core/brats_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from loguru import logger

from brats.core.docker import run_docker
from brats.core.docker import run_container
from brats.utils.algorithm_config import load_algorithms
from brats.utils.constants import OUTPUT_NAME_SCHEMA, Algorithms, Task
from brats.utils.data_handling import InferenceSetup
Expand Down Expand Up @@ -129,7 +129,7 @@ def _infer_single(
inputs=inputs,
)

run_docker(
run_container(
algorithm=self.algorithm,
data_path=tmp_data_folder,
output_path=tmp_output_folder,
Expand Down Expand Up @@ -172,7 +172,7 @@ def _infer_batch(
logger.info(f"Standardized input names to match algorithm requirements.")

# run inference in container
run_docker(
run_container(
algorithm=self.algorithm,
data_path=tmp_data_folder,
output_path=tmp_output_folder,
Expand Down
15 changes: 6 additions & 9 deletions brats/core/docker.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import os
import shutil
import subprocess
import time
from pathlib import Path
Expand All @@ -19,7 +18,7 @@
AlgorithmNotCPUCompatibleException,
BraTSContainerException,
)
from brats.utils.zenodo import check_model_weights, get_dummy_weights_path
from brats.utils.zenodo import check_additional_files_path, get_dummy_path

try:
client = docker.from_env()
Expand All @@ -41,7 +40,7 @@ def _show_docker_pull_progress(tasks: Dict, progress: Progress, line: Dict):
if line["status"] == "Downloading":
task_key = f'[Download {line["id"]}]'
elif line["status"] == "Extracting":
task_key = f'[Extract {line["id"]}]'
task_key = f'[Extract {line["id"]}]'
else:
return

Expand Down Expand Up @@ -102,7 +101,6 @@ def _handle_device_requests(
if cuda_available
else "No Cuda installation/ GPU was found and"
)
# TODO add reference to table of cpu capable algos as help!
raise AlgorithmNotCPUCompatibleException(
f"{cause} the chosen algorithm is not CPU-compatible. Aborting..."
)
Expand All @@ -125,12 +123,11 @@ def _get_additional_files_path(algorithm: AlgorithmData) -> Path:
Path to the additional files
"""
# ensure weights are present and get path
# TODO refactor this rename weights to additional files
if algorithm.weights is not None:
return check_model_weights(record_id=algorithm.weights.record_id)
return check_additional_files_path(record_id=algorithm.weights.record_id)
else:
# if no weights are directly specified a dummy weights folder will be mounted
return get_dummy_weights_path()
return get_dummy_path()


def _get_volume_mappings(
Expand Down Expand Up @@ -195,7 +192,7 @@ def _build_args(
command_args, extra_args (Tuple): The command arguments and extra arguments
"""
# Build command that will be run in the docker container
command_args = f"--data_path=/mlcube_io0 --output_path=/mlcube_io2"
command_args = f"--data_path=/mlcube_io0 --output_path=/mlcube_io2"
if algorithm.weights is not None:
weights_arg = f"--{algorithm.weights.param_name}=/mlcube_io1"
if algorithm.weights.checkpoint_path:
Expand Down Expand Up @@ -285,7 +282,7 @@ def _log_algorithm_info(algorithm: AlgorithmData):
logger.debug(f"Docker image: {algorithm.run_args.docker_image}")


def run_docker(
def run_container(
algorithm: AlgorithmData,
data_path: Path,
output_path: Path,
Expand Down
2 changes: 1 addition & 1 deletion brats/data/meta/africa.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ algorithms:
BraTS23_1:
meta:
authors: Andriy Myronenko, et al.
paper: TODO
paper: N/A
challenge: BraTS23 BraTS-Africa Segmentation
rank: 1st
year: 2023
Expand Down
8 changes: 4 additions & 4 deletions brats/utils/data_handling.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from __future__ import annotations

from contextlib import contextmanager
import shutil
import sys
from pathlib import Path
import tempfile
from typing import Dict, Generator, List, Optional, Tuple
from contextlib import contextmanager
from pathlib import Path
from typing import Generator, Optional, Tuple

import nibabel as nib
from loguru import logger
Expand Down Expand Up @@ -90,6 +89,7 @@ def input_sanity_check(
t2w (Path | str, optional): T2w image path (required for segmentation)
mask (Path | str, optional): Mask image path (required for inpainting)
"""

# Filter out None values to only include provided images
images = {
"t1n": t1n,
Expand Down
45 changes: 26 additions & 19 deletions brats/utils/zenodo.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,20 @@
from brats.utils.constants import ADDITIONAL_FILES_FOLDER, ZENODO_RECORD_BASE_URL


def get_dummy_weights_path() -> Path:
def get_dummy_path() -> Path:
dummy = ADDITIONAL_FILES_FOLDER / "dummy"
dummy.mkdir(exist_ok=True, parents=True)
return dummy


def check_model_weights(record_id: str) -> Path:
"""Check if latest model weights are present and download them otherwise.
def check_additional_files_path(record_id: str) -> Path:
"""Check if latest additional files are present and download them otherwise.
Args:
record_id (str): Zenodo record ID.
Returns:
Path: Path to the model weights folder.
Path: Path to the additional files folder.
"""

zenodo_metadata, archive_url = _get_zenodo_metadata_and_archive_url(
Expand All @@ -44,7 +47,7 @@ def check_model_weights(record_id: str) -> Path:
sys.exit()
logger.info(f"Model weights not found locally")

return _download_model_weights(
return _download_additional_files(
zenodo_metadata=zenodo_metadata,
record_id=record_id,
archive_url=archive_url,
Expand Down Expand Up @@ -75,7 +78,7 @@ def check_model_weights(record_id: str) -> Path:
f"Failed to delete {path}: {excinfo}"
),
)
return _download_model_weights(
return _download_additional_files(
zenodo_metadata=zenodo_metadata, record_id=record_id, archive_url=archive_url
)

Expand Down Expand Up @@ -123,24 +126,24 @@ def _get_zenodo_metadata_and_archive_url(record_id: str) -> Dict | None:
return None


def _download_model_weights(
def _download_additional_files(
zenodo_metadata: Dict, record_id: str, archive_url: str
) -> Path:
"""Download the latest model weights from Zenodo for the requested record and extract them to the target folder.
"""Download the latest additional files from Zenodo for the requested record and extract them to the target folder.
Args:
ADDITIONAL_FILES_FOLDER (Path): General weights folder path in which the requested model weights will be stored.
zenodo_metadata (Dict): Metadata for the Zenodo record.
record_id (str): Zenodo record ID.
archive_url (str): URL to the archive file.
Returns:
Path: Path to the model weights folder for the requested record.
Path: Path to the additional files folder for the requested record.
"""
record_ADDITIONAL_FILES_FOLDER = (
record_folder = (
ADDITIONAL_FILES_FOLDER / f"{record_id}_v{zenodo_metadata['version']}"
)
# ensure folder exists
record_ADDITIONAL_FILES_FOLDER.mkdir(parents=True, exist_ok=True)
record_folder.mkdir(parents=True, exist_ok=True)

logger.info(f"Downloading model weights from Zenodo. This might take a while...")
# Make a GET request to the URL
Expand All @@ -152,14 +155,21 @@ def _download_model_weights(
)
return

_extract_archive(response=response, record_folder=record_folder)

logger.info(f"Zip file extracted successfully to {record_folder}")
return record_folder


def _extract_archive(response: requests.Response, record_folder: Path):
# Download with progress bar
chunk_size = 1024 # 1KB
bytes_io = BytesIO()

with Progress(
SpinnerColumn(),
TextColumn("[cyan]Downloading weights..."),
TextColumn("{task.completed:.2f} MB"),
TextColumn("[cyan]{task.completed:.2f} MB"),
transient=True,
) as progress:
task = progress.add_task("", total=None) # Indeterminate progress
Expand All @@ -172,10 +182,10 @@ def _download_model_weights(

# Extract the downloaded zip file to the target folder
with zipfile.ZipFile(bytes_io) as zip_ref:
zip_ref.extractall(record_ADDITIONAL_FILES_FOLDER)
zip_ref.extractall(record_folder)

# check if the extracted file is still a zip
for f in record_ADDITIONAL_FILES_FOLDER.iterdir():
for f in record_folder.iterdir():
if f.is_file() and f.suffix == ".zip":
with zipfile.ZipFile(f) as zip_ref:
files = zip_ref.namelist()
Expand All @@ -185,10 +195,7 @@ def _download_model_weights(
)
# Iterate over the files and extract them
for i, file in enumerate(files):
zip_ref.extract(file, record_ADDITIONAL_FILES_FOLDER)
zip_ref.extract(file, record_folder)
# Update the progress bar
progress.update(task, completed=i + 1)
f.unlink() # remove zip after extraction

logger.info(f"Zip file extracted successfully to {record_ADDITIONAL_FILES_FOLDER}")
return record_ADDITIONAL_FILES_FOLDER
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ nibabel = ">=5.0.0"

[tool.poetry.dev-dependencies]
pytest = ">=8.0.0"
pytest-cov = ">=5.0.0"

[tool.poetry.group.docs]
optional = true
Expand Down
Empty file added tests/core/__init__.py
Empty file.
109 changes: 109 additions & 0 deletions tests/core/test_brats_algorithm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import unittest
from unittest.mock import MagicMock, patch
from pathlib import Path
import tempfile
import shutil

from brats import AdultGliomaSegmenter
from brats.utils.constants import OUTPUT_NAME_SCHEMA


class TestBraTSAlgorithm(unittest.TestCase):

def setUp(self):
# Create a temporary directory for testing
self.test_dir = Path(tempfile.mkdtemp())
self.data_folder = self.test_dir / "data"
self.data_folder.mkdir(parents=True, exist_ok=True)
self.output_folder = self.test_dir / "output"
self.output_folder.mkdir(parents=True, exist_ok=True)

self.subject_A_folder = self.data_folder / "A"
self.subject_A_folder.mkdir(parents=True, exist_ok=True)
# Create mock file paths
self.input_files = {
"t1c": self.subject_A_folder / "A-t1c.nii.gz",
"t1n": self.subject_A_folder / "A-t1n.nii.gz",
"t2f": self.subject_A_folder / "A-t2f.nii.gz",
"t2w": self.subject_A_folder / "A-t2w.nii.gz",
}
for file in self.input_files.values():
file.touch()

# the core inference method is the same for all segmentation and inpainting algorithms, we use AdultGliomaSegmenter as an example during testing
self.segmenter = AdultGliomaSegmenter()

def tearDown(self):
# Remove the temporary directory after the test
shutil.rmtree(self.test_dir)

@patch("brats.core.brats_algorithm.run_container")
@patch("brats.core.segmentation_algorithms.input_sanity_check")
@patch("brats.core.brats_algorithm.InferenceSetup")
def test_infer_single(
self, mock_inference_setup, mock_input_sanity_check, mock_run_container
):

# Mock InferenceSetup context manager
mock_inference_setup_ret = mock_inference_setup.return_value
mock_inference_setup_ret.__enter__.return_value = (
self.data_folder,
self.output_folder,
)

def create_output_file(*args, **kwargs):
subject_id = self.segmenter.algorithm.run_args.input_name_schema.format(
id=0
)
alg_output_file = self.output_folder / OUTPUT_NAME_SCHEMA[
self.segmenter.task
].format(subject_id=subject_id)
alg_output_file.touch()

mock_run_container.side_effect = create_output_file

output_file = self.output_folder / "output.nii.gz"
self.segmenter.infer_single(
t1c=self.input_files["t1c"],
t1n=self.input_files["t1n"],
t2f=self.input_files["t2f"],
t2w=self.input_files["t2w"],
output_file=output_file,
)
mock_input_sanity_check.assert_called_once()
mock_run_container.assert_called_once()

self.assertTrue(output_file.exists())

@patch("brats.core.brats_algorithm.run_container")
@patch("brats.core.segmentation_algorithms.input_sanity_check")
@patch("brats.core.brats_algorithm.InferenceSetup")
def test_infer_batch(
self, mock_inference_setup, mock_input_sanity_check, mock_run_container
):

# Mock InferenceSetup context manager
mock_inference_setup_ret = mock_inference_setup.return_value
mock_inference_setup_ret.__enter__.return_value = (
self.data_folder,
self.output_folder,
)

def create_output_file(*args, **kwargs):
subject_id = self.segmenter.algorithm.run_args.input_name_schema.format(
id=0
)
alg_output_file = self.output_folder / OUTPUT_NAME_SCHEMA[
self.segmenter.task
].format(subject_id=subject_id)
alg_output_file.touch()

mock_run_container.side_effect = create_output_file

self.segmenter.infer_batch(
data_folder=self.data_folder, output_folder=self.output_folder
)
mock_input_sanity_check.assert_called_once()
mock_run_container.assert_called_once()
output_file = self.output_folder / "A.nii.gz"
self.assertTrue(output_file.exists())
Loading

0 comments on commit 10410ba

Please sign in to comment.