Skip to content

Commit

Permalink
enh: Further additions to test_vis, output path for figures added to …
Browse files Browse the repository at this point in the history
…conftest
  • Loading branch information
Julien Marabotto authored and Julien Marabotto committed Jul 11, 2024
1 parent 6b14738 commit af24af8
Show file tree
Hide file tree
Showing 7 changed files with 144 additions and 44 deletions.
3 changes: 2 additions & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,8 @@ jobs:
-v /tmp/data/nitransforms-tests:/data -e TEST_DATA_HOME=/data \
-e COVERAGE_FILE=/tmp/summaries/.pytest.coverage \
-v /tmp/fslicense/license.txt:/opt/freesurfer/license.txt:ro \
-v /tmp/tests:/tmp nitransforms:latest \
-v /tmp/tests:/tmp -e TEST_OUTPUT_DIR=/tmp/artifacts \
nitransforms:latest \
pytest --junit-xml=/tmp/summaries/pytest.xml \
--cov nitransforms --cov-report xml:/tmp/summaries/unittests.xml \
nitransforms/
Expand Down
8 changes: 8 additions & 0 deletions nitransforms/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
_data = None
_brainmask = None
_testdir = Path(os.getenv("TEST_DATA_HOME", "~/.nitransforms/testdata")).expanduser()
_outdir = os.getenv("TEST_OUTPUT_DIR", None)

_datadir = Path(__file__).parent / "tests" / "data"


Expand Down Expand Up @@ -45,6 +47,12 @@ def testdata_path():
return _testdir


@pytest.fixture
def output_path():
"""Return an output folder."""
return Path(_outdir) if _outdir is not None else None


@pytest.fixture
def get_testdata():
"""Generate data in the requested orientation."""
Expand Down
2 changes: 1 addition & 1 deletion nitransforms/io/afni.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def from_image(cls, imgobj):
hdr = imgobj.header.copy()
shape = hdr.get_data_shape()

if len(shape) != 5 or shape[-2] != 1 or not shape[-1] in (2, 3):
if len(shape) != 5 or shape[-2] != 1 or shape[-1] not in (2, 3):
raise TransformFileError(
'Displacements field "%s" does not come from AFNI.'
% imgobj.file_map["image"].filename
Expand Down
2 changes: 1 addition & 1 deletion nitransforms/io/fsl.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def from_image(cls, imgobj):
hdr = imgobj.header.copy()
shape = hdr.get_data_shape()

if len(shape) != 4 or not shape[-1] in (2, 3):
if len(shape) != 4 or shape[-1] not in (2, 3):
raise TransformFileError(
'Displacements field "%s" does not come from FSL.' %
imgobj.file_map['image'].filename)
Expand Down
2 changes: 1 addition & 1 deletion nitransforms/io/itk.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ def from_image(cls, imgobj):
hdr = imgobj.header.copy()
shape = hdr.get_data_shape()

if len(shape) != 5 or shape[-2] != 1 or not shape[-1] in (2, 3):
if len(shape) != 5 or shape[-2] != 1 or shape[-1] not in (2, 3):
raise TransformFileError(
'Displacements field "%s" does not come from ITK.'
% imgobj.file_map["image"].filename
Expand Down
127 changes: 105 additions & 22 deletions nitransforms/tests/test_vis.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,120 @@
import numpy as np
import matplotlib.pyplot as plt
import pytest, unittest
from pathlib import Path
import pytest

import nibabel as nb
from nitransforms.nonlinear import DenseFieldTransform
from nitransforms.vis import PlotDenseField

test_dir = Path("tests/data/")
test_file = Path("ds-005_sub-01_from-OASIS_to-T1_warp_fsl.nii.gz")

xfm = Path(test_dir/test_file)
def test_read_path(data_path):
"""Check that filepaths are a supported method for loading and reading transforms with PlotDenseField"""
PlotDenseField(transform = data_path / "ds-005_sub-01_from-OASIS_to-T1_warp_fsl.nii.gz")

def test_slice_values(xfm, xslice, yslice, zslice, is_deltas=True):

def test_slice_values():
"""Check that ValueError is issued if negative slices are provided"""
with pytest.raises(ValueError):
PlotDenseField(
transform = np.zeros((10, 10, 10, 3)),
reference=nb.Nifti1Image(np.zeros((10, 10, 10, 3)), np.eye(4), None),
).test_slices(
xslice=-1,
yslice=-1,
zslice=-1,
)

"""Check that IndexError is issued if provided slices are beyond range of transform dimensions"""
with pytest.raises(IndexError):
xfm = DenseFieldTransform(
field=np.zeros((10, 10, 10, 3)),
reference=nb.Nifti1Image(np.zeros((10, 10, 10, 3)), np.eye(4), None),
)
PlotDenseField(
transform=xfm._field,
reference=xfm._reference,
).test_slices(
xslice=xfm._field.shape[0]+1,
yslice=xfm._field.shape[1]+1,
zslice=xfm._field.shape[2]+1,
)


def test_show_transform(data_path, output_path):
PlotDenseField(
path_to_file=Path(xfm),
is_deltas=is_deltas,
).test_slices(
xslice=xslice,
yslice=yslice,
zslice=zslice,
transform = data_path / "ds-005_sub-01_from-OASIS_to-T1_warp_fsl.nii.gz"
).show_transform(
xslice=50,
yslice=50,
zslice=50,
)
if output_path is not None:
plt.savefig(output_path / "show_transform.svg", bbox_inches="tight")
else:
plt.show()


def test_plot_distortion(data_path, output_path):
fig, axes = plt.subplots(1, 3, figsize=(12, 4))
PlotDenseField(
transform = data_path / "ds-005_sub-01_from-OASIS_to-T1_warp_fsl.nii.gz"
).plot_distortion(
axes=axes,
xslice=50,
yslice=50,
zslice=50,
)
if output_path is not None:
plt.savefig(output_path / "show_transform.svg", bbox_inches="tight")
else:
plt.show()


def test_show_transform(xfm, xslice=50, yslice=50, zslice=50, is_deltas=True):
def test_plot_quiverdsm(data_path, output_path):
fig, axes = plt.subplots(1, 3, figsize=(12, 4))
PlotDenseField(
path_to_file=Path(xfm),
is_deltas=is_deltas,
transform = data_path / "ds-005_sub-01_from-OASIS_to-T1_warp_fsl.nii.gz"
).plot_quiverdsm(
axes=axes,
xslice=xslice,
yslice=yslice,
zslice=zslice,
xslice=50,
yslice=50,
zslice=50,
)
plt.show()
if output_path is not None:
plt.savefig(output_path / "show_transform.svg", bbox_inches="tight")
else:
plt.show()


def test_3dquiver(data_path, output_path):
with pytest.raises(NotImplementedError):
fig = plt.figure()
axes = plt.subplots(projection='3d')
PlotDenseField(
transform = data_path / "ds-005_sub-01_from-OASIS_to-T1_warp_fsl.nii.gz"
).plot_quiverdsm(axes=axes, three_D=True)

if output_path is not None:
plt.savefig(output_path / "show_transform.svg", bbox_inches="tight")
else:
plt.show()


def test_plot_jacobian(data_path, output_path):
fig, axes = plt.subplots(1, 3, figsize=(12, 4))
PlotDenseField(
transform = data_path / "ds-005_sub-01_from-OASIS_to-T1_warp_fsl.nii.gz"
).plot_jacobian(
axes=axes,
xslice=50,
yslice=50,
zslice=50,
)
if output_path is not None:
plt.savefig(output_path / "show_transform.svg", bbox_inches="tight")
else:
plt.show()


test_slice_values(xfm, 50, -50, 50) #should raise ValueError
test_slice_values(xfm, 500, 50, 50) #should raise IndexError
test_show_transform(Path("tests/data/ds-005_sub-01_from-OASIS_to-T1_warp_fsl.nii.gz"))
if __name__ == "__main__":
pytest.main([__file__])
44 changes: 26 additions & 18 deletions nitransforms/vis.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

from itertools import product

from nitransforms.base import TransformError
from nitransforms.nonlinear import DenseFieldTransform

class PlotDenseField():
Expand All @@ -27,26 +26,35 @@ class PlotDenseField():
Parameters
----------
path_to_file: :obj:`str`
transform: :obj:`str`
Path from which the trasnformation file should be read.
is_deltas: :obj:`bool`
Whether the field is a displacement field or a deformations field. Default = True
"""
__slots__ = ('_path_to_file', '_xfm', '_voxel_size')
__slots__ = ('_transform', '_xfm', '_voxel_size')

def __init__(self, path_to_file, is_deltas=True):
self._path_to_file = path_to_file
def __init__(self, transform, is_deltas=True, reference=None):
self._transform = transform
self._xfm = DenseFieldTransform(
self._path_to_file,
field=self._transform,
is_deltas=is_deltas,
reference=reference
)
self._voxel_size = nb.load(path_to_file).header.get_zooms()
try:
"""if field provided by path"""
self._voxel_size = nb.load(transform).header.get_zooms()
except:
"""if field provided by numpy array (eg tests)"""
deltas = []
for i in range(self._xfm.ndim):
deltas.append((np.max(self._xfm._field[i]) - np.min(self._xfm._field[i])) / len(self._xfm._field[i]))

assert np.all(deltas == deltas[0])
assert np.mean(deltas) == deltas[0]

deltas.append(0)
self._voxel_size = deltas

if self._xfm._field.shape[-1] != self._xfm.ndim:
raise TransformError(
"The number of components of the field (%d) does not match "
"the number of dimensions (%d)" % (self._xfm._field.shape[-1], self._xfm.ndim)
)

def show_transform(self, xslice, yslice, zslice, scaling=1, show_brain=True, show_grid=True, lw=0.1, save_to_path=None):
"""
Expand All @@ -73,7 +81,7 @@ def show_transform(self, xslice, yslice, zslice, scaling=1, show_brain=True, sho
>>> plt.show()
>>> PlotDenseField(
... path_to_file = "test_dir/someones-displacement-field.nii.gz",
... transform = "test_dir/someones-displacement-field.nii.gz",
... is_deltas = True,
... ).show_transform(
... xslice = 70,
Expand All @@ -92,7 +100,7 @@ def show_transform(self, xslice, yslice, zslice, scaling=1, show_brain=True, sho
figsize=(9,9),
gs_rows=3,
gs_cols=3,
suptitle="Dense Field Transform \n" + os.path.basename(self._path_to_file),
suptitle="Dense Field Transform \n" + os.path.basename(self._transform),
)
fig.subplots_adjust(bottom=0.15)

Expand Down Expand Up @@ -143,7 +151,7 @@ def plot_distortion(self, axes, xslice, yslice, zslice, show_brain=True, show_gr
Example:
fig, axes = plt.subplots(1, 3, figsize=(15, 5), tight_layout=True)
PlotDenseField(
path_to_file="test_dir/someones-displacement-field.nii.gz",
transform="test_dir/someones-displacement-field.nii.gz",
is_deltas=True,
).plot_distortion(
axes=[axes[2], axes[1], axes[0]],
Expand Down Expand Up @@ -215,7 +223,7 @@ def plot_quiverdsm(self, axes, xslice, yslice, zslice, scaling=1, three_D=False,
Example:
fig, axes = plt.subplots(1, 3, figsize=(15, 5), tight_layout=True)
PlotDenseField(
path_to_file="test_dir/someones-displacement-field.nii.gz",
transform="test_dir/someones-displacement-field.nii.gz",
is_deltas=True,
).plot_quiverdsm(
axes=[axes[2], axes[1], axes[0]],
Expand All @@ -230,7 +238,7 @@ def plot_quiverdsm(self, axes, xslice, yslice, zslice, scaling=1, three_D=False,
#Example 2: 3D quiver
fig = plt.figure()
ax = fig.add_subplot(projection='3d')
PlotDenseField(path_to_file, is_deltas=True).plot_quiverdsm(
PlotDenseField(transform, is_deltas=True).plot_quiverdsm(
ax,
xslice=None,
yslice=None,
Expand Down Expand Up @@ -310,7 +318,7 @@ def plot_jacobian(self, axes, xslice, yslice, zslice, show_titles=True):
Example:
fig, axes = plt.subplots(1, 3, figsize=(15, 5), tight_layout=True)
PlotDenseField(
path_to_file="test_dir/someones-displacement-field.nii.gz",
transform="test_dir/someones-displacement-field.nii.gz",
is_deltas=True,
).plot_jacobian(
axes=[axes[2], axes[1], axes[0]],
Expand Down

0 comments on commit af24af8

Please sign in to comment.