diff --git a/nitransforms/nonlinear.py b/nitransforms/nonlinear.py index 0d3e9956..e27f467a 100644 --- a/nitransforms/nonlinear.py +++ b/nitransforms/nonlinear.py @@ -30,6 +30,11 @@ class DenseFieldTransform(TransformBase): __slots__ = ("_field", "_deltas") + @property + def ndim(self): + """Access the dimensions of this Desne Field Transform.""" + return self._field.ndim - 1 + def __init__(self, field=None, is_deltas=True, reference=None): """ Create a dense field transform. @@ -82,11 +87,10 @@ def __init__(self, field=None, is_deltas=True, reference=None): "Reference is not a spatial image" ) - ndim = self._field.ndim - 1 - if self._field.shape[-1] != ndim: + if self._field.shape[-1] != self.ndim: raise TransformError( "The number of components of the field (%d) does not match " - "the number of dimensions (%d)" % (self._field.shape[-1], ndim) + "the number of dimensions (%d)" % (self._field.shape[-1], self.ndim) ) if is_deltas: @@ -240,6 +244,12 @@ class BSplineFieldTransform(TransformBase): __slots__ = ['_coeffs', '_knots', '_weights', '_order', '_moving'] + @property + def ndim(self): + """Access the dimensions of this BSpline.""" + #return ndim = self._coeffs.shape[-1] + return self._coeffs.ndim - 1 + def __init__(self, coefficients, reference=None, order=3): """Create a smooth deformation field using B-Spline basis.""" super().__init__() @@ -267,14 +277,12 @@ def to_field(self, reference=None, dtype="float32"): if _ref is None: raise TransformError("A reference must be defined") - ndim = self._coeffs.shape[-1] - if self._weights is None: self._weights = grid_bspline_weights(_ref, self._knots) - field = np.zeros((_ref.npoints, ndim)) + field = np.zeros((_ref.npoints, self.ndim)) - for d in range(ndim): + for d in range(self.ndim): # 1 x Nvox : (1 x K) @ (K x Nvox) field[:, d] = self._coeffs[..., d].reshape(-1) @ self._weights @@ -282,51 +290,6 @@ def to_field(self, reference=None, dtype="float32"): field.astype(dtype).reshape(*_ref.shape, -1), reference=_ref ) - def apply( - self, - spatialimage, - reference=None, - order=3, - mode="constant", - cval=0.0, - prefilter=True, - output_dtype=None, - ): - """Apply a B-Spline transform on input data.""" - - _ref = ( - self.reference if reference is None else - SpatialReference.factory(_ensure_image(reference)) - ) - spatialimage = _ensure_image(spatialimage) - - # If locations to be interpolated are not on a grid, run map() - #import pdb; pdb.set_trace() - if not isinstance(_ref, ImageGrid): - return apply( - super(), - spatialimage, - reference=_ref, - output_dtype=output_dtype, - order=order, - mode=mode, - cval=cval, - prefilter=prefilter, - - ) - - # If locations to be interpolated are on a grid, generate a displacements field - return apply( - self.to_field(reference=reference), - spatialimage, - reference=reference, - order=order, - mode=mode, - cval=cval, - prefilter=prefilter, - output_dtype=output_dtype, - ) - def map(self, x, inverse=False): r""" Apply the transformation to a list of physical coordinate points. diff --git a/nitransforms/resampling.py b/nitransforms/resampling.py index 7cbdd9b8..942ab07c 100644 --- a/nitransforms/resampling.py +++ b/nitransforms/resampling.py @@ -87,16 +87,21 @@ def apply( spatialimage = _nbload(str(spatialimage)) data = np.asanyarray(spatialimage.dataobj) - - targets = ImageGrid(spatialimage).index( # data should be an image - _as_homogeneous(transform.map(_ref.ndcoords.T), dim=_ref.ndim) - ) if data.ndim == 4 and data.shape[-1] != len(transform): raise ValueError("The fourth dimension of the data does not match the tranform's shape.") if data.ndim < transform.ndim: data = data[..., np.newaxis] + + if hasattr(transform, 'to_field') and callable(transform.to_field): + targets = ImageGrid(spatialimage).index( + _as_homogeneous(transform.to_field(reference=reference).map(_ref.ndcoords.T), dim=_ref.ndim) + ) + else: + targets = ImageGrid(spatialimage).index( # data should be an image + _as_homogeneous(transform.map(_ref.ndcoords.T), dim=_ref.ndim) + ) if transform.ndim == 4: targets = _as_homogeneous(targets.reshape(-2, targets.shape[0])).T diff --git a/nitransforms/tests/test_nonlinear.py b/nitransforms/tests/test_nonlinear.py index dd4cbf93..4a802b54 100644 --- a/nitransforms/tests/test_nonlinear.py +++ b/nitransforms/tests/test_nonlinear.py @@ -97,13 +97,14 @@ def test_bsplines_references(testdata_path): ).to_field() with pytest.raises(TransformError): - BSplineFieldTransform( - testdata_path / "someones_bspline_coefficients.nii.gz" - ).apply(testdata_path / "someones_anatomy.nii.gz") + apply( + BSplineFieldTransform(testdata_path / "someones_bspline_coefficients.nii.gz"), + testdata_path / "someones_anatomy.nii.gz", + ) - BSplineFieldTransform( - testdata_path / "someones_bspline_coefficients.nii.gz" - ).apply( + apply( + BSplineFieldTransform( + testdata_path / "someones_bspline_coefficients.nii.gz"), testdata_path / "someones_anatomy.nii.gz", reference=testdata_path / "someones_anatomy.nii.gz" )