Skip to content

Commit

Permalink
Merge pull request #308 from effigies/fix/grid_creation
Browse files Browse the repository at this point in the history
FIX: Calculate bspline grids separately from colocation matrices
  • Loading branch information
effigies authored Dec 6, 2022
2 parents 2350d4b + 6514401 commit 7898fac
Showing 1 changed file with 18 additions and 32 deletions.
50 changes: 18 additions & 32 deletions sdcflows/interfaces/bspline.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ class BSplineApprox(SimpleInterface):

def _run_interface(self, runtime):
from sklearn import linear_model as lm
from scipy.sparse import vstack as sparse_vstack

# Output name baseline
out_name = fname_presuffix(
Expand All @@ -147,21 +148,25 @@ def _run_interface(self, runtime):
else None
)

# Determine the shape of bspline coefficients
# This should not change with resizing, so do it first
bs_grids = [bspline_grid(fmapnii, control_zooms_mm=sp) for sp in self.inputs.bs_spacing]

need_resize = np.any(np.array(zooms) < self.inputs.zooms_min)
if need_resize:
from sdcflows.utils.tools import resample_to_zooms

zooms_min = np.maximum(zooms, self.inputs.zooms_min)
target_zooms = np.maximum(zooms, self.inputs.zooms_min)

LOGGER.info(
"Resampling image with resolution exceeding 'zooms_min' "
f"({'x'.join(str(s) for s in zooms)} → "
f"{'x'.join(str(s) for s in zooms_min)})."
f"{'x'.join(str(s) for s in target_zooms)})."
)
fmapnii = resample_to_zooms(fmapnii, zooms_min)
fmapnii = resample_to_zooms(fmapnii, target_zooms)

if masknii is not None:
masknii = resample_to_zooms(masknii, zooms_min)
masknii = resample_to_zooms(masknii, target_zooms)

data = fmapnii.get_fdata(dtype="float32")

Expand All @@ -171,9 +176,6 @@ def _run_interface(self, runtime):
else np.asanyarray(masknii.dataobj) > 1e-4
)

# Convert spacings to numpy arrays
bs_spacing = [np.array(sp, dtype="float32") for sp in self.inputs.bs_spacing]

# Recenter the fieldmap
if self.inputs.recenter == "mode":
from scipy.stats import mode
Expand All @@ -187,13 +189,13 @@ def _run_interface(self, runtime):
elif self.inputs.recenter == "mean":
data -= np.mean(data[mask])

# Calculate collocation matrix & the spatial location of control points
colmat, bs_levels = _collocation_matrix(fmapnii, bs_spacing)
# Calculate collocation matrix from (possibly resized) image and knot grids
colmat = sparse_vstack(grid_bspline_weights(fmapnii, grid) for grid in bs_grids).T.tocsr()

bs_levels_str = ['x'.join(str(s) for s in level.shape) for level in bs_levels]
bs_levels_str[-1] = f"and {bs_levels_str[-1]}"
bs_grids_str = ['x'.join(str(s) for s in grid.shape) for grid in bs_grids]
bs_grids_str[-1] = f"and {bs_grids_str[-1]}"
LOGGER.info(
f"Approximating B-Splines grids ({', '.join(bs_levels_str)} [knots]) on a grid of "
f"Approximating B-Splines grids ({', '.join(bs_grids_str)} [knots]) on a grid of "
f"{'x'.join(str(s) for s in fmapnii.shape)} ({np.prod(fmapnii.shape)}) voxels,"
f" of which {mask.sum()} fall within the mask."
)
Expand All @@ -205,7 +207,7 @@ def _run_interface(self, runtime):
# Store coefficients
index = 0
self._results["out_coeff"] = []
for i, bsl in enumerate(bs_levels):
for i, bsl in enumerate(bs_grids):
n = bsl.dataobj.size
out_level = out_name.replace("_field.", f"_coeff{i:03}.")
bsl.__class__(
Expand All @@ -226,7 +228,9 @@ def _run_interface(self, runtime):
np.ones_like(fmapnii.dataobj, dtype=bool) if masknii is None
else np.asanyarray(nb.load(self.inputs.in_mask).dataobj) > 1e-4
)
colmat, _ = _collocation_matrix(fmapnii, bs_spacing)
colmat = sparse_vstack(
grid_bspline_weights(fmapnii, grid) for grid in bs_grids
).T.tocsr()

regressors = colmat[mask.reshape(-1), :]
interp_data = np.zeros_like(data)
Expand Down Expand Up @@ -509,24 +513,6 @@ def bspline_grid(img, control_zooms_mm=DEFAULT_ZOOMS_MM):
return img.__class__(np.zeros(bs_shape, dtype="float32"), bs_affine)


def _collocation_matrix(image, knot_spacing):
from scipy.sparse import vstack as sparse_vstack

bs_levels = []
weights = None
for sp in knot_spacing:
level = bspline_grid(image, control_zooms_mm=sp)
bs_levels.append(level)

weights = (
grid_bspline_weights(image, level)
if weights is None
else sparse_vstack((weights, grid_bspline_weights(image, level)))
)

return weights.T.tocsr(), bs_levels


def _fix_topup_fieldcoeff(in_coeff, fmap_ref, pe_dir, out_file=None):
"""Read in a coefficients file generated by TOPUP and fix x-form headers."""
from pathlib import Path
Expand Down

0 comments on commit 7898fac

Please sign in to comment.