Skip to content

Commit

Permalink
FIX: Offsource Apply
Browse files Browse the repository at this point in the history
Apply function offsourced. Tests:  139 passed, 163 Skipped, 15 Warnings
  • Loading branch information
Julien Marabotto authored and Julien Marabotto committed May 2, 2024
1 parent ab28efc commit aed5237
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 62 deletions.
67 changes: 15 additions & 52 deletions nitransforms/nonlinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ class DenseFieldTransform(TransformBase):

__slots__ = ("_field", "_deltas")

@property
def ndim(self):
"""Access the dimensions of this Desne Field Transform."""
return self._field.ndim - 1

def __init__(self, field=None, is_deltas=True, reference=None):
"""
Create a dense field transform.
Expand Down Expand Up @@ -82,11 +87,10 @@ def __init__(self, field=None, is_deltas=True, reference=None):
"Reference is not a spatial image"
)

ndim = self._field.ndim - 1
if self._field.shape[-1] != ndim:
if self._field.shape[-1] != self.ndim:
raise TransformError(
"The number of components of the field (%d) does not match "
"the number of dimensions (%d)" % (self._field.shape[-1], ndim)
"the number of dimensions (%d)" % (self._field.shape[-1], self.ndim)
)

if is_deltas:
Expand Down Expand Up @@ -240,6 +244,12 @@ class BSplineFieldTransform(TransformBase):

__slots__ = ['_coeffs', '_knots', '_weights', '_order', '_moving']

@property
def ndim(self):
"""Access the dimensions of this BSpline."""
#return ndim = self._coeffs.shape[-1]
return self._coeffs.ndim - 1

def __init__(self, coefficients, reference=None, order=3):
"""Create a smooth deformation field using B-Spline basis."""
super().__init__()
Expand Down Expand Up @@ -267,66 +277,19 @@ def to_field(self, reference=None, dtype="float32"):
if _ref is None:
raise TransformError("A reference must be defined")

ndim = self._coeffs.shape[-1]

if self._weights is None:
self._weights = grid_bspline_weights(_ref, self._knots)

field = np.zeros((_ref.npoints, ndim))
field = np.zeros((_ref.npoints, self.ndim))

for d in range(ndim):
for d in range(self.ndim):
# 1 x Nvox : (1 x K) @ (K x Nvox)
field[:, d] = self._coeffs[..., d].reshape(-1) @ self._weights

return DenseFieldTransform(
field.astype(dtype).reshape(*_ref.shape, -1), reference=_ref
)

def apply(
self,
spatialimage,
reference=None,
order=3,
mode="constant",
cval=0.0,
prefilter=True,
output_dtype=None,
):
"""Apply a B-Spline transform on input data."""

_ref = (
self.reference if reference is None else
SpatialReference.factory(_ensure_image(reference))
)
spatialimage = _ensure_image(spatialimage)

# If locations to be interpolated are not on a grid, run map()
#import pdb; pdb.set_trace()
if not isinstance(_ref, ImageGrid):
return apply(
super(),
spatialimage,
reference=_ref,
output_dtype=output_dtype,
order=order,
mode=mode,
cval=cval,
prefilter=prefilter,

)

# If locations to be interpolated are on a grid, generate a displacements field
return apply(
self.to_field(reference=reference),
spatialimage,
reference=reference,
order=order,
mode=mode,
cval=cval,
prefilter=prefilter,
output_dtype=output_dtype,
)

def map(self, x, inverse=False):
r"""
Apply the transformation to a list of physical coordinate points.
Expand Down
13 changes: 9 additions & 4 deletions nitransforms/resampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,16 +87,21 @@ def apply(
spatialimage = _nbload(str(spatialimage))

data = np.asanyarray(spatialimage.dataobj)

targets = ImageGrid(spatialimage).index( # data should be an image
_as_homogeneous(transform.map(_ref.ndcoords.T), dim=_ref.ndim)
)

if data.ndim == 4 and data.shape[-1] != len(transform):
raise ValueError("The fourth dimension of the data does not match the tranform's shape.")

if data.ndim < transform.ndim:
data = data[..., np.newaxis]

if hasattr(transform, 'to_field') and callable(transform.to_field):
targets = ImageGrid(spatialimage).index(
_as_homogeneous(transform.to_field(reference=reference).map(_ref.ndcoords.T), dim=_ref.ndim)
)
else:
targets = ImageGrid(spatialimage).index( # data should be an image
_as_homogeneous(transform.map(_ref.ndcoords.T), dim=_ref.ndim)
)

if transform.ndim == 4:
targets = _as_homogeneous(targets.reshape(-2, targets.shape[0])).T
Expand Down
13 changes: 7 additions & 6 deletions nitransforms/tests/test_nonlinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,13 +97,14 @@ def test_bsplines_references(testdata_path):
).to_field()

with pytest.raises(TransformError):
BSplineFieldTransform(
testdata_path / "someones_bspline_coefficients.nii.gz"
).apply(testdata_path / "someones_anatomy.nii.gz")
apply(
BSplineFieldTransform(testdata_path / "someones_bspline_coefficients.nii.gz"),
testdata_path / "someones_anatomy.nii.gz",
)

BSplineFieldTransform(
testdata_path / "someones_bspline_coefficients.nii.gz"
).apply(
apply(
BSplineFieldTransform(
testdata_path / "someones_bspline_coefficients.nii.gz"),
testdata_path / "someones_anatomy.nii.gz",
reference=testdata_path / "someones_anatomy.nii.gz"
)
Expand Down

0 comments on commit aed5237

Please sign in to comment.