From f0548b3977249599eb84ef3fbabbc0fb14e3c081 Mon Sep 17 00:00:00 2001 From: Tatiana Al-Chueyr Date: Mon, 2 Dec 2024 13:48:37 +0000 Subject: [PATCH] Fix retrieving namespace from user-defined YAML config --- ray_provider/constants.py | 1 + ray_provider/hooks.py | 21 +++++++++++++++++---- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/ray_provider/constants.py b/ray_provider/constants.py index 9e37f6a..457c4e2 100644 --- a/ray_provider/constants.py +++ b/ray_provider/constants.py @@ -1,3 +1,4 @@ from ray.job_submission import JobStatus +DEFAULT_K8S_NAMESPACE = "default" TERMINAL_JOB_STATUSES = {JobStatus.SUCCEEDED, JobStatus.STOPPED, JobStatus.FAILED} diff --git a/ray_provider/hooks.py b/ray_provider/hooks.py index 38c9132..cc29e7a 100644 --- a/ray_provider/hooks.py +++ b/ray_provider/hooks.py @@ -5,6 +5,7 @@ import subprocess import tempfile import time +from functools import cached_property from typing import Any, AsyncIterator import requests @@ -15,6 +16,8 @@ from kubernetes import client, config from ray.job_submission import JobStatus, JobSubmissionClient +from ray_provider.constants import DEFAULT_K8S_NAMESPACE + class RayHook(KubernetesHook): # type: ignore """ @@ -31,8 +34,6 @@ class RayHook(KubernetesHook): # type: ignore conn_type = "ray" hook_name = "Ray" - DEFAULT_NAMESPACE = "default" - @classmethod def get_ui_field_behaviour(cls) -> dict[str, Any]: """ @@ -92,7 +93,7 @@ def __init__( self.verify = self._get_field("verify") or False self.ray_client_instance = None - self.namespace = self.get_namespace() or self.DEFAULT_NAMESPACE + self.default_namespace = self.get_namespace() or DEFAULT_K8S_NAMESPACE self.kubeconfig: str | None = None self.in_cluster: bool | None = None self.client_configuration = None @@ -106,6 +107,14 @@ def __init__( self.kubeconfig_content = self._get_field("kube_config") self._setup_kubeconfig(self.kubeconfig_path, self.kubeconfig_content, self.cluster_context) + self.ray_cluster_yaml: None | str = None + + @cached_property + def namespace(self) -> str: + if self.ray_cluster_yaml is None: + return self.default_namespace + cluster_spec = self.load_yaml_content(self.ray_cluster_yaml) + return cluster_spec["metadata"].get("namespace") or self.default_namespace def _setup_kubeconfig( self, kubeconfig_path: str | None, kubeconfig_content: str | None, cluster_context: str | None @@ -331,11 +340,13 @@ def _wait_for_load_balancer( :raises AirflowException: If the LoadBalancer does not become ready within the specified retries. """ for attempt in range(1, max_retries + 1): - self.log.info(f"Attempt {attempt}: Checking LoadBalancer status...") + self.log.info(f"Attempt {attempt}: Checking LoadBalancer status {service_name} in {namespace}...") try: service: client.V1Service = self._get_service(service_name, namespace) + self.log.debug(f"Load balancer service {service}") lb_details: dict[str, Any] | None = self._get_load_balancer_details(service) + self.log.debug(f"Load balancer details {lb_details}") if not lb_details: self.log.info("LoadBalancer details not available yet.") @@ -435,6 +446,7 @@ def _setup_load_balancer(self, name: str, namespace: str, context: Context) -> N :param namespace: The namespace where the cluster is deployed. :param context: The Airflow task context. """ + lb_details: dict[str, Any] = self._wait_for_load_balancer(service_name=f"{name}-head-svc", namespace=namespace) if lb_details: @@ -466,6 +478,7 @@ def setup_ray_cluster( """ try: self._validate_yaml_file(ray_cluster_yaml) + self.ray_cluster_yaml = ray_cluster_yaml self.log.info("::group::Add KubeRay operator") self.install_kuberay_operator(version=kuberay_version)