Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Testing registrator #44

Merged
merged 9 commits into from
Dec 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions brainles_preprocessing/modality.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,11 @@ def register(
registered_log = os.path.join(registration_dir, f"{moving_image_name}.log")

registrator.register(
fixed_image=fixed_image_path,
moving_image=self.current,
fixed_image_path=fixed_image_path,
moving_image_path=self.current,
transformed_image=registered,
matrix=registered_matrix,
log_file=registered_log,
matrix_path=registered_matrix,
log_file_path=registered_log,
)
self.current = registered
return registered_matrix
Expand Down Expand Up @@ -133,11 +133,11 @@ def transform(
transformed_log = os.path.join(registration_dir, f"{moving_image_name}.log")

registrator.transform(
fixed_image=fixed_image_path,
moving_image=self.current,
transformed_image=transformed,
matrix=transformation_matrix,
log_file=transformed_log,
fixed_image_path=fixed_image_path,
moving_image_path=self.current,
transformed_image_path=transformed,
matrix_path=transformation_matrix,
log_file_path=transformed_log,
)
self.current = transformed

Expand Down
60 changes: 30 additions & 30 deletions brainles_preprocessing/registration/niftyreg.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,25 +42,25 @@ def __init__(

def register(
self,
fixed_image,
moving_image,
transformed_image,
matrix,
log_file,
fixed_image_path: str,
moving_image_path: str,
transformed_image_path: str,
matrix_path: str,
log_file_path: str,
):
"""
Register images using NiftyReg.

Args:
fixed_image (str): Path to the fixed image.
moving_image (str): Path to the moving image.
transformed_image (str): Path to the transformed image (output).
matrix (str): Path to the transformation matrix (output).
log_file (str): Path to the log file.
fixed_image_path (str): Path to the fixed image.
moving_image_path (str): Path to the moving image.
transformed_image_path (str): Path to the transformed image (output).
matrix_path (str): Path to the transformation matrix (output).
log_file_path (str): Path to the log file.
"""
runner = ScriptRunner(
script_path=self.registration_script,
log_path=log_file,
log_path=log_file_path,
)

niftyreg_executable = str(
Expand All @@ -69,10 +69,10 @@ def register(

input_params = [
turbopath(niftyreg_executable),
turbopath(fixed_image),
turbopath(moving_image),
turbopath(transformed_image),
turbopath(matrix),
turbopath(fixed_image_path),
turbopath(moving_image_path),
turbopath(transformed_image_path),
turbopath(matrix_path),
]

# Call the run method to execute the script and capture the output in the log file
Expand All @@ -85,25 +85,25 @@ def register(

def transform(
self,
fixed_image,
moving_image,
transformed_image,
matrix,
log_file,
fixed_image_path: str,
moving_image_path: str,
transformed_image_path: str,
matrix_path: str,
log_file_path: str,
):
"""
Apply a transformation using NiftyReg.

Args:
fixed_image (str): Path to the fixed image.
moving_image (str): Path to the moving image.
transformed_image (str): Path to the transformed image (output).
matrix (str): Path to the transformation matrix.
log_file (str): Path to the log file.
fixed_image_path (str): Path to the fixed image.
moving_image_path (str): Path to the moving image.
transformed_image_path (str): Path to the transformed image (output).
matrix_path (str): Path to the transformation matrix.
log_file_path (str): Path to the log file.
"""
runner = ScriptRunner(
script_path=self.transformation_script,
log_path=log_file,
log_path=log_file_path,
)

niftyreg_executable = str(
Expand All @@ -112,10 +112,10 @@ def transform(

input_params = [
turbopath(niftyreg_executable),
turbopath(fixed_image),
turbopath(moving_image),
turbopath(transformed_image),
turbopath(matrix),
turbopath(fixed_image_path),
turbopath(moving_image_path),
turbopath(transformed_image_path),
turbopath(matrix_path),
]

# Call the run method to execute the script and capture the output in the log file
Expand Down
57 changes: 42 additions & 15 deletions brainles_preprocessing/registration/registrator.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,56 @@
# TODO add typing and docs
from abc import ABC, abstractmethod
from typing import Any


class Registrator(ABC):
def __init__(self, backend):
self.backend = backend
# TODO probably the init here should be removed?
# def __init__(self, backend):
# self.backend = backend

@abstractmethod
def register(
self,
fixed_image,
moving_image,
transformed_image,
matrix,
log_file,
):
fixed_image_path: Any,
moving_image_path: Any,
transformed_image_path: Any,
matrix_path: Any,
log_file_path: str,
) -> None:
"""
Abstract method for registering images.

Args:
fixed_image_path (Any): The fixed image for registration.
moving_image_path (Any): The moving image to be registered.
transformed_image_path (Any): The resulting transformed image after registration.
matrix_path (Any): The transformation matrix applied during registration.
log_file_path (str): The path to the log file for recording registration details.

Returns:
None
"""
pass

@abstractmethod
def transform(
self,
fixed_image,
moving_image,
transformed_image,
matrix,
log_file,
):
fixed_image_path: Any,
moving_image_path: Any,
transformed_image_path: Any,
matrix: Any,
log_file: str,
) -> None:
"""
Abstract method for transforming images.

Args:
fixed_image_path (Any): The fixed image to be transformed.
moving_image_path (Any): The moving image to be transformed.
transformed_image_path (Any): The resulting transformed image.
matrix_path (Any): The transformation matrix applied during transformation.
log_file_path (str): The path to the log file for recording transformation details.

Returns:
None
"""
pass
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@ class TestHDBetExtractor(unittest.TestCase):
def setUp(self):
test_data_dir = turbopath(__file__).parent + "/test_data"
input_dir = test_data_dir + "/input"
self.output_dir = test_data_dir + "/temp_output"
self.output_dir = test_data_dir + "/temp_output_hdbet"
os.makedirs(self.output_dir, exist_ok=True)

self.brain_extractor = HDBetExtractor()

self.input_image_path = input_dir + "/tcia_example_t1c.nii.gz"
self.input_brain_mask_path = input_dir + "/bet_tcia_example_t1c_mask.nii.gz"

self.masked_image_path = self.output_dir + "/bet_tcia_example_t1c.nii.gz"
self.brain_mask_path = self.output_dir + "/bet_tcia_example_t1c_mask.nii.gz"
self.masked_again_image_path = (
Expand All @@ -30,9 +32,8 @@ def tearDown(self):
# Clean up created files if they exist
shutil.rmtree(self.output_dir)


def test_extract_creates_output_files(self):
# we try to run the fastest possible skullstripping on GPU
# we try to run the fastest possible skullstripping on CPU
self.brain_extractor.extract(
input_image_path=self.input_image_path,
masked_image_path=self.masked_image_path,
Expand Down
56 changes: 56 additions & 0 deletions tests/test_niftyreg_registrator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import os
import shutil
import unittest

from auxiliary.turbopath import turbopath

from brainles_preprocessing.registration.niftyreg import NiftyRegRegistrator


class TestNiftyRegRegistrator(unittest.TestCase):
def setUp(self):
test_data_dir = turbopath(__file__).parent + "/test_data"
input_dir = test_data_dir + "/input"
self.output_dir = test_data_dir + "/temp_output_niftyreg"
os.makedirs(self.output_dir, exist_ok=True)

self.registrator = NiftyRegRegistrator()

self.fixed_image = input_dir + "/tcia_example_t1c.nii.gz"
self.moving_image = input_dir + "/bet_tcia_example_t1c_mask.nii.gz"

self.transformed_image = self.output_dir + "/transformed_image.nii.gz"
self.matrix = self.output_dir + "/matrix.txt"
self.log_file = self.output_dir + "/registration.log"

def tearDown(self):
# Clean up created files if they exist
shutil.rmtree(self.output_dir)

def test_register_creates_output_files(self):
# we try to run the fastest possible skullstripping on GPU
self.registrator.register(
fixed_image_path=self.fixed_image,
moving_image_path=self.moving_image,
transformed_image_path=self.transformed_image,
matrix_path=self.matrix,
log_file_path=self.log_file,
)

self.assertTrue(
os.path.exists(self.transformed_image),
"transformed file was not created.",
)

self.assertTrue(
os.path.exists(self.matrix),
"matrix file was not created.",
)

self.assertTrue(
os.path.exists(self.log_file),
"log file was not created.",
)


# TODO also test transform