Skip to content

Commit

Permalink
Cleanup old code
Browse files Browse the repository at this point in the history
  • Loading branch information
multimeric committed Sep 11, 2023
1 parent 23224b1 commit 17e10b6
Show file tree
Hide file tree
Showing 4 changed files with 138 additions and 1,072 deletions.
75 changes: 48 additions & 27 deletions core/lls_core/lattice_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from lls_core.types import ArrayLike
from napari_workflows import Workflow
from napari.types import ShapesData
from xarray import DataArray

if TYPE_CHECKING:
import pyclesperanto_prototype as cle
Expand Down Expand Up @@ -75,7 +76,7 @@ def save_image(self):
for roi, roi_results in groupby(self.slices, key=lambda it: it.roi_index):
if self.lattice_data.save_type == SaveFileType.h5:
bdv_writer = npy2bdv.BdvWriter(
make_filename_prefix(prefix=self.lattice_data.save_name, roi_index=roi),
make_filename_prefix(prefix=self.lattice_data.save_name, roi_index=roi) + ".h5",
compression='gzip',
nchannels=len(self.lattice_data.channel_range),
subsamp=((1, 1, 1), (1, 2, 2), (2, 4, 4)),
Expand Down Expand Up @@ -180,7 +181,7 @@ def default_time_range(cls, v: Any, values: dict) -> range:
Sets the default time range if undefined
"""
if v is None:
return range(values["dims"].T + 1)
return range(values["data"].sizes["T"] + 1)
return v

@validator("channel_range")
Expand All @@ -189,20 +190,12 @@ def default_channel_range(cls, v: Any, values: dict) -> range:
Sets the default channel range if undefined
"""
if v is None:
return range(values["dims"].C + 1)
return range(values["data"].sizes["C"] + 1)
return v

class DeskewParams(DefaultMixin, arbitrary_types_allowed=True):
#: A 3-5D array containing the image data
data: ArrayLike

#: Dimensions of `data`
dims: Dimensions

#: Dimensions of the deskewed output
deskew_vol_shape: Tuple[int, ...] = Field(init_var=False)

deskew_affine_transform: cle.AffineTransform3D = Field(init_var=False)
data: DataArray

#: Geometry of the light path
skew: DeskewDirection = DeskewDirection.Y
Expand All @@ -211,17 +204,42 @@ class DeskewParams(DefaultMixin, arbitrary_types_allowed=True):
#: Pixel size in microns
physical_pixel_sizes: DefinedPixelSizes = Field(default_factory=DefinedPixelSizes)

#: Dimensions of the deskewed output
deskew_vol_shape: Tuple[int, ...] = Field(init_var=False, default=None)

deskew_affine_transform: cle.AffineTransform3D = Field(init_var=False, default=None)

@property
def dims(self):
return self.data.dims

@validator("data", pre=True)
def reshaping(cls, v: Any):
# This allows a user to pass in any array-like object and have it
# converted and reshaped appropriately
array = DataArray(v)
if not set(array.dims).issuperset({"X", "Y", "Z"}):
raise ValueError("The input array must at least have XYZ coordinates")
if "T" not in array.dims:
array = array.expand_dims("T")
if "C" not in array.dims:
array = array.expand_dims("C")
return array.transpose("T", "C", "Z", "Y", "X")

def get_3d_slice(self) -> DataArray:
return self.data.sel(C=0, T=0)

@root_validator(pre=True)
def set_deskew(cls, values: dict) -> dict:
"""
Sets the default deskew shape values if the user has not provided them
"""
# process the file to get shape of final deskewed image
data = values["data"]
data: DataArray = cls.reshaping(values["data"])
if values.get('deskew_vol_shape') is None:
if values.get('deskew_affine_transform') is None:
# If neither has been set, calculate them ourselves
values["deskew_vol_shape"], values["deskew_affine_transform"] = get_deskewed_shape(values["data"], values["angle"], values["physical_pixel_sizes"].X, values["physical_pixel_sizes"].Y, values["physical_pixel_sizes"].Z, values["skew"])
values["deskew_vol_shape"], values["deskew_affine_transform"] = get_deskewed_shape(data.sel(C=0, T=0).to_numpy(), values["angle"], values["physical_pixel_sizes"].X, values["physical_pixel_sizes"].Y, values["physical_pixel_sizes"].Z, values["skew"])
else:
raise ValueError("deskew_vol_shape and deskew_affine_transform must be either both specified or neither specified")
return values
Expand Down Expand Up @@ -249,7 +267,7 @@ def disjoint_time_range(cls, v: range, values: dict):
"""
Validates that the time range is within the range of channels in our array
"""
max_time = values["dims"].T
max_time = values["data"].sizes["T"]
if v.start < 0:
raise ValueError("The lowest valid start value is 0")
if v.stop > max_time:
Expand All @@ -261,7 +279,7 @@ def disjoint_channel_range(cls, v: range, values: dict):
"""
Validates that the channel range is within the range of channels in our array
"""
max_channel = values["dims"].T
max_channel = values["data"].sizes["C"]
if v.start < 0:
raise ValueError("The lowest valid start value is 0")
if v.stop > max_channel:
Expand All @@ -270,13 +288,15 @@ def disjoint_channel_range(cls, v: range, values: dict):

@validator("channel_range")
def channel_range_subset(cls, v: range, values: dict):
if min(v) < 0 or max(v) > values["dims"].C:
if min(v) < 0 or max(v) > values["data"].sizes["C"]:
raise ValueError("The output channel range must be a subset of the total available channels")
return v

@validator("time_range")
def time_range_subset(cls, v: range, values: dict):
if min(v) < 0 or max(v) > values["dims"].T:
raise ValueError("The output time range must be a subset of the total available time points")
if min(v) < 0 or max(v) > values["data"].sizes["T"]:
raise ValueError("The output time range must be a subset of the total available time points")
return v

# Hack to ensure that .skew_dir behaves identically to .skew
@property
Expand Down Expand Up @@ -343,12 +363,12 @@ def deconv_enabled(self) -> bool:
@property
def time(self) -> int:
"""Number of time points"""
return self.dims.T
return self.data.sizes["T"]

@property
def channels(self) -> int:
"""Number of channels"""
return self.dims.C
return self.data.sizes["C"]

@property
def new_dz(self):
Expand All @@ -358,17 +378,19 @@ def __post_init__(self):
logger.info(f"Channels: {self.channels}, Time: {self.time}")
logger.info("If channel and time need to be swapped, you can enforce this by choosing 'Last dimension is channel' when initialising the plugin")

def slice_data(self, time: int, channel: int) -> ArrayLike:
def slice_data(self, time: int, channel: int) -> DataArray:
if time > self.time:
raise ValueError("time is out of range")
if channel > self.channels:
raise ValueError("channel is out of range")

if len(self.dims.shape) == 3:
return self.data.sel(T=time, C=channel)

if len(self.data.shape) == 3:
return self.data
elif len(self.dims.shape) == 4:
elif len(self.data.shape) == 4:
return self.data[time, :, :, :]
elif len(self.dims.shape) == 5:
elif len(self.data.shape) == 5:
return self.data[time, channel, :, :, :]

raise Exception("Lattice data must be 3-5 dimensions")
Expand Down Expand Up @@ -511,8 +533,7 @@ def process(self) -> ProcessedSlices:
)

class AicsLatticeParams(TypedDict):
data: DaskArray
dims: Dimensions
data: DataArray
physical_pixel_sizes: DefinedPixelSizes

def lattice_params_from_aics(img: AICSImage, physical_pixel_sizes: PhysicalPixelSizes = PhysicalPixelSizes(None, None, None)) -> AicsLatticeParams:
Expand Down
Loading

0 comments on commit 17e10b6

Please sign in to comment.