Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question about ImageStackDataset #36

Open
JoOkuma opened this issue Jun 21, 2021 · 4 comments
Open

Question about ImageStackDataset #36

JoOkuma opened this issue Jun 21, 2021 · 4 comments

Comments

@JoOkuma
Copy link

JoOkuma commented Jun 21, 2021

I was playing with torch-em and I had some issues regarding the shape of the images.

I noticed that the number of dimensions is fixed to 3 when using the ImageStackDataset.

self._ndim = 3

Is this intentional? When I changed to the length of the shape my script worked.

Thanks in advance,

@constantinpape
Copy link
Owner

Yes, that's intentional. The idea in torch_em is that the images are stacked across a first axes (like a virtual z axis).
In torch_em you should then set patch_shape=(1, SHAPE_Y, SHAPE_X), so that images are loaded as 2d slices.
(They will then be squeezed in the ImageStackDataset, s.t. 2d models can work with it; this solution is a bit hacky but worked best in combination with the other logic in torch_em.).

I hope this answers your questions; if not it would be helpful to have a small example that shows how things are going wrong for your use case.

@JoOkuma
Copy link
Author

JoOkuma commented Jun 21, 2021

My use case is that I have 3D + time data and I'm trying to process it as 2D slices.

I'm trying to use as below, it is working after I changed the ndim of the ImageStackDataset:

    model = UNet2d(in_channels=1, out_channels=2)
    path = Path('<my directory>')

    images_key = 'images'
    labels_key = 'labels'

    label_transform = torch_em.transform.BoundaryTransform(
        add_binary_target=True, ndim=2
    )

    def transform(x, y):
        return x.squeeze(), y.squeeze()

    # training and validation data loader
    train_loader = torch_em.default_segmentation_loader(
        str(path / images_key), "*.tif",
        str(path / labels_key), "*.tif",
        batch_size=16, patch_shape=(1, 1, 256, 256),
        transform=transform,
        label_transform2=label_transform,
        n_samples=250,
        ndim=2,
    )

Without the ndim changes I'm limited to using a 3-dimensional patch due to this assertion, which index the data incorrectly because it is trying to access the 4D array with a 3D slice.

Removing the array and patch shape assertion and using the 4-dimensional patch also works.

@constantinpape
Copy link
Owner

Ok, I see. This probably happens because I haven't taken 4D datasets into account. I don't have much time to look into this right now, but it looks like removing the assertion fixes the issue for you for now.
If you want you can make a PR with that change; I have some time to take a closer look at it next week (to think about whether removing the assertion might cause problems elsewhere).

@JoOkuma
Copy link
Author

JoOkuma commented Jun 24, 2021

Thanks, I will do that :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants