Skip to content

Commit

Permalink
enh: prepare code for easy parallelization with a process pool executor
Browse files Browse the repository at this point in the history
Resolves: #214.
  • Loading branch information
oesteban committed Jul 31, 2024
1 parent 9f91e2f commit a7d2939
Showing 1 changed file with 23 additions and 18 deletions.
41 changes: 23 additions & 18 deletions nitransforms/resampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
"""Resampling utilities."""

from functools import partial
from pathlib import Path
import numpy as np
from nibabel.loadsave import load as _nbload
Expand Down Expand Up @@ -135,33 +136,37 @@ def apply(
else None
)

# Order F ensures individual volumes are contiguous in memory
# Also matches NIfTI, making final save more efficient
resampled = np.zeros(
(len(ref_ndcoords), len(transform)), dtype=input_dtype, order="F"
map_coordinates = partial(
ndi.map_coordinates,
order=order,
mode=mode,
cval=cval,
prefilter=prefilter,
)

for t in range(n_resamplings):
xfm_t = transform if n_resamplings == 1 else transform[t]
def _apply_volume(index, data, transform, targets=None):
xfm_t = transform if n_resamplings == 1 else transform[index]

if targets is None:
targets = ImageGrid(spatialimage).index( # data should be an image
_as_homogeneous(xfm_t.map(ref_ndcoords), dim=_ref.ndim)
)

# Interpolate
resampled[..., t] = ndi.map_coordinates(
(
data
if data is not None
else spatialimage.dataobj[..., t].astype(input_dtype, copy=False)
),
targets,
order=order,
mode=mode,
cval=cval,
prefilter=prefilter,
data_t = (
data
if data is not None
else spatialimage.dataobj[..., index].astype(input_dtype, copy=False)
)
return map_coordinates(data_t, targets)

# Order F ensures individual volumes are contiguous in memory
# Also matches NIfTI, making final save more efficient
resampled = np.zeros(
(len(ref_ndcoords), len(transform)), dtype=input_dtype, order="F"
)
for t in range(n_resamplings):
# Interpolate
resampled[..., t] = _apply_volume(t, data, transform, targets=targets)

else:
data = np.asanyarray(spatialimage.dataobj, dtype=input_dtype)
Expand Down

0 comments on commit a7d2939

Please sign in to comment.