diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 73bdc27c..c6b6e605 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -11,10 +11,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 - - name: Set up Python 3.10 + - name: Set up Python 3.12 uses: actions/setup-python@v4 with: - python-version: "3.10" + python-version: "3.12" cache: pip - name: Install build dependencies run: python -m pip install --upgrade pip wheel twine build diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index 1aa90822..ee38475e 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -14,10 +14,10 @@ jobs: - name: Checkout code uses: actions/checkout@v3 - - name: Set up Python 3.10 + - name: Set up Python 3.12 uses: actions/setup-python@v4 with: - python-version: "3.10" + python-version: "3.12" - name: Install hatch run: pip install hatch diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 87b728c2..165d7fd3 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -18,15 +18,15 @@ jobs: strategy: fail-fast: false matrix: - python: ["3.9", "3.10"] + python: ["3.9", "3.12"] os: [ubuntu-latest] include: - os: macos-latest python: "3.9" - os: macos-latest - python: "3.10" + python: "3.12" pip-flags: "--pre" - name: "Python 3.10 (pre-release)" + name: "Python 3.12 (pre-release)" env: OS: ${{ matrix.os }} diff --git a/docs/tutorials/notebooks b/docs/tutorials/notebooks index 00a9c79a..ee68f0db 160000 --- a/docs/tutorials/notebooks +++ b/docs/tutorials/notebooks @@ -1 +1 @@ -Subproject commit 00a9c79a87d012b8666c762946c553c659d80ac1 +Subproject commit ee68f0db7300b23eca341d65415ed66a0dd9ee03 diff --git a/pyproject.toml b/pyproject.toml index 094b20a3..998cebb7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ dependencies = [ "anndata>=0.9.1", "click", "dask-image", - "dask<=2024.2.1", + "dask>=2024.2.1", "fsspec<=2023.6", "geopandas>=0.14", "multiscale_spatial_image>=1.0.0", @@ -36,6 +36,7 @@ dependencies = [ "pooch", "pyarrow", "rich", + "setuptools", "shapely>=2.0.1", "spatial_image>=1.1.0", "scikit-image", diff --git a/src/spatialdata/__init__.py b/src/spatialdata/__init__.py index e99f19cf..c9757d37 100644 --- a/src/spatialdata/__init__.py +++ b/src/spatialdata/__init__.py @@ -1,5 +1,20 @@ from __future__ import annotations +import dask + +dask.config.set({"dataframe.query-planning": False}) +from dask.dataframe import DASK_EXPR_ENABLED + +# Setting `dataframe.query-planning` to False is effective only if run before `dask.dataframe` is initialized. In +# the case in which the user had initilized `dask.dataframe` before, we would have DASK_EXPER_ENABLED set to `True`. +# Here we check that this does not happen. +if DASK_EXPR_ENABLED: + raise RuntimeError( + "Unsupported backend: dask-expr has been detected as the backend of dask.dataframe. Please " + "use:\nimport dask\ndask.config.set({'dataframe.query-planning': False})\nbefore importing " + "dask.dataframe to disable dask-expr. The support is being worked on, for more information please see" + "https://github.com/scverse/spatialdata/pull/570" + ) from importlib.metadata import version __version__ = version("spatialdata") diff --git a/src/spatialdata/_core/_deepcopy.py b/src/spatialdata/_core/_deepcopy.py index 9493a10e..819d1e2a 100644 --- a/src/spatialdata/_core/_deepcopy.py +++ b/src/spatialdata/_core/_deepcopy.py @@ -6,7 +6,7 @@ from anndata import AnnData from dask.array.core import Array as DaskArray from dask.array.core import from_array -from dask.dataframe.core import DataFrame as DaskDataFrame +from dask.dataframe import DataFrame as DaskDataFrame from datatree import DataTree from geopandas import GeoDataFrame from xarray import DataArray @@ -64,9 +64,9 @@ def _(element: DataArray) -> DataArray: @deepcopy.register(DataTree) def _(element: DataTree) -> DataTree: - # the complexity here is due to the fact that the parsers don't accept MultiscaleSpatialImage types and that we need - # to convert the DataTree to a MultiscaleSpatialImage. This will be simplified once we support - # multiscale_spatial_image 1.0.0 + # TODO: now that multiscale_spatial_image 1.0.0 is supported, this code can probably be simplified. Check + # https://github.com/scverse/spatialdata/pull/587/files#diff-c74ebf49cb8cbddcfaec213defae041010f2043cfddbded24175025b6764ef79 + # to understand the original motivation. model = get_model(element) for key in element: ds = element[key].ds diff --git a/src/spatialdata/_core/_elements.py b/src/spatialdata/_core/_elements.py index ab0ed4a8..e27874d6 100644 --- a/src/spatialdata/_core/_elements.py +++ b/src/spatialdata/_core/_elements.py @@ -8,7 +8,7 @@ from warnings import warn from anndata import AnnData -from dask.dataframe.core import DataFrame as DaskDataFrame +from dask.dataframe import DataFrame as DaskDataFrame from geopandas import GeoDataFrame from spatialdata._types import Raster_T diff --git a/src/spatialdata/_core/centroids.py b/src/spatialdata/_core/centroids.py index 813e963c..e0277b91 100644 --- a/src/spatialdata/_core/centroids.py +++ b/src/spatialdata/_core/centroids.py @@ -6,7 +6,7 @@ import dask.array as da import pandas as pd import xarray as xr -from dask.dataframe.core import DataFrame as DaskDataFrame +from dask.dataframe import DataFrame as DaskDataFrame from datatree import DataTree from geopandas import GeoDataFrame from shapely import MultiPolygon, Point, Polygon diff --git a/src/spatialdata/_core/data_extent.py b/src/spatialdata/_core/data_extent.py index 934bd5b8..eddfd94b 100644 --- a/src/spatialdata/_core/data_extent.py +++ b/src/spatialdata/_core/data_extent.py @@ -7,7 +7,7 @@ import numpy as np import pandas as pd -from dask.dataframe.core import DataFrame as DaskDataFrame +from dask.dataframe import DataFrame as DaskDataFrame from datatree import DataTree from geopandas import GeoDataFrame from shapely import MultiPolygon, Point, Polygon diff --git a/src/spatialdata/_core/operations/_utils.py b/src/spatialdata/_core/operations/_utils.py index 9025a5da..0cb0bf10 100644 --- a/src/spatialdata/_core/operations/_utils.py +++ b/src/spatialdata/_core/operations/_utils.py @@ -62,7 +62,7 @@ def transform_to_data_extent( Notes ----- - The data extent is the smallest rectangle that contains all the images and geometries. - - MultiscaleSpatialImage objects will be converted to SpatialImage objects. + - DataTree objects (multiscale images) will be converted to DataArray (single-scale images) objects. - This helper function will be deprecated when https://github.com/scverse/spatialdata/issues/308 is closed, as this function will be easily recovered by `transform_to_coordinate_system()` """ diff --git a/src/spatialdata/_core/operations/aggregate.py b/src/spatialdata/_core/operations/aggregate.py index c250a231..27fad5df 100644 --- a/src/spatialdata/_core/operations/aggregate.py +++ b/src/spatialdata/_core/operations/aggregate.py @@ -9,7 +9,7 @@ import geopandas as gpd import numpy as np import pandas as pd -from dask.dataframe.core import DataFrame as DaskDataFrame +from dask.dataframe import DataFrame as DaskDataFrame from datatree import DataTree from geopandas import GeoDataFrame from scipy import sparse diff --git a/src/spatialdata/_core/operations/rasterize.py b/src/spatialdata/_core/operations/rasterize.py index 446333af..158402be 100644 --- a/src/spatialdata/_core/operations/rasterize.py +++ b/src/spatialdata/_core/operations/rasterize.py @@ -3,8 +3,8 @@ import dask_image.ndinterp import datashader as ds import numpy as np -from dask.array.core import Array as DaskArray -from dask.dataframe.core import DataFrame as DaskDataFrame +from dask.array import Array as DaskArray +from dask.dataframe import DataFrame as DaskDataFrame from datatree import DataTree from geopandas import GeoDataFrame from shapely import Point diff --git a/src/spatialdata/_core/operations/rasterize_bins.py b/src/spatialdata/_core/operations/rasterize_bins.py index 64e6a98f..b577445e 100644 --- a/src/spatialdata/_core/operations/rasterize_bins.py +++ b/src/spatialdata/_core/operations/rasterize_bins.py @@ -10,7 +10,7 @@ from numpy.random import default_rng from scipy.sparse import csc_matrix from skimage.transform import estimate_transform -from spatial_image import SpatialImage +from xarray import DataArray from spatialdata._core.query.relational_query import get_values from spatialdata.models import Image2DModel, get_table_keys @@ -30,7 +30,7 @@ def rasterize_bins( col_key: str, row_key: str, value_key: str | list[str] | None = None, -) -> SpatialImage: +) -> DataArray: """ Rasterizes grid-like binned shapes/points annotated by a table (e.g. Visium HD data). diff --git a/src/spatialdata/_core/operations/transform.py b/src/spatialdata/_core/operations/transform.py index fbf6a8a7..7f2b56cc 100644 --- a/src/spatialdata/_core/operations/transform.py +++ b/src/spatialdata/_core/operations/transform.py @@ -9,11 +9,10 @@ import dask_image.ndinterp import numpy as np from dask.array.core import Array as DaskArray -from dask.dataframe.core import DataFrame as DaskDataFrame +from dask.dataframe import DataFrame as DaskDataFrame from datatree import DataTree from geopandas import GeoDataFrame from shapely import Point -from spatial_image import SpatialImage from xarray import DataArray from spatialdata._core.spatialdata import SpatialData @@ -177,6 +176,9 @@ def _set_transformation_for_transformed_elements( d = get_transformation(element, get_all=True) assert isinstance(d, dict) + if DEFAULT_COORDINATE_SYSTEM not in d: + raise RuntimeError(f"Coordinate system {DEFAULT_COORDINATE_SYSTEM} not found in element") + assert isinstance(d, dict) assert len(d) == 1 assert isinstance(d[DEFAULT_COORDINATE_SYSTEM], Identity) remove_transformation(element, remove_all=True) @@ -389,9 +391,7 @@ def _( raster_translation = raster_translation_single_scale # we set a dummy empty dict for the transformation that will be replaced with the correct transformation for # each scale later in this function, when calling set_transformation() - transformed_dict[k] = SpatialImage( - transformed_dask, dims=xdata.dims, name=xdata.name, attrs={TRANSFORM_KEY: {}} - ) + transformed_dict[k] = DataArray(transformed_dask, dims=xdata.dims, name=xdata.name, attrs={TRANSFORM_KEY: {}}) # mypy thinks that schema could be ShapesModel, PointsModel, ... transformed_data = DataTree.from_dict(transformed_dict) @@ -435,6 +435,9 @@ def _( transformed = data.drop(columns=list(axes)).copy() # dummy transformation that will be replaced by _adjust_transformation() transformed.attrs[TRANSFORM_KEY] = {DEFAULT_COORDINATE_SYSTEM: Identity()} + # TODO: the following line, used in place of the line before, leads to an incorrect aggregation result. Look into + # this! Reported here: ... + # transformed.attrs = {TRANSFORM_KEY: {DEFAULT_COORDINATE_SYSTEM: Identity()}} assert isinstance(transformed, DaskDataFrame) for ax in axes: indices = xtransformed["dim"] == ax diff --git a/src/spatialdata/_core/query/relational_query.py b/src/spatialdata/_core/query/relational_query.py index 5d19d227..afa2b6e3 100644 --- a/src/spatialdata/_core/query/relational_query.py +++ b/src/spatialdata/_core/query/relational_query.py @@ -12,7 +12,7 @@ import numpy as np import pandas as pd from anndata import AnnData -from dask.dataframe.core import DataFrame as DaskDataFrame +from dask.dataframe import DataFrame as DaskDataFrame from datatree import DataTree from geopandas import GeoDataFrame from xarray import DataArray diff --git a/src/spatialdata/_core/query/spatial_query.py b/src/spatialdata/_core/query/spatial_query.py index 2115a8be..6e2ef277 100644 --- a/src/spatialdata/_core/query/spatial_query.py +++ b/src/spatialdata/_core/query/spatial_query.py @@ -9,7 +9,7 @@ import dask.array as da import dask.dataframe as dd import numpy as np -from dask.dataframe.core import DataFrame as DaskDataFrame +from dask.dataframe import DataFrame as DaskDataFrame from datatree import DataTree from geopandas import GeoDataFrame from shapely.geometry import MultiPolygon, Polygon diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index d66ba725..0f7facdc 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -11,16 +11,14 @@ import pandas as pd import zarr from anndata import AnnData +from dask.dataframe import DataFrame as DaskDataFrame from dask.dataframe import read_parquet -from dask.dataframe.core import DataFrame as DaskDataFrame from dask.delayed import Delayed from datatree import DataTree from geopandas import GeoDataFrame -from multiscale_spatial_image.multiscale_spatial_image import MultiscaleSpatialImage from ome_zarr.io import parse_url from ome_zarr.types import JSONDict from shapely import MultiPolygon, Polygon -from spatial_image import SpatialImage from xarray import DataArray from spatialdata._core._elements import Images, Labels, Points, Shapes, Tables @@ -97,9 +95,7 @@ class SpatialData: ----- The SpatialElements are stored with standard types: - - images and labels are stored as :class:`spatial_image.SpatialImage` or - :class:`multiscale_spatial_image.MultiscaleSpatialImage` objects, which are respectively equivalent to - :class:`xarray.DataArray` and to a :class:`datatree.DataTree` of :class:`xarray.DataArray` objects. + - images and labels are stored as :class:`xarray.DataArray` or :class:`datatree.DataTree` objects. - points are stored as :class:`dask.dataframe.DataFrame` objects. - shapes are stored as :class:`geopandas.GeoDataFrame`. - the table are stored as :class:`anndata.AnnData` objects, with the spatial coordinates stored in the obsm @@ -856,8 +852,8 @@ def transform_element_to_coordinate_system( else: # When maintaining positioning is true, and if the element has a transformation to target_coordinate_system # (this may not be the case because it could be that the element is not directly mapped to that coordinate - # system), then the transformation to the target coordinate system is not needed # because the data is now - # already transformed; here we remove such transformation. + # system), then the transformation to the target coordinate system is not needed + # because the data is now already transformed; here we remove such transformation. d = get_transformation(transformed, get_all=True) assert isinstance(d, dict) if target_coordinate_system in d: @@ -1595,7 +1591,7 @@ def add_image( def add_labels( self, name: str, - labels: SpatialImage | MultiscaleSpatialImage, + labels: DataArray | DataTree, storage_options: JSONDict | list[JSONDict] | None = None, overwrite: bool = False, ) -> None: @@ -1743,8 +1739,8 @@ def h(s: str) -> str: descr += f"{h(attr + 'level1.1')}{k!r}: {descr_class} " f"shape: {v.shape} (2D shapes)" elif attr == "points": length: int | None = None - if len(v.dask.layers) == 1: - name, layer = v.dask.layers.items().__iter__().__next__() + if len(v.dask) == 1: + name, layer = v.dask.items().__iter__().__next__() if "read-parquet" in name: t = layer.creation_info["args"] assert isinstance(t, tuple) diff --git a/src/spatialdata/_io/_utils.py b/src/spatialdata/_io/_utils.py index d57b063a..05c650fc 100644 --- a/src/spatialdata/_io/_utils.py +++ b/src/spatialdata/_io/_utils.py @@ -6,7 +6,7 @@ import re import tempfile import warnings -from collections.abc import Generator, Mapping +from collections.abc import Generator, Mapping, Sequence from contextlib import contextmanager from functools import singledispatch from pathlib import Path @@ -17,8 +17,8 @@ from anndata import AnnData from anndata import read_zarr as read_anndata_zarr from anndata.experimental import read_elem -from dask.array.core import Array as DaskArray -from dask.dataframe.core import DataFrame as DaskDataFrame +from dask.array import Array as DaskArray +from dask.dataframe import DataFrame as DaskDataFrame from datatree import DataTree from geopandas import GeoDataFrame from ome_zarr.format import Format @@ -266,21 +266,57 @@ def _(element: AnnData | GeoDataFrame) -> list[str]: def _get_backing_files(element: DaskArray | DaskDataFrame) -> list[str]: - files = [] - for k, v in element.dask.layers.items(): - if k.startswith("original-from-zarr-"): - mapping = v.mapping[k] - path = mapping.store.path - files.append(os.path.realpath(path)) - if k.startswith("read-parquet-"): - t = v.creation_info["args"] - assert isinstance(t, tuple) - assert len(t) == 1 - parquet_file = t[0] - files.append(os.path.realpath(parquet_file)) + files: list[str] = [] + _search_for_backing_files_recursively(subgraph=element.dask, files=files) return files +def _search_for_backing_files_recursively(subgraph: Any, files: list[str]) -> None: + # see the types allowed for the dask graph here: https://docs.dask.org/en/stable/spec.html + + # search recursively + if isinstance(subgraph, Mapping): + for k, v in subgraph.items(): + _search_for_backing_files_recursively(subgraph=k, files=files) + _search_for_backing_files_recursively(subgraph=v, files=files) + elif isinstance(subgraph, Sequence) and not isinstance(subgraph, str): + for v in subgraph: + _search_for_backing_files_recursively(subgraph=v, files=files) + + # cases where a backing file is found + if isinstance(subgraph, Mapping): + for k, v in subgraph.items(): + name = None + if isinstance(k, Sequence) and not isinstance(k, str): + name = k[0] + elif isinstance(k, str): + name = k + if name is not None: + if name.startswith("original-from-zarr"): + path = v.store.path + files.append(os.path.realpath(path)) + elif name.startswith("read-parquet") or name.startswith("read_parquet"): + if hasattr(v, "creation_info"): + # https://github.com/dask/dask/blob/ff2488aec44d641696e0b7aa41ed9e995c710705/dask/dataframe/io/parquet/core.py#L625 + t = v.creation_info["args"] + if not isinstance(t, tuple) or len(t) != 1: + raise ValueError( + f"Unable to parse the parquet file from the dask subgraph {subgraph}. Please " + f"report this bug." + ) + parquet_file = t[0] + files.append(os.path.realpath(parquet_file)) + elif isinstance(v, tuple) and len(v) > 1 and isinstance(v[1], dict) and "piece" in v[1]: + # https://github.com/dask/dask/blob/ff2488aec44d641696e0b7aa41ed9e995c710705/dask/dataframe/io/parquet/core.py#L870 + parquet_file, check0, check1 = v[1]["piece"] + if not parquet_file.endswith(".parquet") or check0 is not None or check1 is not None: + raise ValueError( + f"Unable to parse the parquet file from the dask subgraph {subgraph}. Please " + f"report this bug." + ) + files.append(os.path.realpath(parquet_file)) + + def _backed_elements_contained_in_path(path: Path, object: SpatialData | SpatialElement | AnnData) -> list[bool]: """ Return the list of boolean values indicating if backing files for an object are child directory of a path. @@ -333,6 +369,8 @@ def _is_subfolder(parent: Path, child: Path) -> bool: def _is_element_self_contained( element: DataArray | DataTree | DaskDataFrame | GeoDataFrame | AnnData, element_path: Path ) -> bool: + if isinstance(element, DaskDataFrame): + pass return all(_backed_elements_contained_in_path(path=element_path, object=element)) diff --git a/src/spatialdata/datasets.py b/src/spatialdata/datasets.py index 574cf98d..eb3e4c01 100644 --- a/src/spatialdata/datasets.py +++ b/src/spatialdata/datasets.py @@ -9,7 +9,7 @@ import pandas as pd import scipy from anndata import AnnData -from dask.dataframe.core import DataFrame as DaskDataFrame +from dask.dataframe import DataFrame as DaskDataFrame from datatree import DataTree from geopandas import GeoDataFrame from numpy.random import default_rng diff --git a/src/spatialdata/models/models.py b/src/spatialdata/models/models.py index 26b8ddd4..943b6693 100644 --- a/src/spatialdata/models/models.py +++ b/src/spatialdata/models/models.py @@ -12,9 +12,9 @@ import numpy as np import pandas as pd from anndata import AnnData -from dask.array.core import Array as DaskArray +from dask.array import Array as DaskArray from dask.array.core import from_array -from dask.dataframe.core import DataFrame as DaskDataFrame +from dask.dataframe import DataFrame as DaskDataFrame from datatree import DataTree from geopandas import GeoDataFrame, GeoSeries from multiscale_spatial_image import to_multiscale @@ -600,23 +600,23 @@ def _( ndim = data.shape[1] axes = [X, Y, Z][:ndim] index = annotation.index if annotation is not None else None - table: DaskDataFrame = dd.from_pandas(pd.DataFrame(data, columns=axes, index=index), **kwargs) # type: ignore[attr-defined] + df_dict = {ax: data[:, i] for i, ax in enumerate(axes)} + df_kwargs = {"data": df_dict, "index": index} + if annotation is not None: if feature_key is not None: - feature_categ = dd.from_pandas( # type: ignore[attr-defined] - annotation[feature_key].astype(str).astype("category"), **kwargs - ) - table[feature_key] = feature_categ + df_dict[feature_key] = annotation[feature_key].astype(str).astype("category") if instance_key is not None: - table[instance_key] = annotation[instance_key] + df_dict[instance_key] = annotation[instance_key] if Z not in axes and Z in annotation.columns: logger.info(f"Column `{Z}` in `annotation` will be ignored since the data is 2D.") for c in set(annotation.columns) - {feature_key, instance_key, X, Y, Z}: - table[c] = dd.from_pandas(annotation[c], **kwargs) # type: ignore[attr-defined] - return cls._add_metadata_and_validate( - table, feature_key=feature_key, instance_key=instance_key, transformations=transformations - ) - return cls._add_metadata_and_validate(table, transformations=transformations) + df_dict[c] = annotation[c] + + table: DaskDataFrame = dd.from_pandas(pd.DataFrame(**df_kwargs), **kwargs) + return cls._add_metadata_and_validate( + table, feature_key=feature_key, instance_key=instance_key, transformations=transformations + ) @parse.register(pd.DataFrame) @parse.register(DaskDataFrame) diff --git a/src/spatialdata/transformations/_utils.py b/src/spatialdata/transformations/_utils.py index 40e90294..98645e96 100644 --- a/src/spatialdata/transformations/_utils.py +++ b/src/spatialdata/transformations/_utils.py @@ -4,7 +4,7 @@ from typing import TYPE_CHECKING, Any, Optional, Union import numpy as np -from dask.dataframe.core import DataFrame as DaskDataFrame +from dask.dataframe import DataFrame as DaskDataFrame from datatree import DataTree from geopandas import GeoDataFrame from xarray import DataArray @@ -43,11 +43,20 @@ def _set_transformations_to_dict_container(dict_container: Any, transformations: if TRANSFORM_KEY not in dict_container: dict_container[TRANSFORM_KEY] = {} + # this modifies the dict in place without triggering a setter in the element class. Probably we want to stop using + # _set_transformations_to_dict_container and use _set_transformations_to_element instead dict_container[TRANSFORM_KEY] = transformations -def _set_transformations_xarray(e: DataArray, transformations: MappingToCoordinateSystem_t) -> None: - _set_transformations_to_dict_container(e.attrs, transformations) +def _set_transformations_to_element(element: Any, transformations: MappingToCoordinateSystem_t) -> None: + from spatialdata.models._utils import TRANSFORM_KEY + + attrs = element.attrs + if TRANSFORM_KEY not in attrs: + attrs[TRANSFORM_KEY] = {} + attrs[TRANSFORM_KEY] = transformations + # this calls an eventual setter in the element class; modifying the attrs directly would not trigger the setter + element.attrs = attrs @singledispatch @@ -71,35 +80,9 @@ def _set_transformations(e: SpatialElement, transformations: MappingToCoordinate raise TypeError(f"Unsupported type: {type(e)}") -@_get_transformations.register(DataArray) -def _(e: DataArray) -> MappingToCoordinateSystem_t | None: - return _get_transformations_xarray(e) - - -@_get_transformations.register(DataTree) -def _(e: DataTree) -> MappingToCoordinateSystem_t | None: - from spatialdata.models._utils import TRANSFORM_KEY - - if TRANSFORM_KEY in e.attrs: - raise ValueError( - "A multiscale image must not contain a transformation in the outer level; the transformations need to be " - "stored in the inner levels." - ) - d = dict(e["scale0"]) - assert len(d) == 1 - xdata = d.values().__iter__().__next__() - return _get_transformations_xarray(xdata) - - -@_get_transformations.register(GeoDataFrame) -@_get_transformations.register(DaskDataFrame) -def _(e: Union[GeoDataFrame, DaskDataFrame]) -> MappingToCoordinateSystem_t | None: - return _get_transformations_from_dict_container(e.attrs) - - @_set_transformations.register(DataArray) def _(e: DataArray, transformations: MappingToCoordinateSystem_t) -> None: - _set_transformations_xarray(e, transformations) + _set_transformations_to_dict_container(e.attrs, transformations) @_set_transformations.register(DataTree) @@ -132,16 +115,43 @@ def _(e: DataTree, transformations: MappingToCoordinateSystem_t) -> None: for k, v in transformations.items(): sequence: BaseTransformation = Sequence([scale_transformation, v]) new_transformations[k] = sequence - _set_transformations_xarray(xdata, new_transformations) + _set_transformations(xdata, new_transformations) else: - _set_transformations_xarray(xdata, transformations) + _set_transformations(xdata, transformations) old_shape = new_shape @_set_transformations.register(GeoDataFrame) @_set_transformations.register(DaskDataFrame) def _(e: Union[GeoDataFrame, GeoDataFrame], transformations: MappingToCoordinateSystem_t) -> None: - _set_transformations_to_dict_container(e.attrs, transformations) + _set_transformations_to_element(e, transformations) + # _set_transformations_to_dict_container(e.attrs, transformations) + + +@_get_transformations.register(DataArray) +def _(e: DataArray) -> MappingToCoordinateSystem_t | None: + return _get_transformations_xarray(e) + + +@_get_transformations.register(DataTree) +def _(e: DataTree) -> MappingToCoordinateSystem_t | None: + from spatialdata.models._utils import TRANSFORM_KEY + + if TRANSFORM_KEY in e.attrs: + raise ValueError( + "A multiscale image must not contain a transformation in the outer level; the transformations need to be " + "stored in the inner levels." + ) + d = dict(e["scale0"]) + assert len(d) == 1 + xdata = d.values().__iter__().__next__() + return _get_transformations_xarray(xdata) + + +@_get_transformations.register(GeoDataFrame) +@_get_transformations.register(DaskDataFrame) +def _(e: Union[GeoDataFrame, DaskDataFrame]) -> MappingToCoordinateSystem_t | None: + return _get_transformations_from_dict_container(e.attrs) @singledispatch diff --git a/tests/conftest.py b/tests/conftest.py index 8cc14f12..d4d06cda 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,8 @@ from __future__ import annotations +import dask + +dask.config.set({"dataframe.query-planning": False}) from collections.abc import Sequence from pathlib import Path from typing import Any @@ -10,7 +13,7 @@ import pandas as pd import pytest from anndata import AnnData -from dask.dataframe.core import DataFrame as DaskDataFrame +from dask.dataframe import DataFrame as DaskDataFrame from datatree import DataTree from geopandas import GeoDataFrame from numpy.random import default_rng @@ -18,7 +21,7 @@ from shapely import linearrings, polygons from shapely.geometry import MultiPolygon, Point, Polygon from skimage import data -from spatialdata._core._deepcopy import deepcopy as _deepcopy +from spatialdata._core._deepcopy import deepcopy from spatialdata._core.spatialdata import SpatialData from spatialdata._types import ArrayLike from spatialdata.datasets import BlobsDataset @@ -305,14 +308,9 @@ def labels_blobs() -> ArrayLike: @pytest.fixture() def sdata_blobs() -> SpatialData: """Create a 2D labels.""" - from copy import deepcopy - from spatialdata.datasets import blobs - sdata = deepcopy(blobs(256, 300, 3)) - for k, v in sdata.shapes.items(): - sdata.shapes[k] = _deepcopy(v) - return sdata + return deepcopy(blobs(256, 300, 3)) def _make_points(coordinates: np.ndarray) -> DaskDataFrame: @@ -399,10 +397,12 @@ def _make_sdata_for_testing_querying_and_aggretation() -> SpatialData: by_squares.loc[len(by_squares)] = [polygon] ShapesModel.validate(by_squares) - s = pd.Series(pd.Categorical(["a"] * 9 + ["b"] * 9 + ["c"] * 2)) - values_points["categorical_in_ddf"] = dd.from_pandas(s, npartitions=1) - s = pd.Series(RNG.random(20)) - values_points["numerical_in_ddf"] = dd.from_pandas(s, npartitions=1) + s_cat = pd.Series(pd.Categorical(["a"] * 9 + ["b"] * 9 + ["c"] * 2)) + s_num = pd.Series(RNG.random(20)) + # workaround for https://github.com/dask/dask/issues/11147, let's recompute the dataframe (it's a small one) + values_points = PointsModel.parse( + dd.from_pandas(values_points.compute().assign(categorical_in_ddf=s_cat, numerical_in_ddf=s_num), npartitions=1) + ) sdata = SpatialData( points={"points": values_points}, diff --git a/tests/core/operations/test_aggregations.py b/tests/core/operations/test_aggregations.py index 290839c5..9736ce63 100644 --- a/tests/core/operations/test_aggregations.py +++ b/tests/core/operations/test_aggregations.py @@ -74,6 +74,7 @@ def test_aggregate_points_by_shapes(sdata_query_aggregation, by_shapes: str, val assert np.all(np.isclose(result_adata.X.todense().A, np.array([[s0], [0], [0], [0], [s4]]))) # id_key can be implicit for points + points.attrs[PointsModel.ATTRS_KEY] = {} points.attrs[PointsModel.ATTRS_KEY][PointsModel.FEATURE_KEY] = value_key result_adata_implicit = aggregate(values=points, by=shapes, agg_func="sum").tables["table"] assert_equal(result_adata, result_adata_implicit) @@ -406,7 +407,7 @@ def test_aggregate_requiring_alignment(sdata_blobs: SpatialData, values, by) -> sdata2 = SpatialData.init_from_elements({"values": sdata["values"], "by": transformed_sdata["by"]}) # let's take values from the original sdata (non-transformed but aligned to 'other'); let's take by from the # transformed sdata - out3 = aggregate(values=sdata["values"], by=sdata2["by"], target_coordinate_system="other", agg_func="sum").tables[ + out3 = aggregate(values=sdata2["values"], by=sdata2["by"], target_coordinate_system="other", agg_func="sum").tables[ "table" ] assert np.allclose(out0.X.todense().A, out3.X.todense().A) diff --git a/tests/core/query/test_spatial_query.py b/tests/core/query/test_spatial_query.py index 93138633..eb688b27 100644 --- a/tests/core/query/test_spatial_query.py +++ b/tests/core/query/test_spatial_query.py @@ -6,7 +6,7 @@ import pytest import xarray from anndata import AnnData -from dask.dataframe.core import DataFrame as DaskDataFrame +from dask.dataframe import DataFrame as DaskDataFrame from datatree import DataTree from geopandas import GeoDataFrame from shapely import Polygon @@ -586,4 +586,3 @@ def test_attributes_are_copied(full_sdata, with_polygon_query: bool, name: str): # check that the attributes of the queried element are not the same as the old ones assert sdata[name].attrs is not queried[name].attrs assert sdata[name].attrs["transform"] is not queried[name].attrs["transform"] - pass diff --git a/tests/io/test_readwrite.py b/tests/io/test_readwrite.py index e9fe6480..162f81b1 100644 --- a/tests/io/test_readwrite.py +++ b/tests/io/test_readwrite.py @@ -259,17 +259,12 @@ def test_incremental_io_table_legacy(self, table_single_annotation: SpatialData) assert len(s3["table"]) == len(s2["table"]) def test_io_and_lazy_loading_points(self, points): - elem_name = list(points.points.keys())[0] with tempfile.TemporaryDirectory() as td: f = os.path.join(td, "data.zarr") - dask0 = points.points[elem_name] points.write(f) - assert all("read-parquet" not in key for key in dask0.dask.layers) assert len(get_dask_backing_files(points)) == 0 sdata2 = SpatialData.read(f) - dask1 = sdata2[elem_name] - assert any("read-parquet" in key for key in dask1.dask.layers) assert len(get_dask_backing_files(sdata2)) > 0 def test_io_and_lazy_loading_raster(self, images, labels): diff --git a/tests/io/test_utils.py b/tests/io/test_utils.py index 6a913f9a..ad695cfb 100644 --- a/tests/io/test_utils.py +++ b/tests/io/test_utils.py @@ -19,10 +19,13 @@ def test_backing_files_points(points): p1 = points1.points["points_0"] p2 = dd.concat([p0, p1], axis=0) files = get_dask_backing_files(p2) - expected_zarr_locations = [ + expected_zarr_locations_legacy = [ os.path.realpath(os.path.join(f, "points/points_0/points.parquet")) for f in [f0, f1] ] - assert set(files) == set(expected_zarr_locations) + expected_zarr_locations_new = [ + os.path.realpath(os.path.join(f, "points/points_0/points.parquet/part.0.parquet")) for f in [f0, f1] + ] + assert set(files) == set(expected_zarr_locations_legacy) or set(files) == set(expected_zarr_locations_new) def test_backing_files_images(images): @@ -105,8 +108,12 @@ def test_backing_files_combining_points_and_images(points, images): v.compute_chunk_sizes() im2 = v + im1 files = get_dask_backing_files(im2) - expected_zarr_locations = [ + expected_zarr_locations_old = [ os.path.realpath(os.path.join(f0, "points/points_0/points.parquet")), os.path.realpath(os.path.join(f1, "images/image2d")), ] - assert set(files) == set(expected_zarr_locations) + expected_zarr_locations_new = [ + os.path.realpath(os.path.join(f0, "points/points_0/points.parquet/part.0.parquet")), + os.path.realpath(os.path.join(f1, "images/image2d")), + ] + assert set(files) == set(expected_zarr_locations_old) or set(files) == set(expected_zarr_locations_new) diff --git a/tests/models/test_models.py b/tests/models/test_models.py index 24d5dc80..c3d992ce 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -3,7 +3,6 @@ import os import re import tempfile -from copy import deepcopy from functools import partial from pathlib import Path from typing import Any, Callable @@ -15,13 +14,13 @@ import pytest from anndata import AnnData from dask.array.core import from_array -from dask.dataframe.core import DataFrame as DaskDataFrame +from dask.dataframe import DataFrame as DaskDataFrame +from datatree import DataTree from geopandas import GeoDataFrame -from multiscale_spatial_image import MultiscaleSpatialImage from numpy.random import default_rng from shapely.geometry import MultiPolygon, Point, Polygon from shapely.io import to_ragged_array -from spatial_image import SpatialImage, to_spatial_image +from spatial_image import to_spatial_image from spatialdata._core.spatialdata import SpatialData from spatialdata._types import ArrayLike from spatialdata.models._utils import ( @@ -45,13 +44,12 @@ from spatialdata.testing import assert_elements_are_identical from spatialdata.transformations._utils import ( _set_transformations, - _set_transformations_xarray, ) from spatialdata.transformations.operations import ( get_transformation, set_transformation, ) -from spatialdata.transformations.transformations import Scale +from spatialdata.transformations.transformations import Identity, Scale from xarray import DataArray from tests.conftest import ( @@ -79,47 +77,39 @@ def test_validate_axis_name(): class TestModels: def _parse_transformation_from_multiple_places(self, model: Any, element: Any, **kwargs) -> None: # This function seems convoluted but the idea is simple: sometimes the parser creates a whole new object, - # other times (SpatialImage, DataArray, AnnData, GeoDataFrame) the object is enriched in-place. In such - # cases we check that if there was already a transformation in the object we consider it then we are not + # other times (DataArray, AnnData, GeoDataFrame) the object is enriched in-place. In such + # cases we check that if there was already a transformation in the object then we are not # passing it also explicitly in the parser. # This function does that for all the models (it's called by the various tests of the models) and it first # creates clean copies of the element, and then puts the transformation inside it with various methods - if any(isinstance(element, t) for t in (SpatialImage, DataArray, AnnData, GeoDataFrame, DaskDataFrame)): - element_erased = deepcopy(element) - # we are not respecting the function signature (the transform should be not None); it's fine for testing - if isinstance(element_erased, DataArray) and not isinstance(element_erased, SpatialImage): - # this case is for xarray.DataArray where the user manually updates the transform in attrs, - # or when a user takes an image from a MultiscaleSpatialImage - _set_transformations_xarray(element_erased, {}) - else: - _set_transformations(element_erased, {}) - element_copy0 = deepcopy(element_erased) - parsed0 = model.parse(element_copy0, **kwargs) + if any(isinstance(element, t) for t in (DataArray, GeoDataFrame, DaskDataFrame)): + # no transformation in the element, nor passed to the parser (default transformation is added) + + _set_transformations(element, {}) + parsed0 = model.parse(element, **kwargs) + assert get_transformation(parsed0, "global") == Identity() - element_copy1 = deepcopy(element_erased) + # no transformation in the element, but passed to the parser + _set_transformations(element, {}) t = Scale([1.0, 1.0], axes=("x", "y")) - parsed1 = model.parse(element_copy1, transformations={"global": t}, **kwargs) - assert get_transformation(parsed0, "global") != get_transformation(parsed1, "global") + parsed1 = model.parse(element, transformations={"global": t}, **kwargs) + assert get_transformation(parsed1, "global") == t - element_copy2 = deepcopy(element_erased) - if isinstance(element_copy2, DataArray) and not isinstance(element_copy2, SpatialImage): - _set_transformations_xarray(element_copy2, {"global": t}) - else: - set_transformation(element_copy2, t, "global") - parsed2 = model.parse(element_copy2, **kwargs) - assert get_transformation(parsed1, "global") == get_transformation(parsed2, "global") + # transformation in the element, but not passed to the parser + _set_transformations(element, {}) + set_transformation(element, t, "global") + parsed2 = model.parse(element, **kwargs) + assert get_transformation(parsed2, "global") == t + # transformation in the element, and passed to the parser with pytest.raises(ValueError): - element_copy3 = deepcopy(element_erased) - if isinstance(element_copy3, DataArray) and not isinstance(element_copy3, SpatialImage): - _set_transformations_xarray(element_copy3, {"global": t}) - else: - set_transformation(element_copy3, t, "global") - model.parse(element_copy3, transformations={"global": t}, **kwargs) + _set_transformations(element, {}) + set_transformation(element, t, "global") + model.parse(element, transformations={"global": t}, **kwargs) elif any( isinstance(element, t) for t in ( - MultiscaleSpatialImage, + DataTree, str, np.ndarray, dask.array.core.Array,