From a7d2939af520e35848707d5dc77a4ebd9d7c930a Mon Sep 17 00:00:00 2001 From: Oscar Esteban Date: Wed, 31 Jul 2024 11:38:07 +0200 Subject: [PATCH] enh: prepare code for easy parallelization with a process pool executor Resolves: #214. --- nitransforms/resampling.py | 41 +++++++++++++++++++++----------------- 1 file changed, 23 insertions(+), 18 deletions(-) diff --git a/nitransforms/resampling.py b/nitransforms/resampling.py index abfe2b71..bb1bb309 100644 --- a/nitransforms/resampling.py +++ b/nitransforms/resampling.py @@ -8,6 +8,7 @@ ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## """Resampling utilities.""" +from functools import partial from pathlib import Path import numpy as np from nibabel.loadsave import load as _nbload @@ -135,33 +136,37 @@ def apply( else None ) - # Order F ensures individual volumes are contiguous in memory - # Also matches NIfTI, making final save more efficient - resampled = np.zeros( - (len(ref_ndcoords), len(transform)), dtype=input_dtype, order="F" + map_coordinates = partial( + ndi.map_coordinates, + order=order, + mode=mode, + cval=cval, + prefilter=prefilter, ) - for t in range(n_resamplings): - xfm_t = transform if n_resamplings == 1 else transform[t] + def _apply_volume(index, data, transform, targets=None): + xfm_t = transform if n_resamplings == 1 else transform[index] if targets is None: targets = ImageGrid(spatialimage).index( # data should be an image _as_homogeneous(xfm_t.map(ref_ndcoords), dim=_ref.ndim) ) - # Interpolate - resampled[..., t] = ndi.map_coordinates( - ( - data - if data is not None - else spatialimage.dataobj[..., t].astype(input_dtype, copy=False) - ), - targets, - order=order, - mode=mode, - cval=cval, - prefilter=prefilter, + data_t = ( + data + if data is not None + else spatialimage.dataobj[..., index].astype(input_dtype, copy=False) ) + return map_coordinates(data_t, targets) + + # Order F ensures individual volumes are contiguous in memory + # Also matches NIfTI, making final save more efficient + resampled = np.zeros( + (len(ref_ndcoords), len(transform)), dtype=input_dtype, order="F" + ) + for t in range(n_resamplings): + # Interpolate + resampled[..., t] = _apply_volume(t, data, transform, targets=targets) else: data = np.asanyarray(spatialimage.dataobj, dtype=input_dtype)