From af24af89ff042c21a816574a673d8f83681c5750 Mon Sep 17 00:00:00 2001 From: Julien Marabotto Date: Thu, 11 Jul 2024 12:10:09 +0200 Subject: [PATCH] enh: Further additions to test_vis, output path for figures added to conftest --- .circleci/config.yml | 3 +- nitransforms/conftest.py | 8 +++ nitransforms/io/afni.py | 2 +- nitransforms/io/fsl.py | 2 +- nitransforms/io/itk.py | 2 +- nitransforms/tests/test_vis.py | 127 +++++++++++++++++++++++++++------ nitransforms/vis.py | 44 +++++++----- 7 files changed, 144 insertions(+), 44 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 47b0e00e..dbbf54db 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -142,7 +142,8 @@ jobs: -v /tmp/data/nitransforms-tests:/data -e TEST_DATA_HOME=/data \ -e COVERAGE_FILE=/tmp/summaries/.pytest.coverage \ -v /tmp/fslicense/license.txt:/opt/freesurfer/license.txt:ro \ - -v /tmp/tests:/tmp nitransforms:latest \ + -v /tmp/tests:/tmp -e TEST_OUTPUT_DIR=/tmp/artifacts \ + nitransforms:latest \ pytest --junit-xml=/tmp/summaries/pytest.xml \ --cov nitransforms --cov-report xml:/tmp/summaries/unittests.xml \ nitransforms/ diff --git a/nitransforms/conftest.py b/nitransforms/conftest.py index 854cac43..3bdab782 100644 --- a/nitransforms/conftest.py +++ b/nitransforms/conftest.py @@ -9,6 +9,8 @@ _data = None _brainmask = None _testdir = Path(os.getenv("TEST_DATA_HOME", "~/.nitransforms/testdata")).expanduser() +_outdir = os.getenv("TEST_OUTPUT_DIR", None) + _datadir = Path(__file__).parent / "tests" / "data" @@ -45,6 +47,12 @@ def testdata_path(): return _testdir +@pytest.fixture +def output_path(): + """Return an output folder.""" + return Path(_outdir) if _outdir is not None else None + + @pytest.fixture def get_testdata(): """Generate data in the requested orientation.""" diff --git a/nitransforms/io/afni.py b/nitransforms/io/afni.py index 7c66d434..fb27eda6 100644 --- a/nitransforms/io/afni.py +++ b/nitransforms/io/afni.py @@ -198,7 +198,7 @@ def from_image(cls, imgobj): hdr = imgobj.header.copy() shape = hdr.get_data_shape() - if len(shape) != 5 or shape[-2] != 1 or not shape[-1] in (2, 3): + if len(shape) != 5 or shape[-2] != 1 or shape[-1] not in (2, 3): raise TransformFileError( 'Displacements field "%s" does not come from AFNI.' % imgobj.file_map["image"].filename diff --git a/nitransforms/io/fsl.py b/nitransforms/io/fsl.py index f454227e..3425f5d0 100644 --- a/nitransforms/io/fsl.py +++ b/nitransforms/io/fsl.py @@ -180,7 +180,7 @@ def from_image(cls, imgobj): hdr = imgobj.header.copy() shape = hdr.get_data_shape() - if len(shape) != 4 or not shape[-1] in (2, 3): + if len(shape) != 4 or shape[-1] not in (2, 3): raise TransformFileError( 'Displacements field "%s" does not come from FSL.' % imgobj.file_map['image'].filename) diff --git a/nitransforms/io/itk.py b/nitransforms/io/itk.py index afabfd98..02fd9fe9 100644 --- a/nitransforms/io/itk.py +++ b/nitransforms/io/itk.py @@ -337,7 +337,7 @@ def from_image(cls, imgobj): hdr = imgobj.header.copy() shape = hdr.get_data_shape() - if len(shape) != 5 or shape[-2] != 1 or not shape[-1] in (2, 3): + if len(shape) != 5 or shape[-2] != 1 or shape[-1] not in (2, 3): raise TransformFileError( 'Displacements field "%s" does not come from ITK.' % imgobj.file_map["image"].filename diff --git a/nitransforms/tests/test_vis.py b/nitransforms/tests/test_vis.py index cd4cdb30..57565ada 100644 --- a/nitransforms/tests/test_vis.py +++ b/nitransforms/tests/test_vis.py @@ -1,37 +1,120 @@ +import numpy as np import matplotlib.pyplot as plt -import pytest, unittest -from pathlib import Path +import pytest +import nibabel as nb +from nitransforms.nonlinear import DenseFieldTransform from nitransforms.vis import PlotDenseField -test_dir = Path("tests/data/") -test_file = Path("ds-005_sub-01_from-OASIS_to-T1_warp_fsl.nii.gz") -xfm = Path(test_dir/test_file) +def test_read_path(data_path): + """Check that filepaths are a supported method for loading and reading transforms with PlotDenseField""" + PlotDenseField(transform = data_path / "ds-005_sub-01_from-OASIS_to-T1_warp_fsl.nii.gz") -def test_slice_values(xfm, xslice, yslice, zslice, is_deltas=True): + +def test_slice_values(): + """Check that ValueError is issued if negative slices are provided""" + with pytest.raises(ValueError): + PlotDenseField( + transform = np.zeros((10, 10, 10, 3)), + reference=nb.Nifti1Image(np.zeros((10, 10, 10, 3)), np.eye(4), None), + ).test_slices( + xslice=-1, + yslice=-1, + zslice=-1, + ) + + """Check that IndexError is issued if provided slices are beyond range of transform dimensions""" + with pytest.raises(IndexError): + xfm = DenseFieldTransform( + field=np.zeros((10, 10, 10, 3)), + reference=nb.Nifti1Image(np.zeros((10, 10, 10, 3)), np.eye(4), None), + ) + PlotDenseField( + transform=xfm._field, + reference=xfm._reference, + ).test_slices( + xslice=xfm._field.shape[0]+1, + yslice=xfm._field.shape[1]+1, + zslice=xfm._field.shape[2]+1, + ) + + +def test_show_transform(data_path, output_path): PlotDenseField( - path_to_file=Path(xfm), - is_deltas=is_deltas, - ).test_slices( - xslice=xslice, - yslice=yslice, - zslice=zslice, + transform = data_path / "ds-005_sub-01_from-OASIS_to-T1_warp_fsl.nii.gz" + ).show_transform( + xslice=50, + yslice=50, + zslice=50, ) + if output_path is not None: + plt.savefig(output_path / "show_transform.svg", bbox_inches="tight") + else: + plt.show() + + +def test_plot_distortion(data_path, output_path): + fig, axes = plt.subplots(1, 3, figsize=(12, 4)) + PlotDenseField( + transform = data_path / "ds-005_sub-01_from-OASIS_to-T1_warp_fsl.nii.gz" + ).plot_distortion( + axes=axes, + xslice=50, + yslice=50, + zslice=50, + ) + if output_path is not None: + plt.savefig(output_path / "show_transform.svg", bbox_inches="tight") + else: + plt.show() + -def test_show_transform(xfm, xslice=50, yslice=50, zslice=50, is_deltas=True): +def test_plot_quiverdsm(data_path, output_path): fig, axes = plt.subplots(1, 3, figsize=(12, 4)) PlotDenseField( - path_to_file=Path(xfm), - is_deltas=is_deltas, + transform = data_path / "ds-005_sub-01_from-OASIS_to-T1_warp_fsl.nii.gz" ).plot_quiverdsm( axes=axes, - xslice=xslice, - yslice=yslice, - zslice=zslice, + xslice=50, + yslice=50, + zslice=50, ) - plt.show() + if output_path is not None: + plt.savefig(output_path / "show_transform.svg", bbox_inches="tight") + else: + plt.show() + + +def test_3dquiver(data_path, output_path): + with pytest.raises(NotImplementedError): + fig = plt.figure() + axes = plt.subplots(projection='3d') + PlotDenseField( + transform = data_path / "ds-005_sub-01_from-OASIS_to-T1_warp_fsl.nii.gz" + ).plot_quiverdsm(axes=axes, three_D=True) + + if output_path is not None: + plt.savefig(output_path / "show_transform.svg", bbox_inches="tight") + else: + plt.show() + + +def test_plot_jacobian(data_path, output_path): + fig, axes = plt.subplots(1, 3, figsize=(12, 4)) + PlotDenseField( + transform = data_path / "ds-005_sub-01_from-OASIS_to-T1_warp_fsl.nii.gz" + ).plot_jacobian( + axes=axes, + xslice=50, + yslice=50, + zslice=50, + ) + if output_path is not None: + plt.savefig(output_path / "show_transform.svg", bbox_inches="tight") + else: + plt.show() + -test_slice_values(xfm, 50, -50, 50) #should raise ValueError -test_slice_values(xfm, 500, 50, 50) #should raise IndexError -test_show_transform(Path("tests/data/ds-005_sub-01_from-OASIS_to-T1_warp_fsl.nii.gz")) +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/nitransforms/vis.py b/nitransforms/vis.py index 3beb5793..357bfe55 100644 --- a/nitransforms/vis.py +++ b/nitransforms/vis.py @@ -9,7 +9,6 @@ from itertools import product -from nitransforms.base import TransformError from nitransforms.nonlinear import DenseFieldTransform class PlotDenseField(): @@ -27,26 +26,35 @@ class PlotDenseField(): Parameters ---------- - path_to_file: :obj:`str` + transform: :obj:`str` Path from which the trasnformation file should be read. is_deltas: :obj:`bool` Whether the field is a displacement field or a deformations field. Default = True """ - __slots__ = ('_path_to_file', '_xfm', '_voxel_size') + __slots__ = ('_transform', '_xfm', '_voxel_size') - def __init__(self, path_to_file, is_deltas=True): - self._path_to_file = path_to_file + def __init__(self, transform, is_deltas=True, reference=None): + self._transform = transform self._xfm = DenseFieldTransform( - self._path_to_file, + field=self._transform, is_deltas=is_deltas, + reference=reference ) - self._voxel_size = nb.load(path_to_file).header.get_zooms() + try: + """if field provided by path""" + self._voxel_size = nb.load(transform).header.get_zooms() + except: + """if field provided by numpy array (eg tests)""" + deltas = [] + for i in range(self._xfm.ndim): + deltas.append((np.max(self._xfm._field[i]) - np.min(self._xfm._field[i])) / len(self._xfm._field[i])) + + assert np.all(deltas == deltas[0]) + assert np.mean(deltas) == deltas[0] + + deltas.append(0) + self._voxel_size = deltas - if self._xfm._field.shape[-1] != self._xfm.ndim: - raise TransformError( - "The number of components of the field (%d) does not match " - "the number of dimensions (%d)" % (self._xfm._field.shape[-1], self._xfm.ndim) - ) def show_transform(self, xslice, yslice, zslice, scaling=1, show_brain=True, show_grid=True, lw=0.1, save_to_path=None): """ @@ -73,7 +81,7 @@ def show_transform(self, xslice, yslice, zslice, scaling=1, show_brain=True, sho >>> plt.show() >>> PlotDenseField( - ... path_to_file = "test_dir/someones-displacement-field.nii.gz", + ... transform = "test_dir/someones-displacement-field.nii.gz", ... is_deltas = True, ... ).show_transform( ... xslice = 70, @@ -92,7 +100,7 @@ def show_transform(self, xslice, yslice, zslice, scaling=1, show_brain=True, sho figsize=(9,9), gs_rows=3, gs_cols=3, - suptitle="Dense Field Transform \n" + os.path.basename(self._path_to_file), + suptitle="Dense Field Transform \n" + os.path.basename(self._transform), ) fig.subplots_adjust(bottom=0.15) @@ -143,7 +151,7 @@ def plot_distortion(self, axes, xslice, yslice, zslice, show_brain=True, show_gr Example: fig, axes = plt.subplots(1, 3, figsize=(15, 5), tight_layout=True) PlotDenseField( - path_to_file="test_dir/someones-displacement-field.nii.gz", + transform="test_dir/someones-displacement-field.nii.gz", is_deltas=True, ).plot_distortion( axes=[axes[2], axes[1], axes[0]], @@ -215,7 +223,7 @@ def plot_quiverdsm(self, axes, xslice, yslice, zslice, scaling=1, three_D=False, Example: fig, axes = plt.subplots(1, 3, figsize=(15, 5), tight_layout=True) PlotDenseField( - path_to_file="test_dir/someones-displacement-field.nii.gz", + transform="test_dir/someones-displacement-field.nii.gz", is_deltas=True, ).plot_quiverdsm( axes=[axes[2], axes[1], axes[0]], @@ -230,7 +238,7 @@ def plot_quiverdsm(self, axes, xslice, yslice, zslice, scaling=1, three_D=False, #Example 2: 3D quiver fig = plt.figure() ax = fig.add_subplot(projection='3d') - PlotDenseField(path_to_file, is_deltas=True).plot_quiverdsm( + PlotDenseField(transform, is_deltas=True).plot_quiverdsm( ax, xslice=None, yslice=None, @@ -310,7 +318,7 @@ def plot_jacobian(self, axes, xslice, yslice, zslice, show_titles=True): Example: fig, axes = plt.subplots(1, 3, figsize=(15, 5), tight_layout=True) PlotDenseField( - path_to_file="test_dir/someones-displacement-field.nii.gz", + transform="test_dir/someones-displacement-field.nii.gz", is_deltas=True, ).plot_jacobian( axes=[axes[2], axes[1], axes[0]],