Skip to content

Commit

Permalink
Merge pull request #232 from melonora/xarray_datatree
Browse files Browse the repository at this point in the history
import datatree from xarray
  • Loading branch information
LucaMarconato authored Nov 26, 2024
2 parents 3c6be23 + ed60d9f commit 8a41c13
Show file tree
Hide file tree
Showing 20 changed files with 46 additions and 51 deletions.
2 changes: 2 additions & 0 deletions .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ ignore =
D400
# First line should be in imperative mood; try rephrasing
D401
# Abstract base class without abstractmethod.
B024
exclude = .git,__pycache__,build,docs/_build,dist
per-file-ignores =
tests/*: D
Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python: ["3.9", "3.10"]
python: ["3.10", "3.12"]
os: [ubuntu-latest]

env:
Expand Down Expand Up @@ -52,15 +52,15 @@ jobs:
pip install --pre -e ".[dev,test]"
- name: Download artifact of test data
if: matrix.python == '3.10'
if: matrix.python == '3.12'
uses: dawidd6/action-download-artifact@v2
with:
workflow: prepare_test_data.yaml
name: data
path: ./data

- name: List the data directory
if: matrix.python == '3.10'
if: matrix.python == '3.12'
run: |
ls -l ./data
pwd
Expand Down
2 changes: 1 addition & 1 deletion .mypy.ini
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[mypy]
python_version = 3.9
python_version = 3.10
plugins = numpy.typing.mypy_plugin

ignore_errors = False
Expand Down
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ fail_fast: false
default_language_version:
python: python3
default_stages:
- commit
- push
- pre-commit
- pre-push
minimum_pre_commit_version: 2.16.0
repos:
- repo: https://github.com/psf/black
Expand Down Expand Up @@ -73,7 +73,7 @@ repos:
rev: v3.17.0
hooks:
- id: pyupgrade
args: [--py3-plus, --py39-plus, --keep-runtime-typing]
args: [--py3-plus, --py310-plus, --keep-runtime-typing]
- repo: local
hooks:
- id: forbid-to-commit
Expand Down
2 changes: 1 addition & 1 deletion .readthedocs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ version: 2
build:
os: ubuntu-20.04
tools:
python: "3.9"
python: "3.10"
sphinx:
configuration: docs/conf.py
fail_on_warning: true
Expand Down
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@
}

intersphinx_mapping = {
"python": ("https://docs.python.org/3", None),
"anndata": ("https://anndata.readthedocs.io/en/stable/", None),
"spatialdata": ("https://scverse-spatialdata.readthedocs.io/en/latest/", None),
"numpy": ("https://numpy.org/doc/stable/", None),
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ dynamic= [
]
description = "SpatialData IO for common techs"
readme = "README.md"
requires-python = ">=3.9"
requires-python = ">=3.10"
license = {file = "LICENSE"}
authors = [
{name = "scverse"},
Expand Down Expand Up @@ -83,7 +83,7 @@ skip_glob = ["docs/*"]

[tool.black]
line-length = 120
target-version = ['py39']
target-version = ['py310']
include = '\.pyi?$'
exclude = '''
(
Expand Down
3 changes: 2 additions & 1 deletion src/spatialdata_io/_constants/_enum.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from abc import ABC, ABCMeta
from collections.abc import Callable
from enum import Enum, EnumMeta
from functools import wraps
from typing import Any, Callable
from typing import Any


class PrettyEnum(Enum):
Expand Down
3 changes: 2 additions & 1 deletion src/spatialdata_io/_docs.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from __future__ import annotations

from collections.abc import Callable
from textwrap import dedent
from typing import Any, Callable
from typing import Any


def inject_docs(**kwargs: Any) -> Callable[..., Any]: # noqa: D103
Expand Down
3 changes: 2 additions & 1 deletion src/spatialdata_io/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

import functools
import warnings
from typing import Any, Callable, TypeVar
from collections.abc import Callable
from typing import Any, TypeVar

RT = TypeVar("RT")

Expand Down
8 changes: 4 additions & 4 deletions src/spatialdata_io/readers/_utils/_read_10x_h5.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
# code below taken from https://github.com/scverse/scanpy/blob/master/scanpy/readwrite.py

from pathlib import Path
from typing import Any, Optional, Union
from typing import Any

import h5py
import numpy as np
Expand All @@ -40,8 +40,8 @@


def _read_10x_h5(
filename: Union[str, Path],
genome: Optional[str] = None,
filename: str | Path,
genome: str | None = None,
gex_only: bool = True,
) -> AnnData:
"""
Expand Down Expand Up @@ -96,7 +96,7 @@ def _read_10x_h5(
return adata


def _read_v3_10x_h5(filename: Union[str, Path], *, start: Optional[Any] = None) -> AnnData:
def _read_v3_10x_h5(filename: str | Path, *, start: Any | None = None) -> AnnData:
"""Read hdf5 file from Cell Ranger v3 or later versions."""
with h5py.File(str(filename), "r") as f:
try:
Expand Down
13 changes: 2 additions & 11 deletions src/spatialdata_io/readers/_utils/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,29 +3,20 @@
import os
from collections.abc import Mapping
from pathlib import Path
from typing import Any, Optional, Union
from typing import Any, Union

import numpy as np
from anndata import AnnData, read_text
from h5py import File

from spatialdata_io.readers._utils._read_10x_h5 import _read_10x_h5

PathLike = Union[os.PathLike, str] # type:ignore[type-arg]

try:
from numpy.typing import NDArray

NDArrayA = NDArray[Any]
except (ImportError, TypeError):
NDArray = np.ndarray
NDArrayA = np.ndarray


def _read_counts(
path: str | Path,
counts_file: str,
library_id: Optional[str] = None,
library_id: str | None = None,
**kwargs: Any,
) -> tuple[AnnData, str]:
path = Path(path)
Expand Down
4 changes: 2 additions & 2 deletions src/spatialdata_io/readers/cosmx.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from collections.abc import Mapping
from pathlib import Path
from types import MappingProxyType
from typing import Any, Optional
from typing import Any

import dask.array as da
import numpy as np
Expand Down Expand Up @@ -38,7 +38,7 @@
@inject_docs(cx=CosmxKeys)
def cosmx(
path: str | Path,
dataset_id: Optional[str] = None,
dataset_id: str | None = None,
transcripts: bool = True,
imread_kwargs: Mapping[str, Any] = MappingProxyType({}),
image_models_kwargs: Mapping[str, Any] = MappingProxyType({}),
Expand Down
15 changes: 7 additions & 8 deletions src/spatialdata_io/readers/dbit.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import re
from pathlib import Path
from re import Pattern
from typing import Optional, Union

import anndata as ad
import numpy as np
Expand All @@ -27,9 +26,9 @@ def _check_path(
path: Path,
pattern: Pattern[str],
key: DbitKeys,
path_specific: Optional[str | Path] = None,
path_specific: str | Path | None = None,
optional_arg: bool = False,
) -> tuple[Union[Path, None], bool]:
) -> tuple[Path | None, bool]:
"""
Check that the path is valid and match a regex pattern.
Expand Down Expand Up @@ -218,11 +217,11 @@ def _xy2edges(xy: list[int], scale: float = 1.0, border: bool = True, border_sca

@inject_docs(vx=DbitKeys)
def dbit(
path: Optional[str | Path] = None,
anndata_path: Optional[str] = None,
barcode_position: Optional[str] = None,
image_path: Optional[str] = None,
dataset_id: Optional[str] = None,
path: str | Path | None = None,
anndata_path: str | None = None,
barcode_position: str | None = None,
image_path: str | None = None,
dataset_id: str | None = None,
border: bool = True,
border_scale: float = 1,
) -> SpatialData:
Expand Down
4 changes: 2 additions & 2 deletions src/spatialdata_io/readers/merscope.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

import re
import warnings
from collections.abc import Mapping
from collections.abc import Callable, Mapping
from pathlib import Path
from types import MappingProxyType
from typing import Any, Callable, Literal
from typing import Any, Literal

import anndata
import dask.dataframe as dd
Expand Down
6 changes: 3 additions & 3 deletions src/spatialdata_io/readers/steinbock.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from collections.abc import Mapping
from pathlib import Path
from types import MappingProxyType
from typing import Any, Literal, Union
from typing import Any, Literal

import anndata as ad
from dask_image.imread import imread
Expand Down Expand Up @@ -95,7 +95,7 @@ def _get_images(
sample: str,
imread_kwargs: Mapping[str, Any] = MappingProxyType({}),
image_models_kwargs: Mapping[str, Any] = MappingProxyType({}),
) -> Union[SpatialImage, MultiscaleSpatialImage]:
) -> SpatialImage | MultiscaleSpatialImage:
image = imread(path / SteinbockKeys.IMAGES_DIR / f"{sample}{SteinbockKeys.IMAGE_SUFFIX}", **imread_kwargs)
return Image2DModel.parse(data=image, transformations={sample: Identity()}, rgb=None, **image_models_kwargs)

Expand All @@ -106,6 +106,6 @@ def _get_labels(
labels_kind: str,
imread_kwargs: Mapping[str, Any] = MappingProxyType({}),
image_models_kwargs: Mapping[str, Any] = MappingProxyType({}),
) -> Union[SpatialImage, MultiscaleSpatialImage]:
) -> SpatialImage | MultiscaleSpatialImage:
image = imread(path / labels_kind / f"{sample}{SteinbockKeys.LABEL_SUFFIX}", **imread_kwargs).squeeze()
return Labels2DModel.parse(data=image, transformations={sample: Identity()}, **image_models_kwargs)
4 changes: 2 additions & 2 deletions src/spatialdata_io/readers/stereoseq.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from collections.abc import Mapping
from pathlib import Path
from types import MappingProxyType
from typing import Any, Union
from typing import Any

import anndata as ad
import h5py
Expand All @@ -29,7 +29,7 @@
@inject_docs(xx=SK)
def stereoseq(
path: str | Path,
dataset_id: Union[str, None] = None,
dataset_id: str | None = None,
read_square_bin: bool = True,
optional_tif: bool = False,
imread_kwargs: Mapping[str, Any] = MappingProxyType({}),
Expand Down
2 changes: 1 addition & 1 deletion src/spatialdata_io/readers/visium_hd.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def visium_hd(
image_models_kwargs
Keyword arguments for :class:`spatialdata.models.Image2DModel`.
anndata_kwargs
Keyword arguments for :func:`anndata.read_h5ad`.
Keyword arguments for :func:`anndata.io.read_h5ad`.
Returns
-------
Expand Down
7 changes: 3 additions & 4 deletions src/spatialdata_io/readers/xenium.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from collections.abc import Mapping
from pathlib import Path
from types import MappingProxyType
from typing import Any, Optional
from typing import Any

import dask.array as da
import numpy as np
Expand All @@ -22,7 +22,6 @@
from anndata import AnnData
from dask.dataframe import read_parquet
from dask_image.imread import imread
from datatree.datatree import DataTree
from geopandas import GeoDataFrame
from joblib import Parallel, delayed
from pyarrow import Table
Expand All @@ -38,7 +37,7 @@
TableModel,
)
from spatialdata.transformations.transformations import Affine, Identity, Scale
from xarray import DataArray
from xarray import DataArray, DataTree

from spatialdata_io._constants._constants import XeniumKeys
from spatialdata_io._docs import inject_docs
Expand Down Expand Up @@ -364,7 +363,7 @@ def _decode_cell_id_column(cell_id_column: pd.Series) -> pd.Series:


def _get_polygons(
path: Path, file: str, specs: dict[str, Any], n_jobs: int, idx: Optional[ArrayLike] = None
path: Path, file: str, specs: dict[str, Any], n_jobs: int, idx: ArrayLike | None = None
) -> GeoDataFrame:
def _poly(arr: ArrayLike) -> Polygon:
return Polygon(arr[:-1])
Expand Down
2 changes: 1 addition & 1 deletion tests/test_xenium.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def test_roundtrip_with_data_limits() -> None:
# pointing to "data".
# The GitHub workflow "prepare_test_data.yaml" takes care of downloading the datasets and uploading an artifact for the
# tests to use
@pytest.mark.skipif(sys.version_info < (3, 10), reason="Test requires Python 3.10 or higher")
@pytest.mark.skipif(sys.version_info < (3, 12), reason="Test requires Python 3.10 or higher")
@pytest.mark.parametrize(
"dataset,expected",
[
Expand Down

0 comments on commit 8a41c13

Please sign in to comment.