diff --git a/nitransforms/resampling.py b/nitransforms/resampling.py index 9de0d2d6..52f831ef 100644 --- a/nitransforms/resampling.py +++ b/nitransforms/resampling.py @@ -10,8 +10,10 @@ from pathlib import Path import numpy as np from nibabel.loadsave import load as _nbload +from nibabel.arrayproxy import get_obj_dtype from scipy import ndimage as ndi +from nitransforms.linear import Affine, LinearTransformsMapping from nitransforms.base import ( ImageGrid, TransformError, @@ -19,6 +21,9 @@ _as_homogeneous, ) +SERIALIZE_VOLUME_WINDOW_WIDTH : int = 8 +"""Minimum number of volumes to automatically serialize 4D transforms.""" + def apply( transform, @@ -29,6 +34,8 @@ def apply( cval=0.0, prefilter=True, output_dtype=None, + serialize_nvols=SERIALIZE_VOLUME_WINDOW_WIDTH, + njobs=None, ): """ Apply a transformation to an image, resampling on the reference spatial object. @@ -89,40 +96,93 @@ def apply( spatialimage = _nbload(str(spatialimage)) data = np.asanyarray(spatialimage.dataobj) + data_nvols = 1 if data.ndim < 4 else data.shape[-1] - if data.ndim == 4 and data.shape[-1] != len(transform): - raise ValueError( - "The fourth dimension of the data does not match the tranform's shape." - ) + if type(transform) == Affine or type(transform) == LinearTransformsMapping: + xfm_nvols = len(transform) + else: + xfm_nvols = transform.ndim if data.ndim < transform.ndim: data = data[..., np.newaxis] + elif data_nvols > 1 and data_nvols != xfm_nvols: + raise ValueError( + "The fourth dimension of the data does not match the transform's shape." + ) - # For model-based nonlinear transforms, generate the corresponding dense field - if hasattr(transform, "to_field") and callable(transform.to_field): - targets = ImageGrid(spatialimage).index( - _as_homogeneous( - transform.to_field(reference=reference).map(_ref.ndcoords.T), - dim=_ref.ndim, - ) + serialize_nvols = serialize_nvols if serialize_nvols and serialize_nvols > 1 else np.inf + serialize_4d = max(data_nvols, xfm_nvols) >= serialize_nvols + + if serialize_4d: + # Avoid opening the data array just yet + input_dtype = get_obj_dtype(spatialimage.dataobj) + output_dtype = output_dtype or input_dtype + + # Prepare physical coordinates of input (grid, points) + xcoords = _ref.ndcoords.astype("f4").T + + # Invert target's (moving) affine once + ras2vox = ~Affine(spatialimage.affine) + dataobj = ( + np.asanyarray(spatialimage.dataobj, dtype=input_dtype) + if spatialimage.ndim in (2, 3) + else None ) - else: - targets = ImageGrid(spatialimage).index( # data should be an image - _as_homogeneous(transform.map(_ref.ndcoords.T), dim=_ref.ndim) + + # Order F ensures individual volumes are contiguous in memory + # Also matches NIfTI, making final save more efficient + resampled = np.zeros( + (xcoords.shape[0], len(transform)), dtype=output_dtype, order="F" ) - if transform.ndim == 4: - targets = _as_homogeneous(targets.reshape(-2, targets.shape[0])).T - - resampled = ndi.map_coordinates( - data, - targets, - output=output_dtype, - order=order, - mode=mode, - cval=cval, - prefilter=prefilter, - ) + for t, xfm_t in enumerate(transform): + # Map the input coordinates on to timepoint t of the target (moving) + ycoords = xfm_t.map(xcoords)[..., : _ref.ndim] + + # Calculate corresponding voxel coordinates + yvoxels = ras2vox.map(ycoords)[..., : _ref.ndim] + + # Interpolate + resampled[..., t] = ndi.map_coordinates( + ( + dataobj + if dataobj is not None + else spatialimage.dataobj[..., t].astype(input_dtype, copy=False) + ), + yvoxels.T, + output=output_dtype, + order=order, + mode=mode, + cval=cval, + prefilter=prefilter, + ) + + else: + # For model-based nonlinear transforms, generate the corresponding dense field + if hasattr(transform, "to_field") and callable(transform.to_field): + targets = ImageGrid(spatialimage).index( + _as_homogeneous( + transform.to_field(reference=reference).map(_ref.ndcoords.T), + dim=_ref.ndim, + ) + ) + else: + targets = ImageGrid(spatialimage).index( # data should be an image + _as_homogeneous(transform.map(_ref.ndcoords.T), dim=_ref.ndim) + ) + + if transform.ndim == 4: + targets = _as_homogeneous(targets.reshape(-2, targets.shape[0])).T + + resampled = ndi.map_coordinates( + data, + targets, + output=output_dtype, + order=order, + mode=mode, + cval=cval, + prefilter=prefilter, + ) if isinstance(_ref, ImageGrid): # If reference is grid, reshape hdr = None