Skip to content

Commit

Permalink
TEST: Parametrize all DerivativesDataSink tests
Browse files Browse the repository at this point in the history
  • Loading branch information
effigies committed Jul 30, 2024
1 parent b297129 commit 4489a68
Showing 1 changed file with 94 additions and 47 deletions.
141 changes: 94 additions & 47 deletions niworkflows/interfaces/tests/test_bids.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,35 @@
BOLD_PATH = "ds054/sub-100185/func/sub-100185_task-machinegame_run-01_bold.nii.gz"


def make_prep_and_save(
prep_interface,
base_directory,
out_path_base=None,
**kwargs,
):
prep = save = prep_interface(
**kwargs,
**({"out_path_base": out_path_base} if prep_interface == bintfs.DerivativesDataSink else {}),
)
if prep_interface is bintfs.DerivativesDataSink:
prep.inputs.base_directory = base_directory
else:
save = bintfs.SaveDerivative(base_directory=base_directory)

return prep, save


def connect_and_run_save(prep_result, save):
if prep_result.interface is bintfs.DerivativesDataSink:
return prep_result

save.inputs.in_file = prep_result.outputs.out_file
save.inputs.relative_path = prep_result.outputs.out_path
save.inputs.metadata = prep_result.outputs.out_meta

return save.run()


@pytest.mark.parametrize("interface", [bintfs.DerivativesDataSink, bintfs.PrepareDerivative])
@pytest.mark.parametrize("out_path_base", [None, "fmriprep"])
@pytest.mark.parametrize(
Expand Down Expand Up @@ -296,36 +325,26 @@ def test_DerivativesDataSink_build_path(
ds_inputs.append(str(fname))

base_directory = tmp_path / "output"
work_dir = tmp_path / "work"
base_directory.mkdir()
work_dir.mkdir()

prep = save = interface(
prep, save = make_prep_and_save(
interface,
base_directory=str(base_directory),
out_path_base=out_path_base,
in_file=ds_inputs,
source_file=source,
dismiss_entities=dismiss_entities,
**entities,
**({"out_path_base": out_path_base} if interface == bintfs.DerivativesDataSink else {}),
)
if interface == bintfs.DerivativesDataSink:
prep.inputs.base_directory = str(base_directory)
else:
save = bintfs.SaveDerivative(base_directory=str(base_directory))

if isinstance(expectation, type):
with pytest.raises(expectation):
prep.run()
return

prep_outputs = save_outputs = prep.run().outputs

if save is not prep:
save.inputs.in_file = prep_outputs.out_file
save.inputs.relative_path = prep_outputs.out_path
save.inputs.metadata = prep_outputs.out_meta
save_outputs = save.run().outputs
prep_result = prep.run()
save_result = connect_and_run_save(prep_result, save)

output = save_outputs.out_file
output = save_result.outputs.out_file
if isinstance(expectation, str):
expectation = [expectation]
output = [output]
Expand Down Expand Up @@ -363,7 +382,8 @@ def test_DerivativesDataSink_build_path(
assert sha1(Path(out).read_bytes()).hexdigest() == chksum


def test_DerivativesDataSink_dtseries_json(tmp_path):
@pytest.mark.parametrize("interface", [bintfs.DerivativesDataSink, bintfs.PrepareDerivative])
def test_DerivativesDataSink_dtseries_json(tmp_path, interface):
cifti_fname = str(tmp_path / "test.dtseries.nii")

axes = (nb.cifti2.SeriesAxis(start=0, step=2, size=20),
Expand All @@ -378,20 +398,22 @@ def test_DerivativesDataSink_dtseries_json(tmp_path):
source_file.parent.mkdir(parents=True)
source_file.touch()

dds = bintfs.DerivativesDataSink(
in_file=cifti_fname,
prep, save = make_prep_and_save(
interface,
base_directory=str(tmp_path),
out_path_base="",
in_file=cifti_fname,
source_file=str(source_file),
compress=False,
out_path_base="",
space="fsLR",
grayordinates="91k",
RepetitionTime=2.0,
)

res = dds.run()
prep_result = prep.run()
save_result = connect_and_run_save(prep_result, save)

out_path = Path(res.outputs.out_file)
out_path = Path(save_result.outputs.out_file)

assert out_path.name == "sub-01_task-rest_space-fsLR_bold.dtseries.nii"
old_sidecar = out_path.with_name("sub-01_task-rest_space-fsLR_bold.dtseries.json")
Expand All @@ -402,6 +424,7 @@ def test_DerivativesDataSink_dtseries_json(tmp_path):
assert "RepetitionTime" in json.loads(new_sidecar.read_text())


@pytest.mark.parametrize("interface", [bintfs.DerivativesDataSink, bintfs.PrepareDerivative])
@pytest.mark.parametrize(
"space, size, units, xcodes, zipped, fixed, data_dtype",
[
Expand All @@ -427,7 +450,7 @@ def test_DerivativesDataSink_dtseries_json(tmp_path):
],
)
def test_DerivativesDataSink_bold(
tmp_path, space, size, units, xcodes, zipped, fixed, data_dtype
tmp_path, interface, space, size, units, xcodes, zipped, fixed, data_dtype
):
fname = str(tmp_path / "source.nii") + (".gz" if zipped else "")

Expand All @@ -438,25 +461,29 @@ def test_DerivativesDataSink_bold(
nb.Nifti1Image(np.zeros(size), np.eye(4), hdr).to_filename(fname)

# BOLD derivative in T1w space
dds = bintfs.DerivativesDataSink(
prep, _ = make_prep_and_save(
interface,
base_directory=str(tmp_path),
keep_dtype=True,
data_dtype=data_dtype or Undefined,
desc="preproc",
source_file=BOLD_PATH,
space=space or Undefined,
in_file=fname,
).run()
)

nii = nb.load(dds.outputs.out_file)
assert dds.outputs.fixed_hdr == fixed
prep_result = prep.run()

nii = nb.load(prep_result.outputs.out_file)
assert prep_result.outputs.fixed_hdr == fixed
if data_dtype:
assert nii.get_data_dtype() == np.dtype(data_dtype)
assert int(nii.header["qform_code"]) == XFORM_CODES[space]
assert int(nii.header["sform_code"]) == XFORM_CODES[space]
assert nii.header.get_xyzt_units() == ("mm", "sec")


@pytest.mark.parametrize("interface", [bintfs.DerivativesDataSink, bintfs.PrepareDerivative])
@pytest.mark.parametrize(
"space, size, units, xcodes, fixed",
[
Expand All @@ -480,7 +507,7 @@ def test_DerivativesDataSink_bold(
(None, (30, 30, 30), (None, "sec"), (0, 0), [True]),
],
)
def test_DerivativesDataSink_t1w(tmp_path, space, size, units, xcodes, fixed):
def test_DerivativesDataSink_t1w(tmp_path, interface, space, size, units, xcodes, fixed):
fname = str(tmp_path / "source.nii.gz")

hdr = nb.Nifti1Header()
Expand All @@ -490,22 +517,26 @@ def test_DerivativesDataSink_t1w(tmp_path, space, size, units, xcodes, fixed):
nb.Nifti1Image(np.zeros(size), np.eye(4), hdr).to_filename(fname)

# BOLD derivative in T1w space
dds = bintfs.DerivativesDataSink(
prep, _ = make_prep_and_save(
interface,
base_directory=str(tmp_path),
keep_dtype=True,
desc="preproc",
source_file=T1W_PATH,
space=space or Undefined,
in_file=fname,
).run()
)

prep_result = prep.run()

nii = nb.load(dds.outputs.out_file)
assert dds.outputs.fixed_hdr == fixed
nii = nb.load(prep_result.outputs.out_file)
assert prep_result.outputs.fixed_hdr == fixed
assert int(nii.header["qform_code"]) == XFORM_CODES[space]
assert int(nii.header["sform_code"]) == XFORM_CODES[space]
assert nii.header.get_xyzt_units() == ("mm", "unknown")


@pytest.mark.parametrize("interface", [bintfs.DerivativesDataSink, bintfs.PrepareDerivative])
@pytest.mark.parametrize(
"source_file",
[
Expand All @@ -517,7 +548,7 @@ def test_DerivativesDataSink_t1w(tmp_path, space, size, units, xcodes, fixed):
@pytest.mark.parametrize("source_dtype", ["<i4", "<f4"])
@pytest.mark.parametrize("in_dtype", ["<i4", "<f4"])
def test_DerivativesDataSink_data_dtype_source(
tmp_path, source_file, source_dtype, in_dtype
tmp_path, interface, source_file, source_dtype, in_dtype
):

def make_empty_nii_with_dtype(fname, dtype):
Expand All @@ -539,19 +570,23 @@ def make_empty_nii_with_dtype(fname, dtype):
for s in source_file:
make_empty_nii_with_dtype(s, source_dtype)

dds = bintfs.DerivativesDataSink(
prep, save = make_prep_and_save(
interface,
base_directory=str(tmp_path),
data_dtype="source",
desc="preproc",
source_file=source_file,
in_file=in_file,
).run()
)

prep_result = prep.run()

nii = nb.load(dds.outputs.out_file)
nii = nb.load(prep_result.outputs.out_file)
assert nii.get_data_dtype() == np.dtype(source_dtype)


def test_DerivativesDataSink_fmapid(tmp_path):
@pytest.mark.parametrize("interface", [bintfs.DerivativesDataSink, bintfs.PrepareDerivative])
def test_DerivativesDataSink_fmapid(tmp_path, interface):
"""Ascertain #637 is not regressing."""
source_file = [
(tmp_path / s)
Expand All @@ -569,7 +604,8 @@ def test_DerivativesDataSink_fmapid(tmp_path):
in_file = tmp_path / "report.svg"
in_file.write_text("")

dds = bintfs.DerivativesDataSink(
prep, save = make_prep_and_save(
interface,
base_directory=str(tmp_path),
datatype="figures",
suffix="fieldmap",
Expand All @@ -579,12 +615,17 @@ def test_DerivativesDataSink_fmapid(tmp_path):
fmapid="auto00000",
source_file=[str(s.absolute()) for s in source_file],
in_file=str(in_file),
).run()
assert dds.outputs.out_file.endswith("sub-36_fmapid-auto00000_desc-pepolar_fieldmap.svg")
)

prep_result = prep.run()
save_result = connect_and_run_save(prep_result, save)

assert save_result.outputs.out_file.endswith("sub-36_fmapid-auto00000_desc-pepolar_fieldmap.svg")


@pytest.mark.parametrize("interface", [bintfs.DerivativesDataSink, bintfs.PrepareDerivative])
@pytest.mark.parametrize("dtype", ("i2", "u2", "f4"))
def test_DerivativesDataSink_values(tmp_path, dtype):
def test_DerivativesDataSink_values(tmp_path, interface, dtype):
# We use static checksums above, which ensures we don't break things, but
# pins the tests to specific values.
# Here we use random values, check that the values are preserved, and then
Expand All @@ -599,16 +640,19 @@ def test_DerivativesDataSink_values(tmp_path, dtype):
orig_data = np.asanyarray(nb.load(fname).dataobj)
expected = np.rint(orig_data) if dtype[0] in "iu" else orig_data

dds = bintfs.DerivativesDataSink(
prep, _ = make_prep_and_save(
interface,
base_directory=str(tmp_path),
keep_dtype=True,
data_dtype=dtype,
desc="preproc",
source_file=T1W_PATH,
in_file=fname,
).run()
)

prep_result = prep.run()

out_file = Path(dds.outputs.out_file)
out_file = Path(prep_result.outputs.out_file)

nii = nb.load(out_file)
assert np.allclose(nii.dataobj, expected)
Expand All @@ -617,14 +661,17 @@ def test_DerivativesDataSink_values(tmp_path, dtype):
out_file.unlink()

# Rerun to ensure determinism with non-zero data
dds = bintfs.DerivativesDataSink(
prep, _ = make_prep_and_save(
interface,
base_directory=str(tmp_path),
keep_dtype=True,
data_dtype=dtype,
desc="preproc",
source_file=T1W_PATH,
in_file=fname,
).run()
)

prep_result = prep.run()

assert sha1(out_file.read_bytes()).hexdigest() == checksum

Expand Down

0 comments on commit 4489a68

Please sign in to comment.