diff --git a/nitransforms/resampling.py b/nitransforms/resampling.py index abfe2b71..1b76dba1 100644 --- a/nitransforms/resampling.py +++ b/nitransforms/resampling.py @@ -8,6 +8,8 @@ ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## """Resampling utilities.""" +from os import cpu_count +from concurrent.futures import ProcessPoolExecutor, as_completed from pathlib import Path import numpy as np from nibabel.loadsave import load as _nbload @@ -25,6 +27,25 @@ """Minimum number of volumes to automatically serialize 4D transforms.""" +def _apply_volume( + index, + data, + targets, + order=3, + mode="constant", + cval=0.0, + prefilter=True, +): + return index, ndi.map_coordinates( + data, + targets, + order=order, + mode=mode, + cval=cval, + prefilter=prefilter, + ) + + def apply( transform, spatialimage, @@ -135,34 +156,47 @@ def apply( else None ) - # Order F ensures individual volumes are contiguous in memory - # Also matches NIfTI, making final save more efficient - resampled = np.zeros( - (len(ref_ndcoords), len(transform)), dtype=input_dtype, order="F" - ) + if njobs is None: + njobs = cpu_count() - for t in range(n_resamplings): - xfm_t = transform if n_resamplings == 1 else transform[t] + with ProcessPoolExecutor(max_workers=min(njobs, n_resamplings)) as executor: + results = [] + for t in range(n_resamplings): + xfm_t = transform if n_resamplings == 1 else transform[t] - if targets is None: - targets = ImageGrid(spatialimage).index( # data should be an image - _as_homogeneous(xfm_t.map(ref_ndcoords), dim=_ref.ndim) - ) + if targets is None: + targets = ImageGrid(spatialimage).index( # data should be an image + _as_homogeneous(xfm_t.map(ref_ndcoords), dim=_ref.ndim) + ) - # Interpolate - resampled[..., t] = ndi.map_coordinates( - ( + data_t = ( data if data is not None else spatialimage.dataobj[..., t].astype(input_dtype, copy=False) - ), - targets, - order=order, - mode=mode, - cval=cval, - prefilter=prefilter, + ) + + results.append( + executor.submit( + _apply_volume, + t, + data_t, + targets, + order=order, + mode=mode, + cval=cval, + prefilter=prefilter, + ) + ) + + # Order F ensures individual volumes are contiguous in memory + # Also matches NIfTI, making final save more efficient + resampled = np.zeros( + (len(ref_ndcoords), len(transform)), dtype=input_dtype, order="F" ) + for future in as_completed(results): + t, resampled_t = future.result() + resampled[..., t] = resampled_t else: data = np.asanyarray(spatialimage.dataobj, dtype=input_dtype)