diff --git a/src/rapids_singlecell/preprocessing/_sparse_pca/_dask_sparse_pca.py b/src/rapids_singlecell/preprocessing/_sparse_pca/_dask_sparse_pca.py index 2a59cbfe..41001eb7 100644 --- a/src/rapids_singlecell/preprocessing/_sparse_pca/_dask_sparse_pca.py +++ b/src/rapids_singlecell/preprocessing/_sparse_pca/_dask_sparse_pca.py @@ -57,7 +57,9 @@ def fit(self, x): def transform(self, X): def _transform(X_part, mean_, components_): pre_mean = mean_ @ components_.T - mean_impact = cp.ones((X_part.shape[0], 1)) @ pre_mean.reshape(1, -1) + mean_impact = cp.ones( + (X_part.shape[0], 1), dtype=X_part.dtype + ) @ pre_mean.reshape(1, -1) X_transformed = X_part.dot(components_.T) - mean_impact return X_transformed diff --git a/src/rapids_singlecell/preprocessing/_sparse_pca/_sparse_pca.py b/src/rapids_singlecell/preprocessing/_sparse_pca/_sparse_pca.py index ac2b0030..2f9d5117 100644 --- a/src/rapids_singlecell/preprocessing/_sparse_pca/_sparse_pca.py +++ b/src/rapids_singlecell/preprocessing/_sparse_pca/_sparse_pca.py @@ -52,7 +52,9 @@ def fit(self, x): def transform(self, X): precomputed_mean_impact = self.mean_ @ self.components_.T - mean_impact = cp.ones((X.shape[0], 1)) @ precomputed_mean_impact.reshape(1, -1) + mean_impact = cp.ones( + (X.shape[0], 1), dtype=cp.float32 + ) @ precomputed_mean_impact.reshape(1, -1) X_transformed = X.dot(self.components_.T) - mean_impact # X = X - self.mean_ # X_transformed = X.dot(self.components_.T)