Skip to content

Commit

Permalink
_get_stack_data_type for BaseDataset
Browse files Browse the repository at this point in the history
  • Loading branch information
samtygier-stfc committed Jul 19, 2024
1 parent 5c99124 commit 453431e
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 6 deletions.
9 changes: 4 additions & 5 deletions mantidimaging/core/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def is_processed(self) -> bool:
return False


def _get_stack_data_type(stack_id: uuid.UUID, dataset: MixedDataset | StrictDataset) -> str:
def _get_stack_data_type(stack_id: uuid.UUID, dataset: BaseDataset) -> str:
"""
Find the data type as a string of a stack.
:param stack_id: The ID of the stack.
Expand All @@ -217,10 +217,9 @@ def _get_stack_data_type(stack_id: uuid.UUID, dataset: MixedDataset | StrictData
"""
if stack_id in [recon.id for recon in dataset.recons]:
return "Recon"
if isinstance(dataset, MixedDataset):
if stack_id in dataset:
return "Images"
else:
if stack_id in [stack.id for stack in dataset._stacks]:
return "Images"
if isinstance(dataset, StrictDataset):
if stack_id == dataset.sample.id:
return "Sample"
if dataset.flat_before is not None and stack_id == dataset.flat_before.id:
Expand Down
15 changes: 14 additions & 1 deletion mantidimaging/core/data/test/dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from unittest import mock
import uuid

from mantidimaging.core.data.dataset import BaseDataset
from mantidimaging.core.data.dataset import BaseDataset, _get_stack_data_type
from mantidimaging.test_helpers.unit_test_helper import generate_images


Expand Down Expand Up @@ -76,3 +76,16 @@ def test_delete_stack_from_stacks_list(self):
prev_stacks = image_stacks.copy()
ds.delete_stack(image_stacks[-1].id)
self.assertListEqual(ds.all, prev_stacks[:-1])

def test_get_stack_data_type_returns_recon(self):
recon = generate_images()
recon_id = recon.id
dataset = BaseDataset()
dataset.recons.append(recon)
self.assertEqual(_get_stack_data_type(recon_id, dataset), "Recon")

def test_get_stack_data_type_returns_images(self):
images = generate_images()
images_id = images.id
dataset = BaseDataset(stacks=[images])
self.assertEqual(_get_stack_data_type(images_id, dataset), "Images")

0 comments on commit 453431e

Please sign in to comment.