From ea57084a4052a156120c21a90f9ceda3714de162 Mon Sep 17 00:00:00 2001 From: Intron7 Date: Mon, 14 Oct 2024 15:09:00 +0200 Subject: [PATCH] update typing --- src/rapids_singlecell/_compat.py | 24 +--------------------- src/rapids_singlecell/_utils/__init__.py | 7 +++++++ src/rapids_singlecell/preprocessing/_qc.py | 14 +++++++++---- 3 files changed, 18 insertions(+), 27 deletions(-) diff --git a/src/rapids_singlecell/_compat.py b/src/rapids_singlecell/_compat.py index ecb20551..f569504b 100644 --- a/src/rapids_singlecell/_compat.py +++ b/src/rapids_singlecell/_compat.py @@ -2,29 +2,7 @@ import cupy as cp from cupyx.scipy.sparse import csr_matrix - -try: - from dask.array import Array as DaskArray -except ImportError: - - class DaskArray: - pass - - -try: - from dask.distributed import Client as DaskClient -except ImportError: - - class DaskClient: - pass - - -def _get_dask_client(client=None): - from dask.distributed import default_client - - if client is None or not isinstance(client, DaskClient): - return default_client() - return client +from dask.array import Array as DaskArray # noqa: F401 def _meta_dense(dtype): diff --git a/src/rapids_singlecell/_utils/__init__.py b/src/rapids_singlecell/_utils/__init__.py index 201809b6..8076b7f1 100644 --- a/src/rapids_singlecell/_utils/__init__.py +++ b/src/rapids_singlecell/_utils/__init__.py @@ -2,6 +2,13 @@ from typing import TYPE_CHECKING, Union +import cupy as cp import numpy as np +from cupyx.scipy.sparse import csc_matrix, csr_matrix +from dask.array import Array as DaskArray AnyRandom = Union[int, np.random.RandomState, None] # noqa: UP007 + + +ArrayTypes = Union[cp.ndarray, csc_matrix, csr_matrix] # noqa: UP007 +ArrayTypesDask = Union[cp.ndarray, csc_matrix, csr_matrix, DaskArray] # noqa: UP007 diff --git a/src/rapids_singlecell/preprocessing/_qc.py b/src/rapids_singlecell/preprocessing/_qc.py index c095e314..bd6bdd2c 100644 --- a/src/rapids_singlecell/preprocessing/_qc.py +++ b/src/rapids_singlecell/preprocessing/_qc.py @@ -15,6 +15,8 @@ if TYPE_CHECKING: from anndata import AnnData + from rapids_singlecell._utils import ArrayTypesDask + def calculate_qc_metrics( adata: AnnData, @@ -110,7 +112,9 @@ def calculate_qc_metrics( ) -def _first_pass_qc(X): +def _first_pass_qc( + X: ArrayTypesDask, +) -> tuple[cp.ndarray, cp.ndarray, cp.ndarray, cp.ndarray]: if isinstance(X, DaskArray): return _first_pass_qc_dask(X) @@ -181,7 +185,9 @@ def _first_pass_qc(X): @with_cupy_rmm -def _first_pass_qc_dask(X): +def _first_pass_qc_dask( + X: DaskArray, +) -> tuple[cp.ndarray, cp.ndarray, cp.ndarray, cp.ndarray]: import dask import dask.array as da @@ -337,7 +343,7 @@ def __qc_calc_2(X_part): ) -def _second_pass_qc(X, mask): +def _second_pass_qc(X: ArrayTypesDask, mask: cp.ndarray) -> cp.ndarray: if isinstance(X, DaskArray): return _second_pass_qc_dask(X, mask) sums_cells_sub = cp.zeros(X.shape[0], dtype=X.dtype) @@ -381,7 +387,7 @@ def _second_pass_qc(X, mask): @with_cupy_rmm -def _second_pass_qc_dask(X, mask): +def _second_pass_qc_dask(X: DaskArray, mask: cp.ndarray) -> cp.ndarray: if isinstance(X._meta, sparse.csr_matrix): from ._kernels._qc_kernels import _sparse_qc_csr_sub