From db1b250dbbfa928785a2d469f246369e130f3624 Mon Sep 17 00:00:00 2001 From: Oscar Esteban Date: Thu, 16 Nov 2023 17:49:46 +0100 Subject: [PATCH] fix: postpone coordinate mapping on linear array transforms Resolves: #173. --- nitransforms/linear.py | 68 +++++++++++++++++++++++++----------------- 1 file changed, 41 insertions(+), 27 deletions(-) diff --git a/nitransforms/linear.py b/nitransforms/linear.py index 9c430d3b..0709d50b 100644 --- a/nitransforms/linear.py +++ b/nitransforms/linear.py @@ -436,6 +436,7 @@ def apply( The data imaged after resampling to reference space. """ + if reference is not None and isinstance(reference, (str, Path)): reference = _nbload(str(reference)) @@ -446,40 +447,53 @@ def apply( if isinstance(spatialimage, (str, Path)): spatialimage = _nbload(str(spatialimage)) - data = np.squeeze(np.asanyarray(spatialimage.dataobj)) - output_dtype = output_dtype or data.dtype + # Avoid opening the data array just yet + input_dtype = spatialimage.header.get_data_dtype() + output_dtype = output_dtype or input_dtype - ycoords = self.map(_ref.ndcoords.T) - targets = ImageGrid(spatialimage).index( # data should be an image - _as_homogeneous(np.vstack(ycoords), dim=_ref.ndim) - ) + # Prepare physical coordinates of input (grid, points) + xcoords = _ref.ndcoords.astype("f4") - if data.ndim == 4: - if len(self) != data.shape[-1]: + # Invert target's (moving) affine once + ras2vox = ~Affine(spatialimage.affine) + + if spatialimage.ndim == 4: + if len(self) != spatialimage.shape[-1]: raise ValueError( "Attempting to apply %d transforms on a file with " - "%d timepoints" % (len(self), data.shape[-1]) + "%d timepoints" % (len(self), spatialimage.shape[-1]) ) - targets = targets.reshape((len(self), -1, targets.shape[-1])) - resampled = np.stack( - [ - ndi.map_coordinates( - data[..., t], - targets[t, ..., : _ref.ndim].T, - output=output_dtype, - order=order, - mode=mode, - cval=cval, - prefilter=prefilter, - ) - for t in range(data.shape[-1]) - ], - axis=0, + + # Order F ensures individual volumes are contiguous in memory + # Also matches NIfTI, making final save more efficient + resampled = np.zeros( + (xcoords.T.shape[0], ) + spatialimage.shape[-1:], dtype=output_dtype, order="F" ) - elif data.ndim in (2, 3): + + for t in range(spatialimage.shape[-1]): + # Map the input coordinates on to timepoint t of the target (moving) + ycoords = Affine(self.matrix[t]).map(xcoords.T)[..., : _ref.ndim] + + # Calculate corresponding voxel coordinates + yvoxels = ras2vox.map(ycoords)[..., : _ref.ndim] + + # Interpolate + resampled[..., t] = ndi.map_coordinates( + spatialimage.dataobj[..., t].astype(input_dtype, copy=False), + yvoxels.T, + output=output_dtype, + order=order, + mode=mode, + cval=cval, + prefilter=prefilter, + ) + elif spatialimage.ndim in (2, 3): + ycoords = self.map(xcoords.T)[..., : _ref.ndim] + yvoxels = ras2vox.map(ycoords)[..., : _ref.ndim] + resampled = ndi.map_coordinates( - data, - targets[..., : _ref.ndim].T, + spatialimage.dataobj.astype(input_dtype, copy=False), + yvoxels.T, output=output_dtype, order=order, mode=mode,