diff --git a/src/rapids_singlecell/preprocessing/_neighbors.py b/src/rapids_singlecell/preprocessing/_neighbors.py index 54b10084..cd6eefb9 100644 --- a/src/rapids_singlecell/preprocessing/_neighbors.py +++ b/src/rapids_singlecell/preprocessing/_neighbors.py @@ -15,7 +15,10 @@ from rapids_singlecell.tools._utils import _choose_representation -CUVS_SWITCH = parse_version(pylibraft.__version__) > parse_version("24.10") + +def _cuvs_switch(): + return parse_version(pylibraft.__version__) > parse_version("24.10") + if TYPE_CHECKING: from collections.abc import Mapping @@ -87,7 +90,7 @@ def _brute_knn( def _cagra_knn( X: cp.ndarray, Y: cp.ndarray, k: int, metric: _Metrics, metric_kwds: Mapping ) -> tuple[cp.ndarray, cp.ndarray]: - if not CUVS_SWITCH: + if not _cuvs_switch(): try: from pylibraft.neighbors import cagra except ImportError: @@ -138,7 +141,7 @@ def _cagra_knn( def _ivf_flat_knn( X: cp.ndarray, Y: cp.ndarray, k: int, metric: _Metrics, metric_kwds: Mapping ) -> tuple[cp.ndarray, cp.ndarray]: - if not CUVS_SWITCH: + if not _cuvs_switch(): from pylibraft.neighbors import ivf_flat resources = DeviceResources() @@ -170,7 +173,7 @@ def _ivf_flat_knn( def _ivf_pq_knn( X: cp.ndarray, Y: cp.ndarray, k: int, metric: _Metrics, metric_kwds: Mapping ) -> tuple[cp.ndarray, cp.ndarray]: - if not CUVS_SWITCH: + if not _cuvs_switch(): from pylibraft.neighbors import ivf_pq resources = DeviceResources()