From 51ce370dc4478bfe74fd9ec26e28e3827843f6d1 Mon Sep 17 00:00:00 2001 From: Venkata Jagannath <160357297+venkatajagannath@users.noreply.github.com> Date: Sun, 21 Jul 2024 17:57:16 -0400 Subject: [PATCH] Wait for job completion feature (#20) * first changes * updates * unit tests updated * bug fix --- ray_provider/hooks/ray.py | 1 - ray_provider/operators/ray.py | 70 ++++++++++++++++----------- ray_provider/triggers/ray.py | 58 +++++++++++----------- tests/operators/test_ray_operators.py | 11 +++-- tests/triggers/test_ray_triggers.py | 34 ++++++------- 5 files changed, 93 insertions(+), 81 deletions(-) diff --git a/ray_provider/hooks/ray.py b/ray_provider/hooks/ray.py index 22a3a13..5c79ba8 100644 --- a/ray_provider/hooks/ray.py +++ b/ray_provider/hooks/ray.py @@ -186,7 +186,6 @@ def get_ray_job_logs(self, job_id: str) -> str: """ client = self.ray_client logs = client.get_job_logs(job_id=job_id) - self.log.info(f"Logs for job {job_id}: {logs}") return str(logs) async def get_ray_tail_logs(self, job_id: str) -> AsyncIterator[str]: diff --git a/ray_provider/operators/ray.py b/ray_provider/operators/ray.py index bcfcb2a..6f80947 100644 --- a/ray_provider/operators/ray.py +++ b/ray_provider/operators/ray.py @@ -110,8 +110,11 @@ def _setup_load_balancer(self, name: str, namespace: str, context: Context) -> N def execute(self, context: Context) -> None: """Execute the operator to set up the Ray cluster.""" try: + self.log.info("::group::Add KubeRay operator") self.hook.install_kuberay_operator(version=self.kuberay_version) + self.log.info("::endgroup::") + self.log.info("::group::Create Ray Cluster") self.log.info("Loading yaml content for Ray cluster CRD...") cluster_spec = self.hook.load_yaml_content(self.ray_cluster_yaml) @@ -123,11 +126,14 @@ def execute(self, context: Context) -> None: group, version = api_version.split("/") if "/" in api_version else ("", api_version) self._create_or_update_cluster(group, version, plural, name, namespace, cluster_spec) + self.log.info("::endgroup::") if self.use_gpu: self._setup_gpu_driver() + self.log.info("::group::Setup Load Balancer service") self._setup_load_balancer(name, namespace, context) + self.log.info("::endgroup::") except Exception as e: self.log.error(f"Error setting up Ray cluster: {e}") @@ -210,7 +216,9 @@ def execute(self, context: Context) -> None: try: if self.use_gpu: self._delete_gpu_daemonset() + self.log.info("::group:: Delete Ray Cluster") self._delete_ray_cluster() + self.log.info("::endgroup::") self.hook.uninstall_kuberay_operator() except Exception as e: self.log.error(f"Error deleting Ray cluster: {e}") @@ -231,7 +239,7 @@ class SubmitRayJob(BaseOperator): :param num_gpus: Number of GPUs required for the job. Defaults to 0. :param memory: Amount of memory required for the job. Defaults to 0. :param resources: Additional resources required for the job. Defaults to None. - :param timeout: Maximum time to wait for job completion in seconds. Defaults to 600 seconds. + :param job_timeout_seconds: Maximum time to wait for job completion in seconds. Defaults to 600 seconds. :param poll_interval: Interval between job status checks in seconds. Defaults to 60 seconds. :param xcom_task_key: XCom key to retrieve dashboard URL. Defaults to None. """ @@ -248,7 +256,9 @@ def __init__( num_gpus: int | float = 0, memory: int | float = 0, resources: dict[str, Any] | None = None, - timeout: int = 600, + fetch_logs: bool = True, + wait_for_completion: bool = True, + job_timeout_seconds: int = 600, poll_interval: int = 60, xcom_task_key: str | None = None, **kwargs: Any, @@ -261,7 +271,9 @@ def __init__( self.num_gpus = num_gpus self.memory = memory self.ray_resources = resources - self.timeout = timeout + self.fetch_logs = fetch_logs + self.wait_for_completion = wait_for_completion + self.job_timeout_seconds = job_timeout_seconds self.poll_interval = poll_interval self.xcom_task_key = xcom_task_key self.dashboard_url: str | None = None @@ -303,29 +315,31 @@ def execute(self, context: Context) -> str: ) self.log.info(f"Ray job submitted with id: {self.job_id}") - current_status = self.hook.get_ray_job_status(self.job_id) - self.log.info(f"Current job status for {self.job_id} is: {current_status}") - - if current_status not in self.terminal_state: - self.log.info("Deferring the polling to RayJobTrigger...") - self.defer( - timeout=timedelta(seconds=self.timeout), - trigger=RayJobTrigger( - job_id=self.job_id, - conn_id=self.conn_id, - xcom_dashboard_url=self.dashboard_url, - poll_interval=self.poll_interval, - ), - method_name="execute_complete", - ) - elif current_status == JobStatus.SUCCEEDED: - self.log.info("Job %s completed successfully", self.job_id) - elif current_status == JobStatus.FAILED: - raise AirflowException(f"Job failed:\n{self.job_id}") - elif current_status == JobStatus.STOPPED: - raise AirflowException(f"Job was cancelled:\n{self.job_id}") - else: - raise AirflowException(f"Encountered unexpected state `{current_status}` for job_id `{self.job_id}`") + if self.wait_for_completion: + current_status = self.hook.get_ray_job_status(self.job_id) + self.log.info(f"Current job status for {self.job_id} is: {current_status}") + + if current_status not in self.terminal_state: + self.log.info("Deferring the polling to RayJobTrigger...") + self.defer( + trigger=RayJobTrigger( + job_id=self.job_id, + conn_id=self.conn_id, + xcom_dashboard_url=self.dashboard_url, + poll_interval=self.poll_interval, + fetch_logs=self.fetch_logs, + ), + method_name="execute_complete", + timeout=timedelta(seconds=self.job_timeout_seconds), + ) + elif current_status == JobStatus.SUCCEEDED: + self.log.info("Job %s completed successfully", self.job_id) + elif current_status == JobStatus.FAILED: + raise AirflowException(f"Job failed:\n{self.job_id}") + elif current_status == JobStatus.STOPPED: + raise AirflowException(f"Job was cancelled:\n{self.job_id}") + else: + raise AirflowException(f"Encountered unexpected state `{current_status}` for job_id `{self.job_id}`") return self.job_id @@ -337,10 +351,10 @@ def execute_complete(self, context: Context, event: dict[str, Any]) -> None: :param event: The event containing the job execution result. :raises AirflowException: If the job execution fails or is cancelled. """ - if event["status"] in ["error", "cancelled"]: + if event["status"] in [JobStatus.STOPPED, JobStatus.FAILED]: self.log.info(f"Ray job {self.job_id} execution not completed...") raise AirflowException(event["message"]) - elif event["status"] == "success": + elif event["status"] == JobStatus.SUCCEEDED: self.log.info(f"Ray job {self.job_id} execution succeeded ...") else: raise AirflowException(f"Unexpected event status: {event['status']}") diff --git a/ray_provider/triggers/ray.py b/ray_provider/triggers/ray.py index 060fd2d..0807a35 100644 --- a/ray_provider/triggers/ray.py +++ b/ray_provider/triggers/ray.py @@ -1,7 +1,7 @@ from __future__ import annotations import asyncio -from functools import cached_property +from functools import cached_property, partial from typing import Any, AsyncIterator from airflow.triggers.base import BaseTrigger, TriggerEvent @@ -24,11 +24,19 @@ class RayJobTrigger(BaseTrigger): :param poll_interval: The interval in seconds at which to poll the job status. Defaults to 30 seconds. """ - def __init__(self, job_id: str, conn_id: str, xcom_dashboard_url: str | None, poll_interval: int = 30): + def __init__( + self, + job_id: str, + conn_id: str, + xcom_dashboard_url: str | None, + poll_interval: int = 30, + fetch_logs: bool = True, + ): super().__init__() # type: ignore[no-untyped-call] self.job_id = job_id self.conn_id = conn_id self.dashboard_url = xcom_dashboard_url + self.fetch_logs = fetch_logs self.poll_interval = poll_interval def serialize(self) -> tuple[str, dict[str, Any]]: @@ -43,6 +51,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]: "job_id": self.job_id, "conn_id": self.conn_id, "xcom_dashboard_url": self.dashboard_url, + "fetch_logs": self.fetch_logs, "poll_interval": self.poll_interval, }, ) @@ -71,38 +80,27 @@ async def run(self) -> AsyncIterator[TriggerEvent]: while not self._is_terminal_state(): await asyncio.sleep(self.poll_interval) - # Stream logs if available - async for multi_line in self.hook.get_ray_tail_logs(self.job_id): - self.log.info(multi_line) + self.log.info(f"Fetch logs flag is set to : {self.fetch_logs}") + if self.fetch_logs: + # Stream logs if available + loop = asyncio.get_event_loop() + logs = await loop.run_in_executor(None, partial(self.hook.get_ray_job_logs, job_id=self.job_id)) + self.log.info(f"::group::{self.job_id} logs") + for log in logs.split("\n"): + self.log.info(log) + self.log.info("::endgroup::") completed_status = self.hook.get_ray_job_status(self.job_id) self.log.info(f"Status of completed job {self.job_id} is: {completed_status}") - if completed_status == JobStatus.SUCCEEDED: - yield TriggerEvent( - { - "status": "success", - "message": f"Job run {self.job_id} has completed successfully.", - "job_id": self.job_id, - } - ) - elif completed_status == JobStatus.STOPPED: - yield TriggerEvent( - { - "status": "cancelled", - "message": f"Job run {self.job_id} has been stopped.", - "job_id": self.job_id, - } - ) - else: - yield TriggerEvent( - { - "status": "error", - "message": f"Job run {self.job_id} has failed.", - "job_id": self.job_id, - } - ) + yield TriggerEvent( + { + "status": completed_status, + "message": f"Job {self.job_id} completed with status {completed_status}", + "job_id": self.job_id, + } + ) except Exception as e: - yield TriggerEvent({"status": "error", "message": str(e), "job_id": self.job_id}) + yield TriggerEvent({"status": str(JobStatus.FAILED), "message": str(e), "job_id": self.job_id}) def _is_terminal_state(self) -> bool: """ diff --git a/tests/operators/test_ray_operators.py b/tests/operators/test_ray_operators.py index 89ca4a6..2dd25c4 100644 --- a/tests/operators/test_ray_operators.py +++ b/tests/operators/test_ray_operators.py @@ -2,6 +2,7 @@ import pytest from airflow.exceptions import AirflowException, TaskDeferred +from ray.job_submission import JobStatus from ray_provider.operators.ray import SubmitRayJob @@ -13,7 +14,7 @@ num_gpus = 1 memory = 1024 resources = {"CPU": 2} -timeout = 600 +job_timeout_seconds = 600 context = MagicMock() @@ -27,7 +28,7 @@ def operator(): num_gpus=num_gpus, memory=memory, resources=resources, - timeout=timeout, + job_timeout_seconds=job_timeout_seconds, task_id="Testcases", ) @@ -42,7 +43,7 @@ def test_init(self, operator): assert operator.num_gpus == num_gpus assert operator.memory == memory # assert operator.resources == resources - assert operator.timeout == timeout + assert operator.job_timeout_seconds == job_timeout_seconds @patch("ray_provider.operators.ray.SubmitRayJob.hook") def test_execute(self, mock_hook, operator): @@ -65,13 +66,13 @@ def test_on_kill(self, mock_hook, operator): mock_hook.delete_ray_job.assert_called_once_with("job_12345") def test_execute_complete_success(self, operator): - event = {"status": "success", "message": "Job completed successfully"} + event = {"status": JobStatus.SUCCEEDED, "message": "Job completed successfully"} operator.job_id = "job_12345" assert operator.execute_complete(context, event) is None def test_execute_complete_failure(self, operator): - event = {"status": "error", "message": "Job failed"} + event = {"status": JobStatus.FAILED, "message": "Job failed"} operator.job_id = "job_12345" with pytest.raises(AirflowException, match="Job failed"): diff --git a/tests/triggers/test_ray_triggers.py b/tests/triggers/test_ray_triggers.py index c9e98aa..6bf02f7 100644 --- a/tests/triggers/test_ray_triggers.py +++ b/tests/triggers/test_ray_triggers.py @@ -2,38 +2,38 @@ import pytest from airflow.triggers.base import TriggerEvent -from ray.dashboard.modules.job.sdk import JobStatus +from ray.job_submission import JobStatus from ray_provider.triggers.ray import RayJobTrigger class TestRayJobTrigger: - @pytest.mark.asyncio @patch("ray_provider.triggers.ray.RayJobTrigger._is_terminal_state") @patch("ray_provider.triggers.ray.RayJobTrigger.hook") async def test_run_no_job_id(self, mock_hook, mock_is_terminal): mock_is_terminal.return_value = True + mock_hook.get_ray_job_status.return_value = JobStatus.FAILED trigger = RayJobTrigger(job_id="", poll_interval=1, conn_id="test", xcom_dashboard_url="test") - generator = trigger.run() event = await generator.asend(None) - assert event == TriggerEvent({"status": "error", "message": "Job run has failed.", "job_id": ""}) + assert event == TriggerEvent( + {"status": JobStatus.FAILED, "message": "Job completed with status FAILED", "job_id": ""} + ) @pytest.mark.asyncio + @patch("ray_provider.triggers.ray.RayJobTrigger._is_terminal_state") @patch("ray_provider.triggers.ray.RayJobTrigger.hook") - async def test_run_job_succeeded(self, mock_hook): - trigger = RayJobTrigger(job_id="test_job_id", poll_interval=1, conn_id="test", xcom_dashboard_url="test") - + async def test_run_job_succeeded(self, mock_hook, mock_is_terminal): + mock_is_terminal.side_effect = [False, True] mock_hook.get_ray_job_status.return_value = JobStatus.SUCCEEDED - + trigger = RayJobTrigger(job_id="test_job_id", poll_interval=1, conn_id="test", xcom_dashboard_url="test") generator = trigger.run() - async for event in generator: - assert event == TriggerEvent( - { - "status": "success", - "message": "Job run test_job_id has completed successfully.", - "job_id": "test_job_id", - } - ) - break # Stop after the first event for testing purposes + event = await generator.asend(None) + assert event == TriggerEvent( + { + "status": JobStatus.SUCCEEDED, + "message": f"Job test_job_id completed with status {JobStatus.SUCCEEDED}", + "job_id": "test_job_id", + } + )