Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX: Remove unsafe cast during TransformBase.apply() #189

Merged
merged 1 commit into from
Mar 4, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions nitransforms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,15 @@ def apply(
slightly blurred if *order > 1*, unless the input is prefiltered,
i.e. it is the result of calling the spline filter on the original
input.
output_dtype: dtype specifier, optional
The dtype of the returned array or image, if specified.
If ``None``, the default behavior is to use the effective dtype of
the input image. If slope and/or intercept are defined, the effective
dtype is float64, otherwise it is equivalent to the input image's
``get_data_dtype()`` (on-disk type).
If ``reference`` is defined, then the return value is an image, with
a data array of the effective dtype but with the on-disk dtype set to
the input image's on-disk dtype.

Returns
-------
Expand All @@ -279,11 +288,7 @@ def apply(
if isinstance(spatialimage, (str, Path)):
spatialimage = _nbload(str(spatialimage))

data = np.asanyarray(
spatialimage.dataobj,
dtype=spatialimage.get_data_dtype()
)
output_dtype = output_dtype or data.dtype
data = np.asanyarray(spatialimage.dataobj)
targets = ImageGrid(spatialimage).index( # data should be an image
_as_homogeneous(self.map(_ref.ndcoords.T), dim=_ref.ndim)
)
Expand All @@ -302,9 +307,9 @@ def apply(
hdr = None
if _ref.header is not None:
hdr = _ref.header.copy()
hdr.set_data_dtype(output_dtype)
hdr.set_data_dtype(output_dtype or spatialimage.get_data_dtype())
moved = spatialimage.__class__(
resampled.reshape(_ref.shape).astype(output_dtype),
resampled.reshape(_ref.shape),
_ref.affine,
hdr,
)
Expand Down