Skip to content

Commit

Permalink
Include cropping regions in iter_slices
Browse files Browse the repository at this point in the history
  • Loading branch information
multimeric committed Aug 23, 2024
1 parent b33cc4c commit 6407456
Showing 1 changed file with 64 additions and 51 deletions.
115 changes: 64 additions & 51 deletions core/lls_core/models/lattice_data.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import annotations
from typing import Tuple
# class for initializing lattice data and setting metadata
# TODO: handle scenes
from pydantic import Field, root_validator, validator
Expand All @@ -20,6 +21,7 @@

if TYPE_CHECKING:
from lls_core.models.results import ImageSlice, ImageSlices, ProcessedSlice
from lls_core.cropping import Roi
from lls_core.writers import Writer
from xarray import DataArray

Expand Down Expand Up @@ -229,28 +231,36 @@ def slice_data(self, time: int, channel: int) -> DataArray:

return self.input_image.isel(T=time, C=channel)

def iter_slices(self) -> Iterable[ProcessedSlice[ArrayLike]]:
def iter_roi_indices(self) -> Iterable[Optional[int]]:
"""
Yields array slices for each time and channel of interest.
Params:
progress: If the progress bar is enabled
Yields region of interest indices, with a progress bar.
This yields `None` exactly once if cropping is disabled, for compatibility.
"""
from tqdm import tqdm
if self.cropping_enabled:
for i, _roi in tqdm(enumerate(self.crop.selected_rois), desc="ROI", position=0):
yield i
else:
yield None

Returns:
An iterable of tuples. Each tuple contains (time_index, time, channel_index, channel, slice)
def iter_slices(self) -> Iterable[ProcessedSlice[ArrayLike]]:
"""
Yields 3D array slices for each time, channel and region of interest.
"""
from lls_core.models.results import ProcessedSlice
from tqdm import tqdm

for time_idx, time in tqdm(enumerate(self.time_range), desc="Timepoints", total=len(self.time_range)):
for ch_idx, ch in tqdm(enumerate(self.channel_range), desc="Channels", total=len(self.channel_range), leave=False):
yield ProcessedSlice(
data=self.slice_data(time=time, channel=ch),
time_index=time_idx,
time= time,
channel_index=ch_idx,
channel=ch,
)
for roi_index in self.iter_roi_indices():
for time_idx, time in tqdm(enumerate(self.time_range), desc="Timepoints", total=len(self.time_range)):
for ch_idx, ch in tqdm(enumerate(self.channel_range), desc="Channels", total=len(self.channel_range), leave=False):
yield ProcessedSlice(
data=self.slice_data(time=time, channel=ch),
roi_index=roi_index,
time_index=time_idx,
time=time,
channel_index=ch_idx,
channel=ch,
)

@property
def n_slices(self) -> int:
Expand All @@ -267,13 +277,21 @@ def iter_sublattices(self, update_with: dict = {}) -> Iterable[ProcessedSlice[La
update_with: dictionary of arguments to update the generated lattices with
"""
for subarray in self.iter_slices():

if subarray.roi_index is not None and self.crop is not None:
crop = self.crop.copy_validate(update = {
"roi_subset": [subarray.roi_index]
})
else:
crop = None
new_lattice = self.copy_validate(update={
"input_image": subarray.data,
"time_range": range(1),
"channel_range": range(1),
"crop": crop,
**update_with
})
yield subarray.copy_with_data( new_lattice)
yield subarray.copy_with_data(new_lattice)

def generate_workflows(
self,
Expand Down Expand Up @@ -332,44 +350,39 @@ def _process_crop(self) -> Iterable[ImageSlice]:
"""
Yields processed image slices with cropping enabled
"""
from tqdm import tqdm
if self.crop is None:
raise Exception("This function can only be called when crop is set")

# We have an extra level of iteration for the crop path: iterating over each ROI
for roi_index, roi in enumerate(tqdm(self.crop.selected_rois, desc="ROI", position=0)):
# pass arguments for save tiff, callable and function arguments
logger.info(f"Processing ROI {roi_index}")

for slice in self.iter_slices():
deconv_args: dict[Any, Any] = {}
if self.deconvolution is not None:
deconv_args = dict(
num_iter = self.deconvolution.psf_num_iter,
psf = self.deconvolution.psf[slice.channel].to_numpy(),
decon_processing=self.deconvolution.decon_processing
)
for slice in self.iter_slices():
deconv_args: dict[Any, Any] = {}
if self.deconvolution is not None:
deconv_args = dict(
num_iter = self.deconvolution.psf_num_iter,
psf = self.deconvolution.psf[slice.channel].to_numpy(),
decon_processing=self.deconvolution.decon_processing
)

yield slice.copy(update={
"data": crop_volume_deskew(
original_volume=slice.data,
deconvolution=self.deconv_enabled,
get_deskew_and_decon=False,
debug=False,
roi_shape=list(roi),
linear_interpolation=True,
voxel_size_x=self.dx,
voxel_size_y=self.dy,
voxel_size_z=self.dz,
angle_in_degrees=self.angle,
deskewed_volume=self.deskewed_volume,
z_start=self.crop.z_range[0],
z_end=self.crop.z_range[1],
**deconv_args
),
"roi_index": roi_index
})

yield slice.copy(update={
"data": crop_volume_deskew(
original_volume=slice.data,
deconvolution=self.deconv_enabled,
get_deskew_and_decon=False,
debug=False,
# There is guaranteed to be exactly one ROI at this stage, due to `iter_slices()`
roi_shape=list(next(iter(self.crop.selected_rois))) if self.crop else None,
linear_interpolation=True,
voxel_size_x=self.dx,
voxel_size_y=self.dy,
voxel_size_z=self.dz,
angle_in_degrees=self.angle,
deskewed_volume=self.deskewed_volume,
z_start=self.crop.z_range[0],
z_end=self.crop.z_range[1],
**deconv_args
),
"roi_index": slice.roi_index
})

def _process_non_crop(self) -> Iterable[ImageSlice]:
"""
Yields processed image slices without cropping
Expand Down

0 comments on commit 6407456

Please sign in to comment.