From 1eaa56cc8611ef5961176f63979b1f08262a5cb5 Mon Sep 17 00:00:00 2001 From: Severin Dicks <37635888+Intron7@users.noreply.github.com> Date: Tue, 29 Aug 2023 13:44:43 +0200 Subject: [PATCH] forces Raw move to CPU (#1107) * forces move to CPU * added raw GPU test * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Use qualified function names, plus avoid future scipy deprecations --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Isaac Virshup --- anndata/_core/raw.py | 6 +++++- anndata/tests/test_gpu.py | 27 +++++++++++++++++++++++++++ 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/anndata/_core/raw.py b/anndata/_core/raw.py index f18f75bf8..f248b99be 100644 --- a/anndata/_core/raw.py +++ b/anndata/_core/raw.py @@ -29,7 +29,11 @@ def __init__( self._n_obs = adata.n_obs # construct manually if adata.isbacked == (X is None): - self._X = X + # Move from GPU to CPU since it's large and not always used + if isinstance(X, (CupyArray, CupySparseMatrix)): + self._X = X.get() + else: + self._X = X self._var = _gen_dataframe(var, self.X.shape[1], ["var_names"]) self._varm = AxisArrays(self, 1, varm) elif X is None: # construct from adata diff --git a/anndata/tests/test_gpu.py b/anndata/tests/test_gpu.py index a5f820bd8..434567ca8 100644 --- a/anndata/tests/test_gpu.py +++ b/anndata/tests/test_gpu.py @@ -1,4 +1,7 @@ import pytest +from scipy import sparse + +from anndata import AnnData, Raw @pytest.mark.gpu @@ -9,3 +12,27 @@ def test_gpu(): import cupy # This test shouldn't run if cupy isn't installed cupy.ones(1) + + +@pytest.mark.gpu +def test_adata_raw_gpu(): + from cupyx.scipy import sparse as cupy_sparse + import cupy as cp + + adata = AnnData( + X=cupy_sparse.random(500, 50, density=0.01, format="csr", dtype=cp.float32) + ) + adata.raw = adata + assert isinstance(adata.raw.X, sparse.csr_matrix) + + +@pytest.mark.gpu +def test_raw_gpu(): + from cupyx.scipy import sparse as cupy_sparse + import cupy as cp + + adata = AnnData( + X=cupy_sparse.random(500, 50, density=0.01, format="csr", dtype=cp.float32) + ) + araw = Raw(adata) + assert isinstance(araw.X, sparse.csr_matrix)