diff --git a/nitransforms/io/afni.py b/nitransforms/io/afni.py index b7fc657b..e95c4494 100644 --- a/nitransforms/io/afni.py +++ b/nitransforms/io/afni.py @@ -193,6 +193,17 @@ def from_image(cls, imgobj): return imgobj.__class__(field, imgobj.affine, hdr) + @classmethod + def to_image(cls, imgobj): + """Export a displacements field from a nibabel object.""" + + hdr = imgobj.header.copy() + + warp_data = imgobj.get_fdata().reshape(imgobj.shape[:3] + (1, imgobj.shape[-1])) + warp_data[..., (0, 1)] *= -1 + + return imgobj.__class__(warp_data, imgobj.affine, hdr) + def _is_oblique(affine, thres=OBLIQUITY_THRESHOLD_DEG): """ diff --git a/nitransforms/io/base.py b/nitransforms/io/base.py index 6d1a7c8e..d86c8539 100644 --- a/nitransforms/io/base.py +++ b/nitransforms/io/base.py @@ -146,6 +146,17 @@ def from_image(cls, imgobj): """Import a displacements field from a nibabel image object.""" raise NotImplementedError + @classmethod + def to_filename(cls, img, filename): + """Export a displacements field to a NIfTI file.""" + imgobj = cls.to_image(img) + imgobj.to_filename(filename) + + @classmethod + def to_image(cls, imgobj): + """Export a displacements field image from a nitransforms image object.""" + raise NotImplementedError + def _ensure_image(img): if isinstance(img, (str, Path)): diff --git a/nitransforms/io/fsl.py b/nitransforms/io/fsl.py index 8e4c8264..f454227e 100644 --- a/nitransforms/io/fsl.py +++ b/nitransforms/io/fsl.py @@ -190,6 +190,17 @@ def from_image(cls, imgobj): return imgobj.__class__(field, imgobj.affine, hdr) + @classmethod + def to_image(cls, imgobj): + """Export a displacements field from a nibabel object.""" + + hdr = imgobj.header.copy() + + warp_data = imgobj.get_fdata() + warp_data[..., 0] *= -1 + + return imgobj.__class__(warp_data, imgobj.affine, hdr) + def _fsl_aff_adapt(space): """ diff --git a/nitransforms/io/itk.py b/nitransforms/io/itk.py index d7a093eb..ddeb78e6 100644 --- a/nitransforms/io/itk.py +++ b/nitransforms/io/itk.py @@ -352,6 +352,18 @@ def from_image(cls, imgobj): return imgobj.__class__(field, imgobj.affine, hdr) + @classmethod + def to_image(cls, imgobj): + """Export a displacements field from a nibabel object.""" + + hdr = imgobj.header.copy() + hdr.set_intent("vector") + + warp_data = imgobj.get_fdata().reshape(imgobj.shape[:3] + (1, imgobj.shape[-1])) + warp_data[..., (0, 1)] *= -1 + + return imgobj.__class__(warp_data, imgobj.affine, hdr) + class ITKCompositeH5: """A data structure for ITK's HDF5 files."""