From 13ccc3bf0e3d0f2e2a5f5644170aac9d5d67f4da Mon Sep 17 00:00:00 2001 From: Tatiana Al-Chueyr Date: Fri, 29 Nov 2024 08:22:18 +0000 Subject: [PATCH] Re order things so it is easier for code review --- ray_provider/hooks.py | 437 +++++++++++++++++++----------------------- 1 file changed, 195 insertions(+), 242 deletions(-) diff --git a/ray_provider/hooks.py b/ray_provider/hooks.py index 196358d..560a325 100644 --- a/ray_provider/hooks.py +++ b/ray_provider/hooks.py @@ -5,6 +5,7 @@ import subprocess import tempfile import time +from functools import cached_property from typing import Any, AsyncIterator import requests @@ -17,6 +18,8 @@ from ray_provider.constants import TERMINAL_JOB_STATUSES +DEFAULT_NAMESPACE = "default" + class RayHook(KubernetesHook): # type: ignore """ @@ -33,7 +36,43 @@ class RayHook(KubernetesHook): # type: ignore conn_type = "ray" hook_name = "Ray" - DEFAULT_NAMESPACE = "default" + @classmethod + def get_ui_field_behaviour(cls) -> dict[str, Any]: + """ + Return custom field behaviour for the connection form. + + :return: A dictionary specifying custom field behaviour. + """ + return { + "hidden_fields": ["host", "schema", "login", "password", "port", "extra"], + "relabeling": {}, + } + + @classmethod + def get_connection_form_widgets(cls) -> dict[str, Any]: + """ + Return connection widgets to add to connection form. + + :return: A dictionary of connection form widgets. + """ + from flask_appbuilder.fieldwidgets import BS3PasswordFieldWidget, BS3TextFieldWidget + from flask_babel import lazy_gettext + from wtforms import BooleanField, PasswordField, StringField + + return { + "address": StringField(lazy_gettext("Ray dashboard url"), widget=BS3TextFieldWidget()), + # "create_cluster_if_needed": BooleanField(lazy_gettext("Create cluster if needed")), + "cookies": StringField(lazy_gettext("Cookies"), widget=BS3TextFieldWidget()), + "metadata": StringField(lazy_gettext("Metadata"), widget=BS3TextFieldWidget()), + "headers": StringField(lazy_gettext("Headers"), widget=BS3TextFieldWidget()), + "verify": BooleanField(lazy_gettext("Verify")), + "kube_config_path": StringField(lazy_gettext("Kube config path"), widget=BS3TextFieldWidget()), + "kube_config": PasswordField(lazy_gettext("Kube config (JSON format)"), widget=BS3PasswordFieldWidget()), + "namespace": StringField(lazy_gettext("Namespace"), widget=BS3TextFieldWidget()), + "cluster_context": StringField(lazy_gettext("Cluster context"), widget=BS3TextFieldWidget()), + "disable_verify_ssl": BooleanField(lazy_gettext("Disable SSL")), + "disable_tcp_keepalive": BooleanField(lazy_gettext("Disable TCP keepalive")), + } def __init__( self, @@ -56,7 +95,7 @@ def __init__( self.verify = self._get_field("verify") or False self.ray_client_instance = None - self.default_namespace = self.get_namespace() or self.DEFAULT_NAMESPACE + self.default_namespace = self.get_namespace() or DEFAULT_NAMESPACE self.kubeconfig: str | None = None self.in_cluster: bool | None = None self.client_configuration = None @@ -68,23 +107,23 @@ def __init__( self.cluster_context = self._get_field("cluster_context") self.kubeconfig_path = self._get_field("kube_config_path") self.kubeconfig_content = self._get_field("kube_config") - self.ray_cluster_yaml = None + self.ray_cluster_yaml: None | str = None self._setup_kubeconfig(self.kubeconfig_path, self.kubeconfig_content, self.cluster_context) - @property # TODO: cached property + # Create a PR for this + @cached_property def namespace(self): if self.ray_cluster_yaml is None: return self.default_namespace cluster_spec = self.load_yaml_content(self.ray_cluster_yaml) return cluster_spec["metadata"].get("namespace") or self.default_namespace + # Create another PR for this def test_connection(self): job_client = self.ray_client(self.address) - job_id = job_client.submit_job( - entrypoint="import ray; ray.init(); print(ray.cluster_resources())" - ) + job_id = job_client.submit_job(entrypoint="import ray; ray.init(); print(ray.cluster_resources())") self.log.info(f"Ray test connection: Submitted job with ID: {job_id}") job_completed = False @@ -97,52 +136,12 @@ def test_connection(self): job_completed = True connection_attempt -= 1 - if job_status != JobStatus.SUCCEEDED: return False, f"Ray test connection failed: Job {job_id} status {job_status}" return True, job_status # TODO: check webserver logs - @classmethod - def get_ui_field_behaviour(cls) -> dict[str, Any]: - """ - Return custom field behaviour for the connection form. - - :return: A dictionary specifying custom field behaviour. - """ - return { - "hidden_fields": ["host", "schema", "login", "password", "port", "extra"], - "relabeling": {}, - } - - @classmethod - def get_connection_form_widgets(cls) -> dict[str, Any]: - """ - Return connection widgets to add to connection form. - - :return: A dictionary of connection form widgets. - """ - from flask_appbuilder.fieldwidgets import BS3PasswordFieldWidget, BS3TextFieldWidget - from flask_babel import lazy_gettext - from wtforms import BooleanField, PasswordField, StringField - - return { - "address": StringField(lazy_gettext("Ray dashboard url"), widget=BS3TextFieldWidget()), - # "create_cluster_if_needed": BooleanField(lazy_gettext("Create cluster if needed")), - "cookies": StringField(lazy_gettext("Cookies"), widget=BS3TextFieldWidget()), - "metadata": StringField(lazy_gettext("Metadata"), widget=BS3TextFieldWidget()), - "headers": StringField(lazy_gettext("Headers"), widget=BS3TextFieldWidget()), - "verify": BooleanField(lazy_gettext("Verify")), - "kube_config_path": StringField(lazy_gettext("Kube config path"), widget=BS3TextFieldWidget()), - "kube_config": PasswordField(lazy_gettext("Kube config (JSON format)"), widget=BS3PasswordFieldWidget()), - "namespace": StringField(lazy_gettext("Namespace"), widget=BS3TextFieldWidget()), - "cluster_context": StringField(lazy_gettext("Cluster context"), widget=BS3TextFieldWidget()), - "disable_verify_ssl": BooleanField(lazy_gettext("Disable SSL")), - "disable_tcp_keepalive": BooleanField(lazy_gettext("Disable TCP keepalive")), - } - - def _setup_kubeconfig( self, kubeconfig_path: str | None, kubeconfig_content: str | None, cluster_context: str | None ) -> None: @@ -187,19 +186,16 @@ def ray_client(self, dashboard_url: str | None = None) -> JobSubmissionClient: :raises AirflowException: If the connection fails. """ if not self.ray_client_instance: - try: - self.log.info(f"Address URL is: {self.address}") - self.log.info(f"Dashboard URL is: {dashboard_url}") - self.ray_client_instance = JobSubmissionClient( - address=dashboard_url or self.address, - create_cluster_if_needed=self.create_cluster_if_needed, - cookies=self.cookies, - metadata=self.metadata, - headers=self.headers, - verify=self.verify, - ) - except Exception as e: - raise AirflowException(f"Failed to create Ray JobSubmissionClient: {e}") + self.log.info(f"Address URL is: {self.address}") + self.log.info(f"Dashboard URL is: {dashboard_url}") + self.ray_client_instance = JobSubmissionClient( + address=dashboard_url or self.address, + create_cluster_if_needed=self.create_cluster_if_needed, + cookies=self.cookies, + metadata=self.metadata, + headers=self.headers, + verify=self.verify, + ) return self.ray_client_instance def submit_ray_job( @@ -341,7 +337,6 @@ def _check_load_balancer_readiness(self, lb_details: dict[str, Any]) -> str | No ip: str | None = lb_details["ip"] hostname: str | None = lb_details["hostname"] - self.log.info(f"ports: {lb_details['ports']}") for port_info in lb_details["ports"]: port = port_info["port"] if ip and self._is_port_open(ip, port): @@ -351,123 +346,6 @@ def _check_load_balancer_readiness(self, lb_details: dict[str, Any]) -> str | No return None - def _get_node_ip(self) -> str: - """ - Retrieve the IP address of a Kubernetes node. - - :return: The IP address of a node in the Kubernetes cluster. - """ - # Example: Retrieve the first node's IP (adjust based on your cluster setup) - nodes = self.core_v1_client.list_node().items - self.log.info(f"Nodes: {nodes}") - for node in nodes: - self.log.info(f"Node address: {node.status.addresses}") - for address in node.status.addresses: - if address.type == "ExternalIP": - return address.address - - for node in nodes: - self.log.info(f"Node address: {node.status.addresses}") - for address in node.status.addresses: - if address.type == "InternalIP": - return address.address - - raise AirflowException("No valid node IP found in the cluster.") - - def _setup_node_port(self, name: str, namespace: str, context: dict) -> None: - """ - Set up the NodePort service and push URLs to XCom. - - :param name: The name of the Ray cluster. - :param namespace: The namespace where the cluster is deployed. - :param context: The Airflow task context. - """ - node_port_details: dict[str, Any] = self._wait_for_node_port_service( - service_name=f"{name}-head-svc", namespace=namespace - ) - - if node_port_details: - self.log.info(node_port_details) - - node_ports = node_port_details["node_ports"] - # Example: Assuming `node_ip` is provided as an environment variable or a known cluster node. - node_ip = self._get_node_ip() # Implement this method to return a valid node IP or DNS. - - for port in node_ports: - url = f"http://{node_ip}:{port['port']}" - context["task_instance"].xcom_push(key=port["name"], value=url) - self.log.info(f"Pushed URL to XCom: {url}") - else: - self.log.info("No NodePort URLs to push to XCom.") - - def _wait_for_node_port_service( - self, - service_name: str, - namespace: str = "default", - max_retries: int = 30, - retry_interval: int = 10, - ) -> dict[str, Any]: - """ - Wait for the NodePort service to be ready and return its details. - - :param service_name: The name of the NodePort service. - :param namespace: The namespace of the service. - :param max_retries: Maximum number of retries. - :param retry_interval: Interval between retries in seconds. - :return: A dictionary containing NodePort service details. - :raises AirflowException: If the service does not become ready within the specified retries. - """ - for attempt in range(1, max_retries + 1): - self.log.info(f"Attempt {attempt}: Checking NodePort service status...") - - try: - service: client.V1Service = self._get_service(service_name, namespace) - service_details: dict[str, Any] | None = self._get_node_port_details(service) - - if service_details: - self.log.info("NodePort service is ready.") - return service_details - - self.log.info("NodePort details not available yet. Retrying...") - except AirflowException: - self.log.info("Service is not available yet.") - - time.sleep(retry_interval) - - raise AirflowException(f"Service did not become ready after {max_retries} attempts") - - def _get_node_port_details(self, service: client.V1Service) -> dict[str, Any] | None: - """ - Extract NodePort details from the service. - - :param service: The Kubernetes service object. - :return: A dictionary containing NodePort details if available, None otherwise. - """ - node_ports = [] - for port in service.spec.ports: - if port.node_port: - node_ports.append({"name": port.name, "port": port.node_port}) - - if node_ports: - return {"node_ports": node_ports} - - return None - - def _check_node_port_connectivity(self, node_ports: list[dict[str, Any]]) -> bool: - """ - Check if the NodePort is reachable. - - :param node_ports: List of NodePort details. - :return: True if at least one NodePort is accessible, False otherwise. - """ - for port_info in node_ports: - # Replace with actual logic to test connectivity if needed. - self.log.info(f"Checking connectivity for NodePort {port_info['port']}") - # Example: Simulate readiness check. - if self._is_port_open("example-node-ip", port_info["port"]): - return True - return False - def _wait_for_load_balancer( self, service_name: str, @@ -515,41 +393,6 @@ def _wait_for_load_balancer( raise AirflowException(f"LoadBalancer did not become ready after {max_retries} attempts") - def _get_load_balancer_details(self, service: client.V1Service) -> dict[str, Any] | None: - """ - Extract LoadBalancer details from the service. - - :param service: The Kubernetes service object. - :return: A dictionary containing LoadBalancer details if available, None otherwise. - """ - if service.status.load_balancer.ingress: - ingress: client.V1LoadBalancerIngress = service.status.load_balancer.ingress[0] - ip: str | None = ingress.ip - hostname: str | None = ingress.hostname - if ip or hostname: - ports: list[dict[str, Any]] = [{"name": port.name, "port": port.port} for port in service.spec.ports] - return {"ip": ip, "hostname": hostname, "ports": ports} - return None - - def _check_load_balancer_readiness(self, lb_details: dict[str, Any]) -> str | None: - """ - Check if the LoadBalancer is ready by testing port connectivity. - - :param lb_details: Dictionary containing LoadBalancer details. - :return: The working address (IP or hostname) if ready, None otherwise. - """ - ip: str | None = lb_details["ip"] - hostname: str | None = lb_details["hostname"] - - for port_info in lb_details["ports"]: - port = port_info["port"] - if ip and self._is_port_open(ip, port): - return ip - if hostname and self._is_port_open(hostname, port): - return hostname - - return None - def _validate_yaml_file(self, yaml_file: str) -> None: """ Validate the existence and format of the YAML file. @@ -590,16 +433,13 @@ def _create_or_update_cluster( :param cluster_spec: The specification of the Ray cluster. :raises AirflowException: If there's an error accessing or creating the Ray cluster. """ - """self.get_custom_object(group=group, version=version, plural=plural, name=name, namespace=namespace) if update_if_exists: + self.log.info(f"Updating existing Ray cluster: {name}") + self.get_custom_object(group=group, version=version, plural=plural, name=name, namespace=namespace) self.custom_object_client.patch_namespaced_custom_object( group=group, version=version, namespace=namespace, plural=plural, name=name, body=cluster_spec ) - - except client.exceptions.ApiException as e: - if e.status == 404: - """ self.log.info(f"Creating new Ray cluster: {name}") @@ -608,13 +448,16 @@ def _create_or_update_cluster( ) self.log.info(f"Resource created. Response: {response}") + # TODO: may go to a different PR start_time = time.time() wait_timeout = 300 poll_interval = 5 while time.time() - start_time < wait_timeout: try: - resource = self.get_custom_object(group=group, version=version, plural=plural, name=name, namespace=namespace) + resource = self.get_custom_object( + group=group, version=version, plural=plural, name=name, namespace=namespace + ) except client.exceptions.ApiException as e: self.log.warning(f"Error fetching resource status: {e}") else: @@ -626,12 +469,9 @@ def _create_or_update_cluster( time.sleep(poll_interval) - raise TimeoutError(f"Resource {name} of group {group} did not reach the desired state within {wait_timeout} seconds.") - - """ - else: - raise AirflowException(f"Error accessing Ray cluster '{name}': {e}") - """ + raise TimeoutError( + f"Resource {name} of group {group} did not reach the desired state within {wait_timeout} seconds." + ) def _setup_gpu_driver(self, gpu_device_plugin_yaml: str) -> None: """ @@ -685,7 +525,6 @@ def setup_ray_cluster( :param update_if_exists: Whether to update the cluster if it already exists. :raises AirflowException: If there's an error setting up the Ray cluster. """ - #try: self._validate_yaml_file(ray_cluster_yaml) self.ray_cluster_yaml = ray_cluster_yaml @@ -716,24 +555,20 @@ def setup_ray_cluster( ) except TimeoutError as e: self._delete_ray_cluster_crd(ray_cluster_yaml) - raise AirflowException(e) + raise e self.log.info("::endgroup::") - #self._setup_gpu_driver(gpu_device_plugin_yaml=gpu_device_plugin_yaml) + self._setup_gpu_driver(gpu_device_plugin_yaml=gpu_device_plugin_yaml) - #self.log.info("::group:: (Step 3/3) Setup Node Port service") - #self._setup_node_port(name, namespace, context) - #self.log.info("::endgroup::") + # TODO: separate PR + # self.log.info("::group:: (Step 3/3) Setup Node Port service") + # self._setup_node_port(name, namespace, context) + # self.log.info("::endgroup::") self.log.info("::group:: (Setup 3/3) Setup Load Balancer service") self._setup_load_balancer(name, namespace, context) self.log.info("::endgroup::") - #except Exception as e: - # self.log.error(f"Error setting up Ray cluster: {e}") - # raise AirflowException(f"Failed to set up Ray cluster: {e}") - - def _delete_ray_cluster_crd(self, ray_cluster_yaml: str) -> None: """ Delete the Ray cluster based on the cluster specification. @@ -763,7 +598,6 @@ def _delete_ray_cluster_crd(self, ray_cluster_yaml: str) -> None: self.delete_custom_object(group=group, version=version, name=name, namespace=namespace, plural=plural) self.log.info(f"Deleted Ray cluster: {name}") - def delete_ray_cluster(self, ray_cluster_yaml: str, gpu_device_plugin_yaml: str) -> None: """ Execute the operator to delete the Ray cluster. @@ -772,11 +606,10 @@ def delete_ray_cluster(self, ray_cluster_yaml: str, gpu_device_plugin_yaml: str) :param gpu_device_plugin_yaml: Path or URL to the GPU device plugin YAML. Defaults to NVIDIA's plugin :raises AirflowException: If there's an error deleting the Ray cluster. """ - #try: self._validate_yaml_file(ray_cluster_yaml) if gpu_device_plugin_yaml: - #Delete the NVIDIA GPU device plugin DaemonSet if it exists. + # Delete the NVIDIA GPU device plugin DaemonSet if it exists. gpu_driver = self.load_yaml_content(gpu_device_plugin_yaml) gpu_driver_name = gpu_driver["metadata"]["name"] @@ -793,7 +626,7 @@ def delete_ray_cluster(self, ray_cluster_yaml: str, gpu_device_plugin_yaml: str) self.log.info("::group:: Delete Kuberay operator") self.uninstall_kuberay_operator() self.log.info("::endgroup::") - #except Exception as e: + # except Exception as e: # self.log.error(f"Error deleting Ray cluster: {e}") # raise AirflowException(f"Failed to delete Ray cluster: {e}") @@ -886,7 +719,6 @@ def create_daemon_set(self, name: str, body: dict[str, Any]) -> client.V1DaemonS :param body: The body of the DaemonSet for the create action. :return: The created DaemonSet resource if successful, None otherwise. """ - self.log.warning("Trying to create create_daemon_set %s", name) if not body: self.log.error("Body must be provided for create action.") return None @@ -906,7 +738,6 @@ def delete_daemon_set(self, name: str) -> client.V1Status | None: :param name: The name of the DaemonSet. :return: The status of the delete operation if successful, None otherwise. """ - self.log.info("Trying to delete_daemon_set %s", name) try: delete_response = self.apps_v1_client.delete_namespaced_daemon_set(name=name, namespace=self.namespace) self.log.info(f"DaemonSet {name} deleted.") @@ -914,3 +745,125 @@ def delete_daemon_set(self, name: str) -> client.V1Status | None: except client.exceptions.ApiException as e: self.log.error(f"Exception when deleting DaemonSet: {e}") return None + + # Add this to yet another PR + def _get_node_ip(self) -> str: + """ + Retrieve the IP address of a Kubernetes node. + + :return: The IP address of a node in the Kubernetes cluster. + """ + # Example: Retrieve the first node's IP (adjust based on your cluster setup) + nodes = self.core_v1_client.list_node().items + self.log.info(f"Nodes: {nodes}") + for node in nodes: + self.log.info(f"Node address: {node.status.addresses}") + for address in node.status.addresses: + if address.type == "ExternalIP": + return address.address + + for node in nodes: + self.log.info(f"Node address: {node.status.addresses}") + for address in node.status.addresses: + if address.type == "InternalIP": + return address.address + + raise AirflowException("No valid node IP found in the cluster.") + + # Add this to yet another PR + def _setup_node_port(self, name: str, namespace: str, context: dict) -> None: + """ + Set up the NodePort service and push URLs to XCom. + + :param name: The name of the Ray cluster. + :param namespace: The namespace where the cluster is deployed. + :param context: The Airflow task context. + """ + node_port_details: dict[str, Any] = self._wait_for_node_port_service( + service_name=f"{name}-head-svc", namespace=namespace + ) + + if node_port_details: + self.log.info(node_port_details) + + node_ports = node_port_details["node_ports"] + # Example: Assuming `node_ip` is provided as an environment variable or a known cluster node. + node_ip = self._get_node_ip() # Implement this method to return a valid node IP or DNS. + + for port in node_ports: + url = f"http://{node_ip}:{port['port']}" + context["task_instance"].xcom_push(key=port["name"], value=url) + self.log.info(f"Pushed URL to XCom: {url}") + else: + self.log.info("No NodePort URLs to push to XCom.") + + # Add this to yet another PR + def _wait_for_node_port_service( + self, + service_name: str, + namespace: str = "default", + max_retries: int = 30, + retry_interval: int = 10, + ) -> dict[str, Any]: + """ + Wait for the NodePort service to be ready and return its details. + + :param service_name: The name of the NodePort service. + :param namespace: The namespace of the service. + :param max_retries: Maximum number of retries. + :param retry_interval: Interval between retries in seconds. + :return: A dictionary containing NodePort service details. + :raises AirflowException: If the service does not become ready within the specified retries. + """ + for attempt in range(1, max_retries + 1): + self.log.info(f"Attempt {attempt}: Checking NodePort service status...") + + try: + service: client.V1Service = self._get_service(service_name, namespace) + service_details: dict[str, Any] | None = self._get_node_port_details(service) + + if service_details: + self.log.info("NodePort service is ready.") + return service_details + + self.log.info("NodePort details not available yet. Retrying...") + except AirflowException: + self.log.info("Service is not available yet.") + + time.sleep(retry_interval) + + raise AirflowException(f"Service did not become ready after {max_retries} attempts") + + # Add this to yet another PR + def _get_node_port_details(self, service: client.V1Service) -> dict[str, Any] | None: + """ + Extract NodePort details from the service. + + :param service: The Kubernetes service object. + :return: A dictionary containing NodePort details if available, None otherwise. + """ + node_ports = [] + for port in service.spec.ports: + if port.node_port: + node_ports.append({"name": port.name, "port": port.node_port}) + + if node_ports: + return {"node_ports": node_ports} + + return None + + # Add this to yet another PR + def _check_node_port_connectivity(self, node_ports: list[dict[str, Any]]) -> bool: + """ + Check if the NodePort is reachable. + + :param node_ports: List of NodePort details. + :return: True if at least one NodePort is accessible, False otherwise. + """ + for port_info in node_ports: + # Replace with actual logic to test connectivity if needed. + self.log.info(f"Checking connectivity for NodePort {port_info['port']}") + # Example: Simulate readiness check. + if self._is_port_open("example-node-ip", port_info["port"]): + return True + return False