Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Xarray datatree support #329

Merged
merged 19 commits into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ jobs:
strategy:
matrix:
platform: [ubuntu-latest, macos-latest] # [ubuntu-latest, windows-latest, macos-latest]
python-version: ["3.9", "3.11"]
python-version: ["3.10", "3.12"]

steps:
- uses: actions/checkout@v4
Expand Down Expand Up @@ -86,6 +86,11 @@ jobs:
python-version: 3.11
cache-dependency-path: pyproject.toml

- name: Install dependencies with 'pre' extras (since the above doesn't check pre-releases)
run: |
python -m pip install --upgrade pip
pip install .[pre]

- uses: tlambert03/setup-qt-libs@v1

- uses: octokit/[email protected]
Expand Down
2 changes: 1 addition & 1 deletion .mypy.ini
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[mypy]
mypy_path = napari-spatialdata
python_version = 3.9
python_version = 3.10
plugins = numpy.typing.mypy_plugin

ignore_errors = False
Expand Down
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ fail_fast: false
default_language_version:
python: python3
default_stages:
- commit
- push
- pre-commit
- pre-push
minimum_pre_commit_version: 2.9.3
repos:
- repo: https://github.com/pre-commit/mirrors-mypy
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ write_to = "src/napari_spatialdata/_version.py"

[tool.black]
line-length = 120
target-version = ['py39']
target-version = ['py310']
include = '\.pyi?$'
exclude = '''
(
Expand Down Expand Up @@ -38,7 +38,7 @@ exclude = [
"setup.py",
]
line-length = 120
target-version = "py39"
target-version = "py310"
[tool.ruff.lint]
ignore = [
# Do not assign a lambda expression, use a def -> lambda expression assignments are convenient
Expand Down
5 changes: 3 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,17 @@ classifiers =
Topic :: Software Development :: Testing
Programming Language :: Python
Programming Language :: Python :: 3
Programming Language :: Python :: 3.9
Programming Language :: Python :: 3.10
Programming Language :: Python :: 3.11
Programming Language :: Python :: 3.12
Operating System :: OS Independent
License :: OSI Approved :: BSD License


[options]
packages = find:
include_package_data = True
python_requires = >=3.9.2
python_requires = >=3.10
setup_requires = setuptools_scm
# add your package requirements here
install_requires =
Expand Down
2 changes: 1 addition & 1 deletion src/napari_spatialdata/_reader.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Callable
from collections.abc import Callable

from loguru import logger
from spatialdata import SpatialData
Expand Down
2 changes: 1 addition & 1 deletion src/napari_spatialdata/_scatterwidgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def setComponent(self, text: int | str | None) -> None:
super().setAdataLayer(text)
elif self.getAttribute() == "obsm":
if TYPE_CHECKING:
assert isinstance(text, (int, str))
assert isinstance(text, int | str)
self.text = text # type: ignore[assignment]
super().setIndex(text)

Expand Down
6 changes: 4 additions & 2 deletions src/napari_spatialdata/_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,9 @@ def _channel_changed(self, event: Event) -> None:
current_point = list(event.value)
displayed = self._viewer.dims.displayed
if layer.multiscale:
for i, (lo_size, hi_size, cord) in enumerate(zip(layer.data[-1].shape, layer.data[0].shape, current_point)):
for i, (lo_size, hi_size, cord) in enumerate(
zip(layer.data[-1].shape, layer.data[0].shape, current_point, strict=False)
):
if i in displayed:
current_point[i] = slice(None)
else:
Expand Down Expand Up @@ -385,7 +387,7 @@ def _select_layer(self) -> None:
self.var_widget.clear()
self.obsm_widget.clear()
self.color_by.clear()
if isinstance(layer, (Points, Shapes)) and (cols_df := layer.metadata.get("_columns_df")) is not None:
if isinstance(layer, Points | Shapes) and (cols_df := layer.metadata.get("_columns_df")) is not None:
self.dataframe_columns_widget.addItems(map(str, cols_df.columns))
self.model.system_name = layer.metadata.get("name", None)
self.model.adata = None
Expand Down
8 changes: 4 additions & 4 deletions src/napari_spatialdata/_viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ def save_to_sdata(
parsed, cs = self._save_shapes_to_sdata(selected, spatial_element_name, overwrite)
if table_name:
self._save_table_to_sdata(selected, table_name, spatial_element_name, table_columns, overwrite)
elif isinstance(selected, (Image, Labels)):
elif isinstance(selected, Image | Labels):
raise NotImplementedError
else:
raise ValueError(f"Layer of type {type(selected)} cannot be saved.")
Expand Down Expand Up @@ -400,17 +400,17 @@ def inherit_metadata(self, layers: list[Layer], ref_layer: Layer) -> None:
for layer in (
layer
for layer in layers
if layer != ref_layer and isinstance(layer, (Labels, Points, Shapes)) and "sdata" not in layer.metadata
if layer != ref_layer and isinstance(layer, Labels | Points | Shapes) and "sdata" not in layer.metadata
):
layer.metadata["sdata"] = ref_layer.metadata["sdata"]
layer.metadata["_current_cs"] = ref_layer.metadata["_current_cs"]
layer.metadata["_active_in_cs"] = {ref_layer.metadata["_current_cs"]}
layer.metadata["name"] = None
layer.metadata["adata"] = None
if isinstance(layer, (Shapes, Labels)):
if isinstance(layer, Shapes | Labels):
layer.metadata["region_key"] = None
layer.metadata["instance_key"] = None
if isinstance(layer, (Shapes, Points)):
if isinstance(layer, Shapes | Points):
layer.metadata["_n_indices"] = None
layer.metadata["indices"] = None
self.layer_linked.emit(layer)
Expand Down
12 changes: 6 additions & 6 deletions src/napari_spatialdata/_widgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def _onAction(self, items: Iterable[str]) -> None:

properties = self._get_points_properties(vec, key=item, layer=self.model.layer)
self.model.color_by = "" if self.model.system_name is None else item
if isinstance(self.model.layer, (Points, Shapes)):
if isinstance(self.model.layer, Points | Shapes):
self.model.layer.text = None # needed because of the text-feature order of updates
self.model.layer.face_color = properties["face_color"]
# self.model.layer.edge_color = properties["face_color"]
Expand Down Expand Up @@ -214,7 +214,7 @@ def _(self, vec: pd.Series, **kwargs: Any) -> dict[str, Any]:
colorer = AnnData(shape=(len(vec), 0), obs=pd.DataFrame(index=vec.index, data={"vec": vec}))
_set_colors_for_categorical_obs(colorer, "vec", palette="tab20")
colors = colorer.uns["vec_colors"]
color_dict = dict(zip(vec.cat.categories, colors))
color_dict = dict(zip(vec.cat.categories, colors, strict=False))
color_dict.update({np.nan: "#808080ff"})
else:
color_dict = self.model.adata.uns[vec_color_name]
Expand All @@ -224,7 +224,7 @@ def _(self, vec: pd.Series, **kwargs: Any) -> dict[str, Any]:
colorer = AnnData(shape=(len(vec), 0), obs=pd.DataFrame(index=vec.index, data={"vec": vec}))
_set_colors_for_categorical_obs(colorer, "vec", palette="tab20")
colors = colorer.uns["vec_colors"]
color_dict = dict(zip(vec.cat.categories, colors))
color_dict = dict(zip(vec.cat.categories, colors, strict=False))
color_dict.update({np.nan: "#808080ff"})
color_column = vec.apply(lambda x: color_dict[x])
df[vec_color_name] = color_column
Expand All @@ -247,7 +247,7 @@ def _(self, vec: pd.Series, **kwargs: Any) -> dict[str, Any]:

merge_df["color"] = merge_df[vec.name].map(color_dict)
if layer is not None and isinstance(layer, Labels):
index_color_mapping = dict(zip(merge_df["element_indices"], merge_df["color"]))
index_color_mapping = dict(zip(merge_df["element_indices"], merge_df["color"], strict=False))
return {
"color": index_color_mapping,
"properties": {"value": vec},
Expand Down Expand Up @@ -301,7 +301,7 @@ def _(self, vec: ArrayLike, **kwargs: Any) -> dict[str, Any]:

if layer is not None and isinstance(layer, Labels):
return {
"color": dict(zip(element_indices, color_vec)),
"color": dict(zip(element_indices, color_vec, strict=False)),
"properties": {"value": vec},
"text": None,
}
Expand Down Expand Up @@ -561,7 +561,7 @@ def _onValueChange(self, percentile: tuple[float, float]) -> None:
elif isinstance(layer, Labels):
norm_vec = self._scale_vec(clipped)
color_vec = self._cmap(norm_vec)
layer.color = dict(zip(layer.color.keys(), color_vec))
layer.color = dict(zip(layer.color.keys(), color_vec, strict=False))
layer.properties = {"value": clipped}
layer.refresh()

Expand Down
11 changes: 6 additions & 5 deletions src/napari_spatialdata/constants/_pkg_constants.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Internal constants not exposed to the user."""

from typing import Any, Callable, Optional
from collections.abc import Callable
from typing import Any

from anndata import AnnData

Expand Down Expand Up @@ -71,15 +72,15 @@ def size_key(cls) -> str:
return "spot_diameter_fullres"

@classmethod
def spatial_neighs(cls, value: Optional[str] = None) -> str:
def spatial_neighs(cls, value: str | None = None) -> str:
return f"{Key.obsm.spatial}_neighbors" if value is None else f"{value}_neighbors"

@classmethod
def colors(cls, cluster: str) -> str:
return f"{cluster}_colors"

@classmethod
def spot_diameter(cls, adata: AnnData, spatial_key: str, library_id: Optional[str] = None) -> float:
def spot_diameter(cls, adata: AnnData, spatial_key: str, library_id: str | None = None) -> float:
try:
return float(adata.uns[spatial_key][library_id]["scalefactors"]["spot_diameter_fullres"])
except KeyError:
Expand All @@ -90,9 +91,9 @@ def spot_diameter(cls, adata: AnnData, spatial_key: str, library_id: Optional[st

class obsp:
@classmethod
def spatial_dist(cls, value: Optional[str] = None) -> str:
def spatial_dist(cls, value: str | None = None) -> str:
return f"{Key.obsm.spatial}_distances" if value is None else f"{value}_distances"

@classmethod
def spatial_conn(cls, value: Optional[str] = None) -> str:
def spatial_conn(cls, value: str | None = None) -> str:
return f"{Key.obsm.spatial}_connectivities" if value is None else f"{value}_connectivities"
3 changes: 2 additions & 1 deletion src/napari_spatialdata/constants/_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from abc import ABC, ABCMeta
from collections.abc import Callable
from enum import Enum, EnumMeta
from functools import wraps
from typing import Any, Callable
from typing import Any


def _pretty_raise_enum(cls: type["ModeEnum"], fun: Callable[..., Any]) -> Callable[..., Any]:
Expand Down
23 changes: 13 additions & 10 deletions src/napari_spatialdata/utils/_utils.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
from __future__ import annotations

from collections import Counter
from collections.abc import Generator, Iterable, Sequence
from collections.abc import Callable, Generator, Iterable, Sequence
from contextlib import contextmanager
from functools import wraps
from random import randint
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
from typing import TYPE_CHECKING, Any

import numpy as np
import packaging.version
import pandas as pd
from anndata import AnnData
from dask.dataframe import DataFrame as DaskDataFrame
from datatree import DataTree
from geopandas import GeoDataFrame
from loguru import logger
from matplotlib.colors import is_color_like, to_rgb
Expand All @@ -33,7 +32,7 @@
from spatialdata import SpatialData, get_extent, join_spatialelement_table
from spatialdata.models import SpatialElement, get_axes_names
from spatialdata.transformations import get_transformation
from xarray import DataArray
from xarray import DataArray, DataTree

from napari_spatialdata.constants._pkg_constants import Key
from napari_spatialdata.utils._categoricals_utils import (
Expand All @@ -49,7 +48,7 @@

from spatialdata._types import ArrayLike

Vector_name_index_t = tuple[Optional[Union[pd.Series, ArrayLike]], Optional[str], Optional[pd.Index]]
Vector_name_index_t = tuple[pd.Series | ArrayLike | None, str | None, pd.Index | None]


def _ensure_dense_vector(fn: Callable[..., Vector_name_index_t]) -> Callable[..., Vector_name_index_t]:
Expand Down Expand Up @@ -79,7 +78,7 @@ def decorator(self: Any, *args: Any, **kwargs: Any) -> Vector_name_index_t:
if TYPE_CHECKING:
assert isinstance(res, spmatrix)
res = res.toarray()
elif not isinstance(res, (np.ndarray, Sequence)):
elif not isinstance(res, np.ndarray | Sequence):
raise TypeError(f"Unable to process result of type `{type(res).__name__}`.")

res = np.atleast_1d(np.squeeze(res))
Expand All @@ -100,7 +99,7 @@ def _get_palette(
if key not in adata.obs:
raise KeyError("Missing key!") # TODO: Improve error message

return dict(zip(adata.obs[key].cat.categories, [to_rgb(i) for i in adata.uns[Key.uns.colors(key)]]))
return dict(zip(adata.obs[key].cat.categories, [to_rgb(i) for i in adata.uns[Key.uns.colors(key)]], strict=True))


def _set_palette(
Expand All @@ -121,7 +120,7 @@ def _set_palette(
)
vec = vec if vec is not None else adata.obs[key]
#
return dict(zip(vec.cat.categories, [to_rgb(i) for i in adata.uns[Key.uns.colors(key)]]))
return dict(zip(vec.cat.categories, [to_rgb(i) for i in adata.uns[Key.uns.colors(key)]], strict=True))


def _get_categorical(
Expand Down Expand Up @@ -183,7 +182,7 @@ def _transform_coordinates(data: list[Any], f: Callable[..., Any]) -> list[Any]:


def _get_transform(element: SpatialElement, coordinate_system_name: str | None = None) -> None | ArrayLike:
if not isinstance(element, (DataArray, DataTree, DaskDataFrame, GeoDataFrame)):
if not isinstance(element, DataArray | DataTree | DaskDataFrame | GeoDataFrame):
raise RuntimeError("Cannot get transform for {type(element)}")

transformations = get_transformation(element, get_all=True)
Expand Down Expand Up @@ -256,7 +255,11 @@ def _adjust_channels_order(element: DataArray | DataTree) -> tuple[DataArray | l

if len(c_coords) != 0 and set(c_coords) - {"r", "g", "b"} <= {"a"}:
rgb = True
new_raster = element.transpose("y", "x", "c").reindex(c=["r", "g", "b", "a"][: len(c_coords)])
if isinstance(element, DataArray):
new_raster = element.transpose("y", "x", "c").reindex(c=["r", "g", "b", "a"][: len(c_coords)])
else:
new_raster = element.msi.transpose("y", "x", "c")
new_raster = new_raster.msi.reindex_data_arrays({"c": ["r", "g", "b", "a"][: len(c_coords)]})
else:
rgb = False
new_raster = element
Expand Down
3 changes: 2 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from __future__ import annotations

from abc import ABC, ABCMeta
from collections.abc import Callable
from functools import wraps
from pathlib import Path
from typing import Any, Callable
from typing import Any

import napari
import numpy as np
Expand Down
9 changes: 4 additions & 5 deletions tests/test_spatialdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from dask.array.random import randint
from dask.dataframe import DataFrame as DaskDataFrame
from dask.dataframe import from_dask_array
from datatree import DataTree
from multiscale_spatial_image import to_multiscale
from napari.layers import Image, Labels, Points
from napari.utils.events import EventedList
Expand All @@ -19,7 +18,7 @@
from spatialdata.models import PointsModel, TableModel
from spatialdata.transformations import Identity
from spatialdata.transformations.operations import set_transformation
from xarray import DataArray
from xarray import DataArray, DataTree

from napari_spatialdata import QtAdataViewWidget
from napari_spatialdata._sdata_widgets import CoordinateSystemWidget, ElementWidget, SdataWidget
Expand Down Expand Up @@ -288,7 +287,7 @@ def test_partial_table_matching_with_arbitrary_ordering(qtbot, make_napari_viewe
"blobs_polygons",
]:
element = original_sdata[region]
if isinstance(element, (DataArray, DataTree)):
if isinstance(element, DataArray | DataTree):
index = get_element_instances(element).values
elif isinstance(element, DaskDataFrame):
index = element.index.compute().values
Expand All @@ -309,15 +308,15 @@ def test_partial_table_matching_with_arbitrary_ordering(qtbot, make_napari_viewe
# when instance_key_type == 'str' (and when the element is not Labels), let's change the type of instance_key
# column and of the corresponding index in the spatial element to string. Labels need to have int as they are
# tensors of non-negative integers.
if not isinstance(element, (DataArray, DataTree)) and instance_key_type == "str":
if not isinstance(element, DataArray | DataTree) and instance_key_type == "str":
element.index = element.index.astype(str)
table.obs[INSTANCE_KEY] = table.obs[INSTANCE_KEY].astype(str)

shuffled_element = deepcopy(element)
shuffled_table = deepcopy(table)

# shuffle the order of the rows of the element (when the element is not Labels)
if not isinstance(element, (DataArray, DataTree)):
if not isinstance(element, DataArray | DataTree):
shuffled_element = shuffled_element.loc[RNG.permutation(shuffled_element.index)]
# shuffle the order of the rows of the table
shuffled_table = shuffled_table[RNG.permutation(shuffled_table.obs.index), :].copy()
Expand Down
Loading
Loading