Skip to content

Commit

Permalink
StrictDataset: force named parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
samtygier-stfc committed Jul 24, 2024
1 parent cf7bf35 commit 4a975bd
Show file tree
Hide file tree
Showing 10 changed files with 83 additions and 61 deletions.
1 change: 1 addition & 0 deletions mantidimaging/core/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ class StrictDataset(BaseDataset):
dark_after: ImageStack | None = None

def __init__(self,
*,
sample: ImageStack,
flat_before: ImageStack | None = None,
flat_after: ImageStack | None = None,
Expand Down
10 changes: 5 additions & 5 deletions mantidimaging/core/data/test/strictdataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def setUp(self) -> None:

def test_attribute_not_set_returns_none(self):
sample = generate_images()
dataset = StrictDataset(sample)
dataset = StrictDataset(sample=sample)

self.assertIsNone(dataset.flat_before)
self.assertIsNone(dataset.flat_after)
Expand Down Expand Up @@ -251,7 +251,7 @@ def test_partially_incomplete_nexus_rotation_angles(self):
def test_get_stack_data_type_returns_sample(self):
sample = generate_images()
sample_id = sample.id
dataset = StrictDataset(sample)
dataset = StrictDataset(sample=sample)
self.assertEqual(_get_stack_data_type(sample_id, dataset), "Sample")

def test_get_stack_data_type_returns_flat_before(self):
Expand Down Expand Up @@ -279,15 +279,15 @@ def test_get_stack_data_type_returns_dark_after(self):
self.assertEqual(_get_stack_data_type(dark_after_id, dataset), "Dark After")

def test_get_stack_data_type_raises(self):
empty_ds = StrictDataset(generate_images())
empty_ds = StrictDataset(sample=generate_images())
with self.assertRaises(RuntimeError):
_get_stack_data_type("bad-id", empty_ds)

def test_processed_is_true(self):
ds = StrictDataset(generate_images())
ds = StrictDataset(sample=generate_images())
ds.sample.record_operation("", "")
self.assertTrue(ds.is_processed)

def test_processed_is_false(self):
ds = StrictDataset(generate_images())
ds = StrictDataset(sample=generate_images())
self.assertFalse(ds.is_processed)
34 changes: 21 additions & 13 deletions mantidimaging/core/io/test/io_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def test_nexus_simple_dataset_save(self):
sample.data *= 12
sample._projection_angles = sample.projection_angles()

sd = StrictDataset(sample)
sd = StrictDataset(sample=sample)
sd.sample.record_operation("", "")

path = "nexus/file/path"
Expand Down Expand Up @@ -243,7 +243,7 @@ def test_nexus_missing_projection_angles_save_as_zeros(self):
flat_before = th.generate_images(shape)
flat_before._projection_angles = flat_before.projection_angles()

sd = StrictDataset(sample, flat_before=flat_before)
sd = StrictDataset(sample=sample, flat_before=flat_before)
path = "nexus/file/path"
sample_name = "sample-name"

Expand All @@ -267,7 +267,11 @@ def test_nexus_complex_processed_dataset_save(self):
image_stacks.append(image_stack)
image_stack._projection_angles = image_stack.projection_angles()

sd = StrictDataset(*image_stacks)
sd = StrictDataset(sample=image_stacks[0],
flat_before=image_stacks[1],
flat_after=image_stacks[2],
dark_before=image_stacks[3],
dark_after=image_stacks[4])
sd.sample.record_operation("", "")

with h5py.File("nexus/file/path", "w", driver="core", backing_store=False) as nexus_file:
Expand Down Expand Up @@ -302,7 +306,11 @@ def test_nexus_unprocessed_dataset_save(self):
image_stacks.append(image_stack)
image_stack._projection_angles = image_stack.projection_angles()

sd = StrictDataset(*image_stacks)
sd = StrictDataset(sample=image_stacks[0],
flat_before=image_stacks[1],
flat_after=image_stacks[2],
dark_before=image_stacks[3],
dark_after=image_stacks[4])

with h5py.File("nexus/file/path", "w", driver="core", backing_store=False) as nexus_file:
saver._nexus_save(nexus_file, sd, "sample-name", True)
Expand All @@ -319,7 +327,7 @@ def test_nexus_unprocessed_dataset_save(self):
def test_h5py_os_error_returns(self, nexus_save_mock: mock.Mock, file_mock: mock.Mock):
file_mock.side_effect = OSError
with self.assertRaises(RuntimeError):
saver.nexus_save(StrictDataset(th.generate_images()), "path", "sample-name", True)
saver.nexus_save(StrictDataset(sample=th.generate_images()), "path", "sample-name", True)
nexus_save_mock.assert_not_called()

@mock.patch("mantidimaging.core.io.saver.h5py.File")
Expand All @@ -329,22 +337,22 @@ def test_failed_nexus_save_deletes_file(self, os_mock: mock.Mock, nexus_save_moc
nexus_save_mock.side_effect = OSError
save_path = "failed/save/path"
with self.assertRaises(RuntimeError):
saver.nexus_save(StrictDataset(th.generate_images()), save_path, "sample-name", True)
saver.nexus_save(StrictDataset(sample=th.generate_images()), save_path, "sample-name", True)
file_mock.return_value.close.assert_called_once()
os_mock.remove.assert_called_once_with(save_path)

@mock.patch("mantidimaging.core.io.saver.h5py.File")
@mock.patch("mantidimaging.core.io.saver._nexus_save")
def test_successful_nexus_save_closes_file(self, nexus_save_mock: mock.Mock, file_mock: mock.Mock):
saver.nexus_save(StrictDataset(th.generate_images()), "path", "sample-name", True)
saver.nexus_save(StrictDataset(sample=th.generate_images()), "path", "sample-name", True)
file_mock.return_value.close.assert_called_once()

@mock.patch("mantidimaging.core.io.saver._save_recon_to_nexus")
def test_save_recons_if_present(self, recon_save_mock: mock.Mock):
sample = _create_sample_with_filename()
sample._projection_angles = sample.projection_angles()

sd = StrictDataset(sample)
sd = StrictDataset(sample=sample)
sd.recons.data = [th.generate_images(), th.generate_images()]

with h5py.File("path", "w", driver="core", backing_store=False) as nexus_file:
Expand All @@ -353,7 +361,7 @@ def test_save_recons_if_present(self, recon_save_mock: mock.Mock):
self.assertEqual(recon_save_mock.call_count, len(sd.recons))

def test_save_process(self):
ds = StrictDataset(th.generate_images())
ds = StrictDataset(sample=th.generate_images())
process_path = "processed-data/process"
with h5py.File("path", "w", driver="core", backing_store=False) as nexus_file:
rotation_angle = nexus_file.create_dataset("rotation_angle", dtype="float")
Expand All @@ -373,7 +381,7 @@ def test_dont_save_recons_if_none_present(self, recon_save_mock: mock.Mock):
sample = th.generate_images()
sample._projection_angles = sample.projection_angles()

sd = StrictDataset(sample)
sd = StrictDataset(sample=sample)

with h5py.File("path", "w", driver="core", backing_store=False) as nexus_file:
saver._nexus_save(nexus_file, sd, "sample-name", True)
Expand All @@ -385,7 +393,7 @@ def test_save_recon_to_nexus(self):
sample = _create_sample_with_filename()
sample._projection_angles = sample.projection_angles()

sd = StrictDataset(sample)
sd = StrictDataset(sample=sample)

recon = th.generate_images(seed=2)
recon.metadata[TIMESTAMP] = None
Expand Down Expand Up @@ -424,7 +432,7 @@ def test_use_recon_date_from_image_stack(self):
sample = _create_sample_with_filename()
sample._projection_angles = sample.projection_angles()

sd = StrictDataset(sample)
sd = StrictDataset(sample=sample)

recon = th.generate_images(seed=2)
recon.name = recon_name = "Recon"
Expand Down Expand Up @@ -462,7 +470,7 @@ def test_raw_file_field(self):
self.sample_path)

def test_save_image_stacks_to_nexus_as_int(self):
ds = StrictDataset(th.generate_images())
ds = StrictDataset(sample=th.generate_images())

with h5py.File("path", "w", driver="core", backing_store=False) as nexus_file:
data = nexus_file.create_group("data")
Expand Down
2 changes: 1 addition & 1 deletion mantidimaging/eyes_tests/base_eyes.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def _load_strict_data_set(self, set_180: bool = False):
filename_group = FilenameGroup.from_file(Path(LOAD_SAMPLE))
filename_group.find_all_files()
image_stack = loader.load(filename_group, indices=Indices(0, 100, 2))
dataset = StrictDataset(image_stack)
dataset = StrictDataset(sample=image_stack)
image_stack.name = "Stack 1"
vis = self.imaging.presenter.create_strict_dataset_stack_windows(dataset)
self.imaging.presenter.create_strict_dataset_tree_view_items(dataset)
Expand Down
2 changes: 1 addition & 1 deletion mantidimaging/eyes_tests/spectrum_viewer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def _generate_spectrum_dataset(self):
sample_stack.name = "Sample Stack"
open_stack = generate_images(seed=666, shape=(20, 10, 10))
open_stack.name = "Open Beam Stack"
dataset = StrictDataset(sample_stack, flat_before=open_stack)
dataset = StrictDataset(sample=sample_stack, flat_before=open_stack)
vis = self.imaging.presenter.create_strict_dataset_stack_windows(dataset)
self.imaging.presenter.create_strict_dataset_tree_view_items(dataset)
self.imaging.presenter.model.add_dataset_to_model(dataset)
Expand Down
2 changes: 1 addition & 1 deletion mantidimaging/gui/windows/main/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def load(im_param: ImageParameters) -> ImageStack:
return loader.load_stack_from_image_params(im_param, progress, dtype=parameters.dtype)

sample = load(parameters.image_stacks[FILE_TYPES.SAMPLE])
ds = StrictDataset(sample)
ds = StrictDataset(sample=sample)
sample._is_sinograms = parameters.sinograms
sample.pixel_size = parameters.pixel_size

Expand Down
46 changes: 25 additions & 21 deletions mantidimaging/gui/windows/main/test/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def test_do_load_stack_sample_and_flat(self, dataset_mock: mock.Mock, load_mock:
mock.call(flat_after_mock, progress_mock, dtype=lp.dtype)
])

dataset_mock.assert_called_with(sample_images_mock)
dataset_mock.assert_called_with(sample=sample_images_mock)

ds_mock.set_stack.assert_has_calls([
mock.call(FILE_TYPES.FLAT_BEFORE, flatb_images_mock),
Expand Down Expand Up @@ -181,7 +181,7 @@ def test_do_load_stack_sample_and_dark_and_180deg(self, dataset_mock: mock.Mock,
mock.call(proj_180deg_mock, progress_mock, dtype=lp.dtype),
])

dataset_mock.assert_called_with(sample_images_mock)
dataset_mock.assert_called_with(sample=sample_images_mock)

ds_mock.set_stack.assert_has_calls([
mock.call(FILE_TYPES.DARK_BEFORE, darkb_images_mock),
Expand Down Expand Up @@ -223,7 +223,7 @@ def test_add_log_to_sample_no_stack(self, load_log: mock.Mock):
def test_add_180_deg_to_dataset(self, load: mock.Mock):
_180_file = "180 file"
dataset_id = "id"
self.model.datasets[dataset_id] = dataset_mock = StrictDataset(generate_images())
self.model.datasets[dataset_id] = dataset_mock = StrictDataset(sample=generate_images())
load.return_value = _180_stack = generate_images()
self.model.add_180_deg_to_dataset(dataset_id=dataset_id, _180_deg_file=_180_file)

Expand Down Expand Up @@ -316,7 +316,11 @@ def test_remove_dataset_from_model(self):
images = [generate_images() for _ in range(5)]
ids = [image_stack.id for image_stack in images]

ds = StrictDataset(*images)
ds = StrictDataset(sample=images[0],
flat_before=images[1],
flat_after=images[2],
dark_before=images[3],
dark_after=images[4])
self.model.datasets[ds.id] = ds

stacks_to_close = self.model.remove_container(ds.id)
Expand All @@ -329,7 +333,7 @@ def test_failed_remove_container(self):

def test_remove_empty_dataset_from_model(self):
sample = generate_images()
ds = StrictDataset(sample)
ds = StrictDataset(sample=sample)
self.model.datasets[ds.id] = ds

self.model.remove_container(sample.id)
Expand All @@ -342,7 +346,7 @@ def test_remove_non_sample_images_from_dataset_with_sample(self):
images = [generate_images() for _ in range(2)]
# Set the sample 180 to check this isn't removed
images[0].proj180deg = generate_images()
ds = StrictDataset(*images)
ds = StrictDataset(sample=images[0], flat_before=images[1])
self.model.datasets[ds.id] = ds
id_to_remove = images[-1].id

Expand All @@ -353,7 +357,7 @@ def test_remove_non_sample_images_from_dataset_with_sample(self):

def test_remove_non_sample_images_from_dataset_without_sample(self):
images = [generate_images() for _ in range(2)]
ds = StrictDataset(*images)
ds = StrictDataset(sample=images[0], flat_before=images[1])
ds.sample = None
self.model.datasets[ds.id] = ds
id_to_remove = images[-1].id
Expand All @@ -366,7 +370,7 @@ def test_remove_non_sample_images_from_dataset_without_sample(self):
def test_remove_sample_with_180_from_dataset(self):
sample = generate_images()
sample.proj180deg = generate_images()
ds = StrictDataset(sample)
ds = StrictDataset(sample=sample)
self.model.datasets[ds.id] = ds

expected_result = [sample.id, sample.proj180deg.id]
Expand All @@ -377,7 +381,7 @@ def test_remove_sample_with_180_from_dataset(self):

def test_remove_sample_without_180_from_dataset(self):
sample = generate_images()
ds = StrictDataset(sample)
ds = StrictDataset(sample=sample)
self.model.datasets[ds.id] = ds

expected_result = [sample.id]
Expand All @@ -399,7 +403,7 @@ def test_remove_images_from_mixed_dataset(self):
self.assertListEqual([id_to_remove], deleted_stacks)

def test_add_dataset_to_model(self):
ds = StrictDataset(generate_images())
ds = StrictDataset(sample=generate_images())
self.model.add_dataset_to_model(ds)
self.assertIn(ds, self.model.datasets.values())

Expand All @@ -408,14 +412,14 @@ def test_image_ids(self):
for _ in range(3):
images = [generate_images() for _ in range(3)]
all_ids += [image.id for image in images]
ds = StrictDataset(*images)
ds = StrictDataset(sample=images[0], flat_before=images[1], flat_after=images[2])
self.model.add_dataset_to_model(ds)
self.assertListEqual(all_ids, self.model.image_ids)

def test_add_recon_to_dataset(self):
sample = generate_images()
sample_id = sample.id
ds = StrictDataset(sample)
ds = StrictDataset(sample=sample)

recon = generate_images()
self.model.add_dataset_to_model(ds)
Expand All @@ -425,8 +429,8 @@ def test_add_recon_to_dataset(self):

def test_proj180s(self):

ds1 = StrictDataset(generate_images())
ds2 = StrictDataset(generate_images())
ds1 = StrictDataset(sample=generate_images())
ds2 = StrictDataset(sample=generate_images())
ds3 = MixedDataset(stacks=[generate_images()])

proj180s = [ImageStack(ds1.sample.data[0]), ImageStack(ds2.sample.data[0])]
Expand All @@ -444,18 +448,18 @@ def test_exception_when_dataset_for_recons_not_found(self):
self.model.add_recon_to_dataset(generate_images(), "bad-id")

def test_get_parent_strict_dataset_success(self):
ds = StrictDataset(generate_images())
ds = StrictDataset(sample=generate_images())
self.model.add_dataset_to_model(ds)
self.assertIs(self.model.get_parent_dataset(ds.sample.id), ds.id)

def test_get_parent_dataset_doesnt_find_any_parent(self):
ds = StrictDataset(generate_images())
ds = StrictDataset(sample=generate_images())
self.model.add_dataset_to_model(ds)
with self.assertRaises(RuntimeError):
self.model.get_parent_dataset("unrecognised-id")

def test_delete_all_recons_in_dataset(self):
ds = StrictDataset(generate_images())
ds = StrictDataset(sample=generate_images())
[ds.add_recon(generate_images()) for _ in range(3)]
recon_ids = ds.recons.ids
self.model.add_dataset_to_model(ds)
Expand Down Expand Up @@ -490,14 +494,14 @@ def test_wrong_dataset_type_for_180_raises(self):
self.model.get_existing_180_id(md.id)

def test_get_existing_180_id_finds_id(self):
sd = StrictDataset(generate_images((5, 20, 20)))
sd = StrictDataset(sample=generate_images((5, 20, 20)))
sd.proj180deg = _180 = generate_images((1, 20, 20))
self.model.add_dataset_to_model(sd)

assert self.model.get_existing_180_id(sd.id) == _180.id

def test_get_existing_id_returns_none_for_dataset_without_180(self):
sd = StrictDataset(generate_images((5, 20, 20)))
sd = StrictDataset(sample=generate_images((5, 20, 20)))
self.model.add_dataset_to_model(sd)

self.assertIsNone(self.model.get_existing_180_id(sd.id))
Expand Down Expand Up @@ -526,7 +530,7 @@ def test_do_nexus_saving_fails_from_wrong_dataset(self):

@mock.patch("mantidimaging.gui.windows.main.model.saver.nexus_save")
def test_do_nexus_save_success(self, nexus_save):
sd = StrictDataset(generate_images())
sd = StrictDataset(sample=generate_images())
self.model.add_dataset_to_model(sd)
path = "path"
sample_name = "sample-name"
Expand All @@ -536,7 +540,7 @@ def test_do_nexus_save_success(self, nexus_save):
nexus_save.assert_called_once_with(sd, path, sample_name, save_as_float)

def test_is_dataset_strict_returns_true(self):
strict_ds = StrictDataset(generate_images())
strict_ds = StrictDataset(sample=generate_images())
self.model.add_dataset_to_model(strict_ds)
self.assertTrue(self.model.is_dataset_strict(strict_ds.id))

Expand Down
Loading

0 comments on commit 4a975bd

Please sign in to comment.