diff --git a/anndata/_core/raw.py b/anndata/_core/raw.py index f18f75bf8..f248b99be 100644 --- a/anndata/_core/raw.py +++ b/anndata/_core/raw.py @@ -29,7 +29,11 @@ def __init__( self._n_obs = adata.n_obs # construct manually if adata.isbacked == (X is None): - self._X = X + # Move from GPU to CPU since it's large and not always used + if isinstance(X, (CupyArray, CupySparseMatrix)): + self._X = X.get() + else: + self._X = X self._var = _gen_dataframe(var, self.X.shape[1], ["var_names"]) self._varm = AxisArrays(self, 1, varm) elif X is None: # construct from adata diff --git a/anndata/tests/test_gpu.py b/anndata/tests/test_gpu.py index a5f820bd8..434567ca8 100644 --- a/anndata/tests/test_gpu.py +++ b/anndata/tests/test_gpu.py @@ -1,4 +1,7 @@ import pytest +from scipy import sparse + +from anndata import AnnData, Raw @pytest.mark.gpu @@ -9,3 +12,27 @@ def test_gpu(): import cupy # This test shouldn't run if cupy isn't installed cupy.ones(1) + + +@pytest.mark.gpu +def test_adata_raw_gpu(): + from cupyx.scipy import sparse as cupy_sparse + import cupy as cp + + adata = AnnData( + X=cupy_sparse.random(500, 50, density=0.01, format="csr", dtype=cp.float32) + ) + adata.raw = adata + assert isinstance(adata.raw.X, sparse.csr_matrix) + + +@pytest.mark.gpu +def test_raw_gpu(): + from cupyx.scipy import sparse as cupy_sparse + import cupy as cp + + adata = AnnData( + X=cupy_sparse.random(500, 50, density=0.01, format="csr", dtype=cp.float32) + ) + araw = Raw(adata) + assert isinstance(araw.X, sparse.csr_matrix)