Skip to content

Commit

Permalink
fix: postpone coordinate mapping on linear array transforms
Browse files Browse the repository at this point in the history
Resolves: #173.
  • Loading branch information
oesteban committed Nov 16, 2023
1 parent 6e70c02 commit db1b250
Showing 1 changed file with 41 additions and 27 deletions.
68 changes: 41 additions & 27 deletions nitransforms/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,7 @@ def apply(
The data imaged after resampling to reference space.
"""

if reference is not None and isinstance(reference, (str, Path)):
reference = _nbload(str(reference))

Expand All @@ -446,40 +447,53 @@ def apply(
if isinstance(spatialimage, (str, Path)):
spatialimage = _nbload(str(spatialimage))

data = np.squeeze(np.asanyarray(spatialimage.dataobj))
output_dtype = output_dtype or data.dtype
# Avoid opening the data array just yet
input_dtype = spatialimage.header.get_data_dtype()
output_dtype = output_dtype or input_dtype

ycoords = self.map(_ref.ndcoords.T)
targets = ImageGrid(spatialimage).index( # data should be an image
_as_homogeneous(np.vstack(ycoords), dim=_ref.ndim)
)
# Prepare physical coordinates of input (grid, points)
xcoords = _ref.ndcoords.astype("f4")

if data.ndim == 4:
if len(self) != data.shape[-1]:
# Invert target's (moving) affine once
ras2vox = ~Affine(spatialimage.affine)

if spatialimage.ndim == 4:
if len(self) != spatialimage.shape[-1]:
raise ValueError(
"Attempting to apply %d transforms on a file with "
"%d timepoints" % (len(self), data.shape[-1])
"%d timepoints" % (len(self), spatialimage.shape[-1])
)
targets = targets.reshape((len(self), -1, targets.shape[-1]))
resampled = np.stack(
[
ndi.map_coordinates(
data[..., t],
targets[t, ..., : _ref.ndim].T,
output=output_dtype,
order=order,
mode=mode,
cval=cval,
prefilter=prefilter,
)
for t in range(data.shape[-1])
],
axis=0,

# Order F ensures individual volumes are contiguous in memory
# Also matches NIfTI, making final save more efficient
resampled = np.zeros(
(xcoords.T.shape[0], ) + spatialimage.shape[-1:], dtype=output_dtype, order="F"
)
elif data.ndim in (2, 3):

for t in range(spatialimage.shape[-1]):
# Map the input coordinates on to timepoint t of the target (moving)
ycoords = Affine(self.matrix[t]).map(xcoords.T)[..., : _ref.ndim]

# Calculate corresponding voxel coordinates
yvoxels = ras2vox.map(ycoords)[..., : _ref.ndim]

# Interpolate
resampled[..., t] = ndi.map_coordinates(
spatialimage.dataobj[..., t].astype(input_dtype, copy=False),
yvoxels.T,
output=output_dtype,
order=order,
mode=mode,
cval=cval,
prefilter=prefilter,
)
elif spatialimage.ndim in (2, 3):
ycoords = self.map(xcoords.T)[..., : _ref.ndim]
yvoxels = ras2vox.map(ycoords)[..., : _ref.ndim]

resampled = ndi.map_coordinates(
data,
targets[..., : _ref.ndim].T,
spatialimage.dataobj.astype(input_dtype, copy=False),
yvoxels.T,
output=output_dtype,
order=order,
mode=mode,
Expand Down

0 comments on commit db1b250

Please sign in to comment.