Skip to content

Commit

Permalink
enh: integrating @jmarabotto's code
Browse files Browse the repository at this point in the history
  • Loading branch information
oesteban committed Jul 30, 2024
1 parent 23daabb commit 79e5cad
Showing 1 changed file with 45 additions and 39 deletions.
84 changes: 45 additions & 39 deletions nitransforms/resampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,21 @@
#
### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
"""Resampling utilities."""
from warnings import warn

from pathlib import Path
import numpy as np
from nibabel.loadsave import load as _nbload
from nibabel.arrayproxy import get_obj_dtype
from scipy import ndimage as ndi

from nitransforms.linear import Affine, LinearTransformsMapping
from nitransforms.base import (
ImageGrid,
TransformError,
SpatialReference,
_as_homogeneous,
)

SERIALIZE_VOLUME_WINDOW_WIDTH : int = 8
SERIALIZE_VOLUME_WINDOW_WIDTH: int = 8
"""Minimum number of volumes to automatically serialize 4D transforms."""


Expand Down Expand Up @@ -96,58 +95,67 @@ def apply(
if isinstance(spatialimage, (str, Path)):
spatialimage = _nbload(str(spatialimage))

data = np.asanyarray(spatialimage.dataobj)
data_nvols = 1 if data.ndim < 4 else data.shape[-1]
# Avoid opening the data array just yet
input_dtype = get_obj_dtype(spatialimage.dataobj)
output_dtype = output_dtype or input_dtype

# Number of transformations
data_nvols = 1 if spatialimage.ndim < 4 else spatialimage.shape[-1]
xfm_nvols = len(transform)

if data_nvols == 1 and xfm_nvols > 1:
data = data[..., np.newaxis]
elif data_nvols != xfm_nvols:
if data_nvols != xfm_nvols and min(data_nvols, xfm_nvols) > 1:
raise ValueError(
"The fourth dimension of the data does not match the transform's shape."
)

serialize_nvols = serialize_nvols if serialize_nvols and serialize_nvols > 1 else np.inf
serialize_4d = max(data_nvols, xfm_nvols) >= serialize_nvols
serialize_nvols = (
serialize_nvols if serialize_nvols and serialize_nvols > 1 else np.inf
)
n_resamplings = max(data_nvols, xfm_nvols)
serialize_4d = n_resamplings >= serialize_nvols

targets = None
if hasattr(transform, "to_field") and callable(transform.to_field):
targets = ImageGrid(spatialimage).index(
_as_homogeneous(
transform.to_field(reference=reference).map(_ref.ndcoords.T),
dim=_ref.ndim,
)
)
elif xfm_nvols == 1:
targets = ImageGrid(spatialimage).index( # data should be an image
_as_homogeneous(transform.map(_ref.ndcoords.T), dim=_ref.ndim)
)

if serialize_4d:
# Avoid opening the data array just yet
input_dtype = get_obj_dtype(spatialimage.dataobj)
output_dtype = output_dtype or input_dtype

# Prepare physical coordinates of input (grid, points)
xcoords = _ref.ndcoords.astype("f4").T

# Invert target's (moving) affine once
ras2vox = ~Affine(spatialimage.affine)
dataobj = (
data = (
np.asanyarray(spatialimage.dataobj, dtype=input_dtype)
if spatialimage.ndim in (2, 3)
if data_nvols == 1
else None
)

# Order F ensures individual volumes are contiguous in memory
# Also matches NIfTI, making final save more efficient
resampled = np.zeros(
(xcoords.shape[0], len(transform)), dtype=output_dtype, order="F"
(spatialimage.size, len(transform)), dtype=output_dtype, order="F"
)

for t, xfm_t in enumerate(transform):
# Map the input coordinates on to timepoint t of the target (moving)
ycoords = xfm_t.map(xcoords)[..., : _ref.ndim]
for t in range(n_resamplings):
xfm_t = transform if n_resamplings == 1 else transform[t]

# Calculate corresponding voxel coordinates
yvoxels = ras2vox.map(ycoords)[..., : _ref.ndim]
if targets is None:
targets = ImageGrid(spatialimage).index( # data should be an image
_as_homogeneous(xfm_t.map(_ref.ndcoords.T), dim=_ref.ndim)
)

# Interpolate
resampled[..., t] = ndi.map_coordinates(
(
dataobj
if dataobj is not None
data
if data is not None
else spatialimage.dataobj[..., t].astype(input_dtype, copy=False)
),
yvoxels.T,
targets,
output=output_dtype,
order=order,
mode=mode,
Expand All @@ -156,19 +164,17 @@ def apply(
)

else:
# For model-based nonlinear transforms, generate the corresponding dense field
if hasattr(transform, "to_field") and callable(transform.to_field):
targets = ImageGrid(spatialimage).index(
_as_homogeneous(
transform.to_field(reference=reference).map(_ref.ndcoords.T),
dim=_ref.ndim,
)
)
else:
data = np.asanyarray(spatialimage.dataobj, dtype=input_dtype)

if targets is None:
targets = ImageGrid(spatialimage).index( # data should be an image
_as_homogeneous(transform.map(_ref.ndcoords.T), dim=_ref.ndim)
)

# Cast 3D data into 4D if 4D nonsequential transform
if data_nvols == 1 and xfm_nvols > 1:
data = data[..., np.newaxis]

if transform.ndim == 4:
targets = _as_homogeneous(targets.reshape(-2, targets.shape[0])).T

Expand Down

0 comments on commit 79e5cad

Please sign in to comment.