From b922fa5fe473d43d03f56afe2aff75fbe52a4f55 Mon Sep 17 00:00:00 2001 From: Oscar Esteban Date: Tue, 23 Jul 2024 12:56:36 +0200 Subject: [PATCH 1/5] wip: initiate implementation --- nitransforms/resampling.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/nitransforms/resampling.py b/nitransforms/resampling.py index 9de0d2d6..bc343231 100644 --- a/nitransforms/resampling.py +++ b/nitransforms/resampling.py @@ -7,6 +7,7 @@ # ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## """Resampling utilities.""" +from warnings import warn from pathlib import Path import numpy as np from nibabel.loadsave import load as _nbload @@ -19,6 +20,9 @@ _as_homogeneous, ) +SERIALIZE_VOLUME_WINDOW_WIDTH : int = 8 +"""Minimum number of volumes to automatically serialize 4D transforms.""" + def apply( transform, @@ -29,6 +33,8 @@ def apply( cval=0.0, prefilter=True, output_dtype=None, + serialize_nvols=SERIALIZE_VOLUME_WINDOW_WIDTH, + njobs=None, ): """ Apply a transformation to an image, resampling on the reference spatial object. @@ -89,14 +95,20 @@ def apply( spatialimage = _nbload(str(spatialimage)) data = np.asanyarray(spatialimage.dataobj) + data_nvols = 1 if data.ndim < 4 else data.shape[-1] + xfm_nvols = len(transforms) - if data.ndim == 4 and data.shape[-1] != len(transform): + if data_nvols == 1 and xfm_nvols > 1: + data = data[..., np.newaxis] + elif data_nvols != xfm_nvols: raise ValueError( "The fourth dimension of the data does not match the tranform's shape." ) - if data.ndim < transform.ndim: - data = data[..., np.newaxis] + serialize_nvols = serialize_nvols if serialize_nvols and serialize_nvols > 1 else np.inf + serialize_4d = max(data_nvols, xfm_nvols) > serialize_nvols + if serialize_4d: + warn("4D transforms serialization into 3D+t not implemented") # For model-based nonlinear transforms, generate the corresponding dense field if hasattr(transform, "to_field") and callable(transform.to_field): From 6064b8c056c2797b1d6dad3ab4a4365054291982 Mon Sep 17 00:00:00 2001 From: Julien Marabotto Date: Wed, 24 Jul 2024 11:19:56 +0200 Subject: [PATCH 2/5] enh: draft implementation of serialize 4d --- nitransforms/resampling.py | 87 ++++++++++++++++++++++++++------------ 1 file changed, 60 insertions(+), 27 deletions(-) diff --git a/nitransforms/resampling.py b/nitransforms/resampling.py index bc343231..ad37c768 100644 --- a/nitransforms/resampling.py +++ b/nitransforms/resampling.py @@ -7,12 +7,13 @@ # ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## """Resampling utilities.""" -from warnings import warn from pathlib import Path import numpy as np from nibabel.loadsave import load as _nbload +from nibabel.arrayproxy import get_obj_dtype from scipy import ndimage as ndi +from nitransforms.linear import Affine, get from nitransforms.base import ( ImageGrid, TransformError, @@ -96,45 +97,77 @@ def apply( data = np.asanyarray(spatialimage.dataobj) data_nvols = 1 if data.ndim < 4 else data.shape[-1] - xfm_nvols = len(transforms) + xfm_nvols = len(transform) + assert xfm_nvols == transform.ndim == _ref.ndim if data_nvols == 1 and xfm_nvols > 1: data = data[..., np.newaxis] elif data_nvols != xfm_nvols: raise ValueError( - "The fourth dimension of the data does not match the tranform's shape." + "The fourth dimension of the data does not match the transform's shape." ) serialize_nvols = serialize_nvols if serialize_nvols and serialize_nvols > 1 else np.inf serialize_4d = max(data_nvols, xfm_nvols) > serialize_nvols if serialize_4d: - warn("4D transforms serialization into 3D+t not implemented") - - # For model-based nonlinear transforms, generate the corresponding dense field - if hasattr(transform, "to_field") and callable(transform.to_field): - targets = ImageGrid(spatialimage).index( - _as_homogeneous( - transform.to_field(reference=reference).map(_ref.ndcoords.T), - dim=_ref.ndim, + for t, xfm_t in enumerate(transform): + ras2vox = ~Affine(spatialimage.affine) + input_dtype = get_obj_dtype(spatialimage.dataobj) + output_dtype = output_dtype or input_dtype + + # Map the input coordinates on to timepoint t of the target (moving) + xcoords = _ref.ndcoords.astype("f4").T + ycoords = xfm_t.map(xcoords)[..., : _ref.ndim] + + # Calculate corresponding voxel coordinates + yvoxels = ras2vox.map(ycoords)[..., : _ref.ndim] + + # Interpolate + dataobj = ( + np.asanyarray(spatialimage.dataobj, dtype=input_dtype) + if spatialimage.ndim in (2, 3) + else None ) - ) + resampled[..., t] = ndi.map_coordinates( + ( + dataobj + if dataobj is not None + else spatialimage.dataobj[..., t].astype(input_dtype, copy=False) + ), + yvoxels.T, + output=output_dtype, + order=order, + mode=mode, + cval=cval, + prefilter=prefilter, + ) + else: - targets = ImageGrid(spatialimage).index( # data should be an image - _as_homogeneous(transform.map(_ref.ndcoords.T), dim=_ref.ndim) - ) + # For model-based nonlinear transforms, generate the corresponding dense field + if hasattr(transform, "to_field") and callable(transform.to_field): + targets = ImageGrid(spatialimage).index( + _as_homogeneous( + transform.to_field(reference=reference).map(_ref.ndcoords.T), + dim=_ref.ndim, + ) + ) + else: + targets = ImageGrid(spatialimage).index( # data should be an image + _as_homogeneous(transform.map(_ref.ndcoords.T), dim=_ref.ndim) + ) - if transform.ndim == 4: - targets = _as_homogeneous(targets.reshape(-2, targets.shape[0])).T - - resampled = ndi.map_coordinates( - data, - targets, - output=output_dtype, - order=order, - mode=mode, - cval=cval, - prefilter=prefilter, - ) + if transform.ndim == 4: + targets = _as_homogeneous(targets.reshape(-2, targets.shape[0])).T + + resampled = ndi.map_coordinates( + data, + targets, + output=output_dtype, + order=order, + mode=mode, + cval=cval, + prefilter=prefilter, + ) if isinstance(_ref, ImageGrid): # If reference is grid, reshape hdr = None From e47a4769b03c351a8e907e380e3dffd74e3a2955 Mon Sep 17 00:00:00 2001 From: Julien Marabotto Date: Thu, 25 Jul 2024 09:34:44 +0200 Subject: [PATCH 3/5] fix: passes more tests, more suggestions in progress --- nitransforms/resampling.py | 18 +++++++++++++++--- nitransforms/tests/test_base.py | 3 ++- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/nitransforms/resampling.py b/nitransforms/resampling.py index ad37c768..b9ca65b8 100644 --- a/nitransforms/resampling.py +++ b/nitransforms/resampling.py @@ -13,7 +13,7 @@ from nibabel.arrayproxy import get_obj_dtype from scipy import ndimage as ndi -from nitransforms.linear import Affine, get +from nitransforms.linear import Affine, LinearTransformsMapping from nitransforms.base import ( ImageGrid, TransformError, @@ -97,15 +97,27 @@ def apply( data = np.asanyarray(spatialimage.dataobj) data_nvols = 1 if data.ndim < 4 else data.shape[-1] - xfm_nvols = len(transform) - assert xfm_nvols == transform.ndim == _ref.ndim + if type(transform) == Affine or type(transform) == LinearTransformsMapping: + xfm_nvols = len(transform) + else: + xfm_nvols = transform.ndim + """ if data_nvols == 1 and xfm_nvols > 1: data = data[..., np.newaxis] elif data_nvols != xfm_nvols: raise ValueError( "The fourth dimension of the data does not match the transform's shape." ) + RESAMPLING FAILS. SUGGEST: + """ + if data.ndim < transform.ndim: + data = data[..., np.newaxis] + elif data_nvols > 1 and data_nvols != xfm_nvols: + import pdb; pdb.set_trace() + raise ValueError( + "The fourth dimension of the data does not match the transform's shape." + ) serialize_nvols = serialize_nvols if serialize_nvols and serialize_nvols > 1 else np.inf serialize_4d = max(data_nvols, xfm_nvols) > serialize_nvols diff --git a/nitransforms/tests/test_base.py b/nitransforms/tests/test_base.py index fb4be8d8..74bc3358 100644 --- a/nitransforms/tests/test_base.py +++ b/nitransforms/tests/test_base.py @@ -186,6 +186,7 @@ def test_SurfaceMesh(testdata_path): with pytest.raises(ValueError): SurfaceMesh(nb.load(img_path)) - + """ with pytest.raises(TypeError): SurfaceMesh(nb.load(shape_path)) + """ \ No newline at end of file From 1616a35bf454898a6ff95b4d2925b4496da5be81 Mon Sep 17 00:00:00 2001 From: Julien Marabotto Date: Thu, 25 Jul 2024 11:37:32 +0200 Subject: [PATCH 4/5] fix: pass tests --- nitransforms/resampling.py | 1 - nitransforms/tests/test_base.py | 3 +-- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/nitransforms/resampling.py b/nitransforms/resampling.py index b9ca65b8..c36750ef 100644 --- a/nitransforms/resampling.py +++ b/nitransforms/resampling.py @@ -114,7 +114,6 @@ def apply( if data.ndim < transform.ndim: data = data[..., np.newaxis] elif data_nvols > 1 and data_nvols != xfm_nvols: - import pdb; pdb.set_trace() raise ValueError( "The fourth dimension of the data does not match the transform's shape." ) diff --git a/nitransforms/tests/test_base.py b/nitransforms/tests/test_base.py index 74bc3358..fb4be8d8 100644 --- a/nitransforms/tests/test_base.py +++ b/nitransforms/tests/test_base.py @@ -186,7 +186,6 @@ def test_SurfaceMesh(testdata_path): with pytest.raises(ValueError): SurfaceMesh(nb.load(img_path)) - """ + with pytest.raises(TypeError): SurfaceMesh(nb.load(shape_path)) - """ \ No newline at end of file From 6292daf1d0f7dc56ae51d1d87a83fe827f72dd5c Mon Sep 17 00:00:00 2001 From: Julien Marabotto Date: Thu, 25 Jul 2024 13:44:11 +0200 Subject: [PATCH 5/5] fix: pass tests, serialization implemented --- nitransforms/resampling.py | 44 +++++++++++++++++++++----------------- 1 file changed, 24 insertions(+), 20 deletions(-) diff --git a/nitransforms/resampling.py b/nitransforms/resampling.py index c36750ef..52f831ef 100644 --- a/nitransforms/resampling.py +++ b/nitransforms/resampling.py @@ -102,15 +102,7 @@ def apply( xfm_nvols = len(transform) else: xfm_nvols = transform.ndim - """ - if data_nvols == 1 and xfm_nvols > 1: - data = data[..., np.newaxis] - elif data_nvols != xfm_nvols: - raise ValueError( - "The fourth dimension of the data does not match the transform's shape." - ) - RESAMPLING FAILS. SUGGEST: - """ + if data.ndim < transform.ndim: data = data[..., np.newaxis] elif data_nvols > 1 and data_nvols != xfm_nvols: @@ -119,26 +111,38 @@ def apply( ) serialize_nvols = serialize_nvols if serialize_nvols and serialize_nvols > 1 else np.inf - serialize_4d = max(data_nvols, xfm_nvols) > serialize_nvols + serialize_4d = max(data_nvols, xfm_nvols) >= serialize_nvols + if serialize_4d: - for t, xfm_t in enumerate(transform): - ras2vox = ~Affine(spatialimage.affine) - input_dtype = get_obj_dtype(spatialimage.dataobj) - output_dtype = output_dtype or input_dtype + # Avoid opening the data array just yet + input_dtype = get_obj_dtype(spatialimage.dataobj) + output_dtype = output_dtype or input_dtype + + # Prepare physical coordinates of input (grid, points) + xcoords = _ref.ndcoords.astype("f4").T + + # Invert target's (moving) affine once + ras2vox = ~Affine(spatialimage.affine) + dataobj = ( + np.asanyarray(spatialimage.dataobj, dtype=input_dtype) + if spatialimage.ndim in (2, 3) + else None + ) + # Order F ensures individual volumes are contiguous in memory + # Also matches NIfTI, making final save more efficient + resampled = np.zeros( + (xcoords.shape[0], len(transform)), dtype=output_dtype, order="F" + ) + + for t, xfm_t in enumerate(transform): # Map the input coordinates on to timepoint t of the target (moving) - xcoords = _ref.ndcoords.astype("f4").T ycoords = xfm_t.map(xcoords)[..., : _ref.ndim] # Calculate corresponding voxel coordinates yvoxels = ras2vox.map(ycoords)[..., : _ref.ndim] # Interpolate - dataobj = ( - np.asanyarray(spatialimage.dataobj, dtype=input_dtype) - if spatialimage.ndim in (2, 3) - else None - ) resampled[..., t] = ndi.map_coordinates( ( dataobj