Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multiprocess fitting script #80

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
217 changes: 138 additions & 79 deletions WrapImage/nifti_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,64 +1,13 @@
import argparse
import json
import os
import nibabel as nib
from utilities.process.file_io import read_nifti_file, read_bval_file, read_bvec_file, save_nifti_file
from utilities.process.diffusion_utils import find_directions, find_shells, normalize_series
from src.wrappers.OsipiBase import OsipiBase
import numpy as np
from tqdm import tqdm
from tqdm.contrib.concurrent import process_map
from functools import partial


def read_nifti_file(input_file):
"""
For reading the 4d nifti image
"""
nifti_img = nib.load(input_file)
return nifti_img.get_fdata(), nifti_img.header

def read_json_file(json_file):
"""
For reading the json file
"""

if not os.path.exists(json_file):
raise FileNotFoundError(f"File '{json_file}' not found.")

with open(json_file, "r") as f:
try:
json_data = json.load(f)
except json.JSONDecodeError as e:
raise ValueError(f"Error decoding JSON in file '{json_file}': {e}")

return json_data

def read_bval_file(bval_file):
"""
For reading the bval file
"""
if not os.path.exists(bval_file):
raise FileNotFoundError(f"File '{bval_file}' not found.")

bval_data = np.genfromtxt(bval_file, dtype=float)
return bval_data

def read_bvec_file(bvec_file):
"""
For reading the bvec file
"""
if not os.path.exists(bvec_file):
raise FileNotFoundError(f"File '{bvec_file}' not found.")

bvec_data = np.genfromtxt(bvec_file)
bvec_data = np.transpose(bvec_data) # Transpose the array
return bvec_data

def save_nifti_file(data, output_file, affine=None, **kwargs):
"""
For saving the 3d nifti images of the output of the algorithm
"""
if affine is None:
affine = np.eye(data.ndim + 1)
output_img = nib.nifti1.Nifti1Image(data, affine , **kwargs)
nib.save(output_img, output_file)

def loop_over_first_n_minus_1_dimensions(arr):
"""
Expand All @@ -75,52 +24,162 @@ def loop_over_first_n_minus_1_dimensions(arr):
flat_view = arr[idx].flatten()
yield idx, flat_view

def generate_data(data, bvals, b0_indices, groups, total_iteration):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1: Is b0_indices something you want to tell the algorithm, or is it more fool proof if we generate them locally from bvals?
2: Maybe we need to start documenting some code; for example, to me "generate_data" sounds like it will generate data, but it is unclear why it needs data for this (if there is already data, why generate it?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The source could come from anywhere, this splits the data up based on the bvals and directions. The b0 indices are included in order to give the b0 data to every fit.

True, I should update it with some more documentation.

num_directions = groups.shape[1]
data = data.reshape(total_iteration, -1)
for idx in range(total_iteration):
for dir in range(num_directions):
# print('yielding')
yield (data[idx, groups[:, dir]].flatten(), bvals[:, groups[:, dir]].ravel(), b0_indices[:, groups[:, dir]].ravel())

def osipi_fit(fitfunc, data_bvals):
data, bvals, b0_indices = data_bvals
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it make more sense to use dictionaries for these?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think Ivan has a PR for that so maybe there's a conflict between these. I guess either his or mine should get merged first and then the other will need to update the other.

data = normalize_series(data, b0_indices)
# print(f'data.shape {data.shape} data {data} bvals {bvals}')
return fitfunc(data, bvals)

# def osipi_fit(fitfunc, bvals, data, f_image, Dp_image, D_image, index):
# bval_index = len(f_image) % len(bvals)
# print(f'data.shape {data.shape} index {index} data[index] {data[index]} bvals.shape {bvals.shape} bval_index {bval_index} bvals {bvals[:, bval_index]}')
# [f_fit, Dp_fit, D_fit] = fitfunc(data[index], bvals[:, bval_index])
# f_image[index] = f_fit
# Dp_image[index] = Dp_fit
# D_image[index] = D_fit



if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Read a 4D NIfTI phantom file along with BIDS JSON, b-vector, and b-value files.")
parser.add_argument("input_file", type=str, help="Path to the input 4D NIfTI file.")
parser.add_argument("bvec_file", type=str, help="Path to the b-vector file.")
parser.add_argument("bval_file", type=str, help="Path to the b-value file.")
parser.add_argument("--affine", type=float, nargs="+", help="Affine matrix for NIfTI image.")
parser.add_argument("--nproc", type=int, default=0, help="Number of processes to use, -1 disabled multprocessing, 0 automatically determines number, >0 uses that number.")
parser.add_argument("--group_directions", default=False, action="store_true", help="Fit all directions together")
parser.add_argument("--affine", type=float, default=None, nargs="+", help="Affine matrix for NIfTI image.")
parser.add_argument("--algorithm", type=str, default="OJ_GU_seg", help="Select the algorithm to use.")
parser.add_argument("--algorithm_args", nargs=argparse.REMAINDER, help="Additional arguments for the algorithm.")
parser.add_argument("--algorithm_args", default={}, nargs=argparse.REMAINDER, help="Additional arguments for the algorithm.")


args = parser.parse_args()

try:
# Read the 4D NIfTI file
data, _ = read_nifti_file(args.input_file)
data = data[0::4, 0::4, 0::2, :]
print(f'data.shape {data.shape}')

# Read the b-vector, and b-value files
bvecs = read_bvec_file(args.bvec_file)
bvals = read_bval_file(args.bval_file)
# print(f'bvals.size {bvals.shape} bvecs.size {bvecs.shape}')
print(bvals)
print(bvecs)
shells, bval_indices, b0_indices = find_shells(bvals)
num_b0 = np.count_nonzero(b0_indices)
print(shells)
print(bval_indices)
# print(b0_indices)
# print(bvecs)

# print('vectors')
vectors, bvec_indices, groups = find_directions(bvecs, b0_indices)
print(vectors)
print(bvec_indices)
print(f'groups {groups}')

# split_bval_bvec(bvec_indices, num_vectors)
# quit()


# Pass additional arguments to the algorithm

fit = OsipiBase(algorithm=args.algorithm)
f_image = []
Dp_image = []
D_image = []
fit = OsipiBase(algorithm=args.algorithm, **args.algorithm_args)

# n = data.ndim
output_shape = list(data.shape[:-1])
# if args.group_directions:
# input_data = data
# input_bvals = np.atleast_2d(bvals)
# else:
# num_directions = groups.shape[1]
# measurements = np.count_nonzero(groups[:, 0])
# print(f"group_length {num_directions}")
# input_shape = output_shape.copy()

# print(f"groups[:, 0] {groups[:,0]} {np.count_nonzero(groups[:, 0])}")
# input_shape.append(num_directions)
# input_shape.append(measurements)
# output_shape.append(num_directions)
# print(f"input_shape {input_shape}")
# input_data = np.zeros(input_shape)
# input_bvals = np.zeros([measurements, num_directions])
# for group_idx in range(num_directions):
# print(f"group {group_idx} {groups[:, group_idx]}")
# input_data[..., group_idx, :] = data[..., groups[:, group_idx]]
# input_bvals[:, group_idx] = bvals[groups[:, group_idx]]
# if args.group_directions:
# input_data = data
# input_bvals = np.atleast_2d(bvals)
# else:
input_data = data
input_bvals = np.atleast_2d(bvals)
b0_indices = np.atleast_2d(b0_indices)
print(f"data.shape {data.shape}")
print(f"input_data.shape {input_data.shape}")




voxel_iteration = np.prod(output_shape)
group_iteration = groups.shape[1]
total_iteration = voxel_iteration * group_iteration
output_shape.append(group_iteration)
f_image = np.zeros(output_shape)
Dp_image = np.zeros(output_shape)
D_image = np.zeros(output_shape)
print(f_image.shape)

# This is necessary for the tqdm to display progress bar.
n = data.ndim
total_iteration = np.prod(data.shape[:n-1])
for idx, view in tqdm(loop_over_first_n_minus_1_dimensions(data), desc=f"{args.algorithm} is fitting", dynamic_ncols=True, total=total_iteration):
[f_fit, Dp_fit, D_fit] = fit.osipi_fit(view, bvals)
f_image.append(f_fit)
Dp_image.append(Dp_fit)
D_image.append(D_fit)

# Convert lists to NumPy arrays
f_image = np.array(f_image)
Dp_image = np.array(Dp_image)
D_image = np.array(D_image)

# Reshape arrays if needed
f_image = f_image.reshape(data.shape[:data.ndim-1])
Dp_image = Dp_image.reshape(data.shape[:data.ndim-1])
D_image = D_image.reshape(data.shape[:data.ndim-1])

# total_iteration = np.prod(data.shape[:n-1])
print(f'input_bvals {input_bvals}')
print(f'voxel_iteration {voxel_iteration} input_data.shape {input_data.shape}')
# print(f'input_data[5000] {input_data.reshape(total_iteration, -1)[5000]}')


# fit_partial = partial(osipi_fit, fit.osipi_fit, input_bvals, input_data.reshape(total_iteration, -1), f_image.reshape(total_iteration), Dp_image.reshape(total_iteration), D_image.reshape(total_iteration))
fit_partial = partial(osipi_fit, fit.osipi_fit)


if args.nproc >= 0:
print('multiprocess fitting')
gd = generate_data(input_data, input_bvals, b0_indices, groups, voxel_iteration)
map_args = [fit_partial, gd]
chunksize = round(total_iteration / args.nproc) if args.nproc > 0 else round(total_iteration / 128)
print(f'chunksize {chunksize}')
map_kwargs = {'desc':f"{args.algorithm} is fitting", 'dynamic_ncols':True, 'total':total_iteration, 'chunksize':chunksize}
if args.nproc > 0:
map_kwargs['max_workers'] = args.nproc
result = process_map(*map_args, **map_kwargs)
output = np.asarray(result)
print(f'output.shape {output.shape}')
output = output.reshape([*output_shape, 3])
f_image = output[..., 0]
print(f'f_img.shape {f_image.shape}')
Dp_image = output[..., 1]
D_image = output[..., 2]
# print(result)
# if args.nproc == 0: # TODO: can this be done more elegantly, I just want to omit a single parameter here
# process_map(fit_partial, range(total_iteration), desc=f"{args.algorithm} is fitting", dynamic_ncols=True, total=total_iteration)
# else:
# process_map(fit_partial, range(total_iteration), max_workers=args.nproc, desc=f"{args.algorithm} is fitting", dynamic_ncols=True, total=total_iteration)
else:
for idx, view in tqdm(loop_over_first_n_minus_1_dimensions(data), desc=f"{args.algorithm} is fitting", dynamic_ncols=True, total=total_iteration):
[f_fit, Dp_fit, D_fit] = fit.osipi_fit(view, bvals)
f_image[idx] = f_fit
Dp_image[idx] = Dp_fit
D_image[idx] = D_fit

print("finished fitting")

save_nifti_file(f_image, "f.nii.gz", args.affine)
save_nifti_file(Dp_image, "dp.nii.gz", args.affine)
Expand Down
46 changes: 46 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
import pathlib
import json
import csv
import tempfile
import os
import random
import numpy as np
# import datetime


Expand Down Expand Up @@ -178,3 +182,45 @@ def data_list(filename):
bvals = bvals['bvalues']
for name, data in all_data.items():
yield name, bvals, data

@pytest.fixture
def bval_bvec_info():
shells = [0, 10, 20, 50, 100, 200, 500, 1000]
# random.shuffle(shells)
bvals = np.concatenate((shells, random.choices(shells, k=10)), axis=0)

vecs = [[0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1], [0.707, 0.707, 0], [0.5, 0.5, 0.5], [0, 0.707, 0.707], [0.707, 0, 0.707]]
for idx in range(len(vecs)):
if np.linalg.norm(vecs[idx]) != 0:
vecs[idx] = vecs[idx]/np.linalg.norm(vecs[idx])
bvecs = []
vecs_idx = 1 # the first index is needed for the true output, but not needed here
for idx in range(len(bvals)):
if bvals[idx] == 0:
bvecs.append(np.asarray([0, 0, 0]))
elif vecs_idx < len(vecs):
bvecs.append(vecs[vecs_idx])
vecs_idx += 1
else:
bvecs.append(random.choice(vecs[1:])) # don't put a b0 in where it shouldn't be
print(f'raw bvals {bvals}')
print(f'raw bvecs {bvecs}')


with tempfile.NamedTemporaryFile(mode='wt', delete=False) as fp_val, tempfile.NamedTemporaryFile(mode='wt', delete=False) as fp_vec:
writer = csv.writer(fp_val, delimiter=' ')
for bval in bvals:
writer.writerow((bval,))
fp_val.close()
writer = csv.writer(fp_vec, delimiter=' ')
for bvec in bvecs:
writer.writerow(bvec)
fp_vec.close()
yield (fp_val.name, np.asarray(shells), bvals, fp_vec.name, np.asarray(vecs), np.asarray(bvecs))
os.unlink(fp_val.name) # use NamedTemporaryFile with delete_on_close with later python versions
os.path.exists(fp_val.name)
os.unlink(fp_vec.name) # use NamedTemporaryFile with delete_on_close with later python versions
os.path.exists(fp_vec.name)



Empty file added tests/utilities/__init__.py
Empty file.
Empty file.
54 changes: 54 additions & 0 deletions tests/utilities/unit_tests/test_diffusion_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import numpy as np
import numpy.testing as npt
from utilities.process.file_io import read_bval_file, read_bvec_file
from utilities.process.diffusion_utils import find_shells, find_directions, normalize_series


#TODO: test without b0
#TODO: test symmetry
def test_read_bval_bvec(bval_bvec_info):
bval_name, shells, bvals, bvec_name, directions, bvecs = bval_bvec_info
saved_bvals = read_bval_file(bval_name)
npt.assert_equal(bvals, np.asarray(saved_bvals))
saved_shells, bval_indices, b0 = find_shells(saved_bvals)
npt.assert_equal(shells, saved_shells, "Shells do not match")
npt.assert_equal(saved_bvals, [saved_shells[index] for index in bval_indices], "Bvalue indices are incorrect")

saved_bvecs = read_bvec_file(bvec_name)
npt.assert_allclose(np.asarray(bvecs), np.asarray(saved_bvecs), err_msg="Incorrectly saved bvectors")
vectors, bvec_indices, groups = find_directions(saved_bvecs, b0)
assert vectors.shape[0] == groups.shape[1] + 1, "Number of vectors is correct"
assert vectors.shape == np.asarray(directions).shape, "Number of elements in directions does not match"
directions_set = set()
for direction in directions:
directions_set.add(tuple(direction))
vectors_set = set()
for vector in vectors:
vectors_set.add(tuple(vector))
assert directions_set == vectors_set, "Elements in directions does not match"
npt.assert_equal(saved_bvecs, [vectors[index] for index in bvec_indices], "Bvector indices are incorrect")

def test_normalization():
original = np.atleast_2d([[10, 10], [10, 10], [5, 5], [5, 5]]).T

indices = [True, False, False, False]
updated = normalize_series(original.copy(), indices)
npt.assert_allclose(original / 10, updated, err_msg="Normalization with 1 point failed")

indices = [True, True, False, False]
updated = normalize_series(original.copy(), indices)
npt.assert_allclose(original / 10, updated, err_msg="Normalization with 2 points failed")

indices = [False, True, True, False]
updated = normalize_series(original.copy(), indices)
npt.assert_allclose(original / 7.5, updated, err_msg="Normalization with 2 different points failed")

indices = [False, False, False, True]
updated = normalize_series(original.copy(), indices)
npt.assert_allclose(original / 5, updated, err_msg="Normalization with 1 final point failed")

original = np.asarray([10, 5])
indices = [True, False]

updated = normalize_series(original.copy(), indices)
npt.assert_allclose(original / 10, updated, err_msg="Normalization of 1D failed")
23 changes: 23 additions & 0 deletions tests/utilities/unit_tests/test_file_io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import tempfile
import os
import numpy as np
import numpy.testing as npt
from utilities.process.file_io import save_nifti_file, read_nifti_file, read_bval_file, read_bvec_file


def test_nifti_read_write():
with tempfile.TemporaryDirectory() as tmp:
path = os.path.join(tmp, 'my_nifti.nii.gz')
data = np.random.rand(7, 8, 9)
save_nifti_file(data, path)
assert os.path.exists(path), "Nifti file does not exist"
saved_data, saved_hdr = read_nifti_file(path)
npt.assert_equal(data, saved_data, "Nifti data does not match")

def test_read_bval_bvec(bval_bvec_info):
bval_name, shells, bvals, bvec_name, directions, bvecs = bval_bvec_info
assert bvecs.shape[1] == 3, "Bvec input is not Nx3"
saved_bvals = read_bval_file(bval_name)
npt.assert_equal(bvals, np.asarray(saved_bvals), "Bvalues do not match")
saved_bvecs = read_bvec_file(bvec_name)
npt.assert_allclose(bvecs, np.asarray(saved_bvecs), err_msg="Bvectors do not match")
Empty file added utilities/process/__init__.py
Empty file.
Loading
Loading