diff --git a/nitransforms/resampling.py b/nitransforms/resampling.py index c36750ef..52f831ef 100644 --- a/nitransforms/resampling.py +++ b/nitransforms/resampling.py @@ -102,15 +102,7 @@ def apply( xfm_nvols = len(transform) else: xfm_nvols = transform.ndim - """ - if data_nvols == 1 and xfm_nvols > 1: - data = data[..., np.newaxis] - elif data_nvols != xfm_nvols: - raise ValueError( - "The fourth dimension of the data does not match the transform's shape." - ) - RESAMPLING FAILS. SUGGEST: - """ + if data.ndim < transform.ndim: data = data[..., np.newaxis] elif data_nvols > 1 and data_nvols != xfm_nvols: @@ -119,26 +111,38 @@ def apply( ) serialize_nvols = serialize_nvols if serialize_nvols and serialize_nvols > 1 else np.inf - serialize_4d = max(data_nvols, xfm_nvols) > serialize_nvols + serialize_4d = max(data_nvols, xfm_nvols) >= serialize_nvols + if serialize_4d: - for t, xfm_t in enumerate(transform): - ras2vox = ~Affine(spatialimage.affine) - input_dtype = get_obj_dtype(spatialimage.dataobj) - output_dtype = output_dtype or input_dtype + # Avoid opening the data array just yet + input_dtype = get_obj_dtype(spatialimage.dataobj) + output_dtype = output_dtype or input_dtype + + # Prepare physical coordinates of input (grid, points) + xcoords = _ref.ndcoords.astype("f4").T + + # Invert target's (moving) affine once + ras2vox = ~Affine(spatialimage.affine) + dataobj = ( + np.asanyarray(spatialimage.dataobj, dtype=input_dtype) + if spatialimage.ndim in (2, 3) + else None + ) + # Order F ensures individual volumes are contiguous in memory + # Also matches NIfTI, making final save more efficient + resampled = np.zeros( + (xcoords.shape[0], len(transform)), dtype=output_dtype, order="F" + ) + + for t, xfm_t in enumerate(transform): # Map the input coordinates on to timepoint t of the target (moving) - xcoords = _ref.ndcoords.astype("f4").T ycoords = xfm_t.map(xcoords)[..., : _ref.ndim] # Calculate corresponding voxel coordinates yvoxels = ras2vox.map(ycoords)[..., : _ref.ndim] # Interpolate - dataobj = ( - np.asanyarray(spatialimage.dataobj, dtype=input_dtype) - if spatialimage.ndim in (2, 3) - else None - ) resampled[..., t] = ndi.map_coordinates( ( dataobj