Skip to content

Commit

Permalink
update typing
Browse files Browse the repository at this point in the history
  • Loading branch information
Intron7 committed Oct 14, 2024
1 parent 0201658 commit ea57084
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 27 deletions.
24 changes: 1 addition & 23 deletions src/rapids_singlecell/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,7 @@

import cupy as cp
from cupyx.scipy.sparse import csr_matrix

try:
from dask.array import Array as DaskArray
except ImportError:

class DaskArray:
pass


try:
from dask.distributed import Client as DaskClient
except ImportError:

class DaskClient:
pass


def _get_dask_client(client=None):
from dask.distributed import default_client

if client is None or not isinstance(client, DaskClient):
return default_client()
return client
from dask.array import Array as DaskArray # noqa: F401


def _meta_dense(dtype):
Expand Down
7 changes: 7 additions & 0 deletions src/rapids_singlecell/_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,13 @@

from typing import TYPE_CHECKING, Union

import cupy as cp
import numpy as np
from cupyx.scipy.sparse import csc_matrix, csr_matrix
from dask.array import Array as DaskArray

AnyRandom = Union[int, np.random.RandomState, None] # noqa: UP007


ArrayTypes = Union[cp.ndarray, csc_matrix, csr_matrix] # noqa: UP007
ArrayTypesDask = Union[cp.ndarray, csc_matrix, csr_matrix, DaskArray] # noqa: UP007
14 changes: 10 additions & 4 deletions src/rapids_singlecell/preprocessing/_qc.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
if TYPE_CHECKING:
from anndata import AnnData

from rapids_singlecell._utils import ArrayTypesDask


def calculate_qc_metrics(
adata: AnnData,
Expand Down Expand Up @@ -110,7 +112,9 @@ def calculate_qc_metrics(
)


def _first_pass_qc(X):
def _first_pass_qc(
X: ArrayTypesDask,
) -> tuple[cp.ndarray, cp.ndarray, cp.ndarray, cp.ndarray]:
if isinstance(X, DaskArray):
return _first_pass_qc_dask(X)

Expand Down Expand Up @@ -181,7 +185,9 @@ def _first_pass_qc(X):


@with_cupy_rmm
def _first_pass_qc_dask(X):
def _first_pass_qc_dask(
X: DaskArray,
) -> tuple[cp.ndarray, cp.ndarray, cp.ndarray, cp.ndarray]:
import dask
import dask.array as da

Expand Down Expand Up @@ -337,7 +343,7 @@ def __qc_calc_2(X_part):
)


def _second_pass_qc(X, mask):
def _second_pass_qc(X: ArrayTypesDask, mask: cp.ndarray) -> cp.ndarray:
if isinstance(X, DaskArray):
return _second_pass_qc_dask(X, mask)
sums_cells_sub = cp.zeros(X.shape[0], dtype=X.dtype)
Expand Down Expand Up @@ -381,7 +387,7 @@ def _second_pass_qc(X, mask):


@with_cupy_rmm
def _second_pass_qc_dask(X, mask):
def _second_pass_qc_dask(X: DaskArray, mask: cp.ndarray) -> cp.ndarray:
if isinstance(X._meta, sparse.csr_matrix):
from ._kernels._qc_kernels import _sparse_qc_csr_sub

Expand Down

0 comments on commit ea57084

Please sign in to comment.