diff --git a/dask_kubernetes/operator/kubecluster/kubecluster.py b/dask_kubernetes/operator/kubecluster/kubecluster.py index 8e1a9dfe8..738c12e4c 100644 --- a/dask_kubernetes/operator/kubecluster/kubecluster.py +++ b/dask_kubernetes/operator/kubecluster/kubecluster.py @@ -22,6 +22,7 @@ import pykube.exceptions import kubernetes_asyncio as kubernetes import yaml +import kr8s import dask.config from distributed.core import Status, rpc @@ -43,9 +44,13 @@ ) from dask_kubernetes.common.utils import get_current_namespace from dask_kubernetes.aiopykube import HTTPClient, KubeConfig -from dask_kubernetes.aiopykube.dask import DaskCluster, DaskWorkerGroup +from dask_kubernetes.aiopykube.dask import ( + DaskCluster, + DaskWorkerGroup as AIODaskWorkerGroup, +) from dask_kubernetes.aiopykube.objects import Pod, Service from dask_kubernetes.exceptions import CrashLoopBackOffError, SchedulerStartupError +from dask_kubernetes.operator._objects import DaskWorkerGroup, DaskAutoscaler logger = logging.getLogger(__name__) @@ -541,7 +546,7 @@ async def _watch_component_status(self): # Get DaskWorkerGroup status with suppress(pykube.exceptions.ObjectDoesNotExist): - await DaskWorkerGroup.objects( + await AIODaskWorkerGroup.objects( self.k8s_api, namespace=self.namespace ).get_by_name(self.name + "-default") self._startup_component_status["workergroup"] = "Created" @@ -799,31 +804,18 @@ def scale(self, n, worker_group="default"): return self.sync(self._scale, n, worker_group) async def _scale(self, n, worker_group="default"): - async with kubernetes.client.api_client.ApiClient() as api_client: - custom_objects_api = kubernetes.client.CustomObjectsApi(api_client) - custom_objects_api.api_client.set_default_header( - "content-type", "application/merge-patch+json" - ) - # Disable adaptivity if enabled - with suppress(kubernetes.client.ApiException): - await custom_objects_api.delete_namespaced_custom_object( - group="kubernetes.dask.org", - version="v1", - plural="daskautoscalers", - namespace=self.namespace, - name=self.name, - ) - await custom_objects_api.patch_namespaced_custom_object_scale( - group="kubernetes.dask.org", - version="v1", - plural="daskworkergroups", - namespace=self.namespace, - name=f"{self.name}-{worker_group}", - body={"spec": {"replicas": n}}, - ) - for instance in self._instances: - if instance.name == self.name: - instance.scheduler_info = self.scheduler_info + # Disable adaptivity if enabled + with suppress(kr8s.NotFoundError): + autoscaler = await DaskAutoscaler(self.name, self.namespace) + await autoscaler.delete() + + wg = await DaskWorkerGroup( + f"{self.name}-{worker_group}", namespace=self.namespace + ) + await wg.scale(n) + for instance in self._instances: + if instance.name == self.name: + instance.scheduler_info = self.scheduler_info def adapt(self, minimum=None, maximum=None): """Turn on adaptivity @@ -843,41 +835,27 @@ def adapt(self, minimum=None, maximum=None): return self.sync(self._adapt, minimum, maximum) async def _adapt(self, minimum=None, maximum=None): - async with kubernetes.client.api_client.ApiClient() as api_client: - custom_objects_api = kubernetes.client.CustomObjectsApi(api_client) - custom_objects_api.api_client.set_default_header( - "content-type", "application/merge-patch+json" - ) - try: - await custom_objects_api.patch_namespaced_custom_object_scale( - group="kubernetes.dask.org", - version="v1", - plural="daskautoscalers", - namespace=self.namespace, - name=self.name, - body={"spec": {"minimum": minimum, "maximum": maximum}}, - ) - except kubernetes.client.ApiException: - await custom_objects_api.create_namespaced_custom_object( - group="kubernetes.dask.org", - version="v1", - plural="daskautoscalers", - namespace=self.namespace, - body={ - "apiVersion": "kubernetes.dask.org/v1", - "kind": "DaskAutoscaler", - "metadata": { - "name": self.name, - "dask.org/cluster-name": self.name, - "dask.org/component": "autoscaler", - }, - "spec": { - "cluster": self.name, - "minimum": minimum, - "maximum": maximum, - }, - }, - ) + autoscaler = await DaskAutoscaler( + { + "apiVersion": "kubernetes.dask.org/v1", + "kind": "DaskAutoscaler", + "metadata": { + "name": self.name, + "dask.org/cluster-name": self.name, + "dask.org/component": "autoscaler", + }, + "spec": { + "cluster": self.name, + "minimum": minimum, + "maximum": maximum, + }, + }, + self.namespace, + ) + try: + await autoscaler.patch({"spec": {"minimum": minimum, "maximum": maximum}}) + except kr8s.NotFoundError: + await autoscaler.create() def __enter__(self): return self