Skip to content

Commit

Permalink
Wait for job completion feature (#20)
Browse files Browse the repository at this point in the history
* first changes

* updates

* unit tests updated

* bug fix
  • Loading branch information
venkatajagannath authored Jul 21, 2024
1 parent d3241ff commit 51ce370
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 81 deletions.
1 change: 0 additions & 1 deletion ray_provider/hooks/ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,6 @@ def get_ray_job_logs(self, job_id: str) -> str:
"""
client = self.ray_client
logs = client.get_job_logs(job_id=job_id)
self.log.info(f"Logs for job {job_id}: {logs}")
return str(logs)

async def get_ray_tail_logs(self, job_id: str) -> AsyncIterator[str]:
Expand Down
70 changes: 42 additions & 28 deletions ray_provider/operators/ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,11 @@ def _setup_load_balancer(self, name: str, namespace: str, context: Context) -> N
def execute(self, context: Context) -> None:
"""Execute the operator to set up the Ray cluster."""
try:
self.log.info("::group::Add KubeRay operator")
self.hook.install_kuberay_operator(version=self.kuberay_version)
self.log.info("::endgroup::")

self.log.info("::group::Create Ray Cluster")
self.log.info("Loading yaml content for Ray cluster CRD...")
cluster_spec = self.hook.load_yaml_content(self.ray_cluster_yaml)

Expand All @@ -123,11 +126,14 @@ def execute(self, context: Context) -> None:
group, version = api_version.split("/") if "/" in api_version else ("", api_version)

self._create_or_update_cluster(group, version, plural, name, namespace, cluster_spec)
self.log.info("::endgroup::")

if self.use_gpu:
self._setup_gpu_driver()

self.log.info("::group::Setup Load Balancer service")
self._setup_load_balancer(name, namespace, context)
self.log.info("::endgroup::")

except Exception as e:
self.log.error(f"Error setting up Ray cluster: {e}")
Expand Down Expand Up @@ -210,7 +216,9 @@ def execute(self, context: Context) -> None:
try:
if self.use_gpu:
self._delete_gpu_daemonset()
self.log.info("::group:: Delete Ray Cluster")
self._delete_ray_cluster()
self.log.info("::endgroup::")
self.hook.uninstall_kuberay_operator()
except Exception as e:
self.log.error(f"Error deleting Ray cluster: {e}")
Expand All @@ -231,7 +239,7 @@ class SubmitRayJob(BaseOperator):
:param num_gpus: Number of GPUs required for the job. Defaults to 0.
:param memory: Amount of memory required for the job. Defaults to 0.
:param resources: Additional resources required for the job. Defaults to None.
:param timeout: Maximum time to wait for job completion in seconds. Defaults to 600 seconds.
:param job_timeout_seconds: Maximum time to wait for job completion in seconds. Defaults to 600 seconds.
:param poll_interval: Interval between job status checks in seconds. Defaults to 60 seconds.
:param xcom_task_key: XCom key to retrieve dashboard URL. Defaults to None.
"""
Expand All @@ -248,7 +256,9 @@ def __init__(
num_gpus: int | float = 0,
memory: int | float = 0,
resources: dict[str, Any] | None = None,
timeout: int = 600,
fetch_logs: bool = True,
wait_for_completion: bool = True,
job_timeout_seconds: int = 600,
poll_interval: int = 60,
xcom_task_key: str | None = None,
**kwargs: Any,
Expand All @@ -261,7 +271,9 @@ def __init__(
self.num_gpus = num_gpus
self.memory = memory
self.ray_resources = resources
self.timeout = timeout
self.fetch_logs = fetch_logs
self.wait_for_completion = wait_for_completion
self.job_timeout_seconds = job_timeout_seconds
self.poll_interval = poll_interval
self.xcom_task_key = xcom_task_key
self.dashboard_url: str | None = None
Expand Down Expand Up @@ -303,29 +315,31 @@ def execute(self, context: Context) -> str:
)
self.log.info(f"Ray job submitted with id: {self.job_id}")

current_status = self.hook.get_ray_job_status(self.job_id)
self.log.info(f"Current job status for {self.job_id} is: {current_status}")

if current_status not in self.terminal_state:
self.log.info("Deferring the polling to RayJobTrigger...")
self.defer(
timeout=timedelta(seconds=self.timeout),
trigger=RayJobTrigger(
job_id=self.job_id,
conn_id=self.conn_id,
xcom_dashboard_url=self.dashboard_url,
poll_interval=self.poll_interval,
),
method_name="execute_complete",
)
elif current_status == JobStatus.SUCCEEDED:
self.log.info("Job %s completed successfully", self.job_id)
elif current_status == JobStatus.FAILED:
raise AirflowException(f"Job failed:\n{self.job_id}")
elif current_status == JobStatus.STOPPED:
raise AirflowException(f"Job was cancelled:\n{self.job_id}")
else:
raise AirflowException(f"Encountered unexpected state `{current_status}` for job_id `{self.job_id}`")
if self.wait_for_completion:
current_status = self.hook.get_ray_job_status(self.job_id)
self.log.info(f"Current job status for {self.job_id} is: {current_status}")

if current_status not in self.terminal_state:
self.log.info("Deferring the polling to RayJobTrigger...")
self.defer(
trigger=RayJobTrigger(
job_id=self.job_id,
conn_id=self.conn_id,
xcom_dashboard_url=self.dashboard_url,
poll_interval=self.poll_interval,
fetch_logs=self.fetch_logs,
),
method_name="execute_complete",
timeout=timedelta(seconds=self.job_timeout_seconds),
)
elif current_status == JobStatus.SUCCEEDED:
self.log.info("Job %s completed successfully", self.job_id)
elif current_status == JobStatus.FAILED:
raise AirflowException(f"Job failed:\n{self.job_id}")
elif current_status == JobStatus.STOPPED:
raise AirflowException(f"Job was cancelled:\n{self.job_id}")
else:
raise AirflowException(f"Encountered unexpected state `{current_status}` for job_id `{self.job_id}`")

return self.job_id

Expand All @@ -337,10 +351,10 @@ def execute_complete(self, context: Context, event: dict[str, Any]) -> None:
:param event: The event containing the job execution result.
:raises AirflowException: If the job execution fails or is cancelled.
"""
if event["status"] in ["error", "cancelled"]:
if event["status"] in [JobStatus.STOPPED, JobStatus.FAILED]:
self.log.info(f"Ray job {self.job_id} execution not completed...")
raise AirflowException(event["message"])
elif event["status"] == "success":
elif event["status"] == JobStatus.SUCCEEDED:
self.log.info(f"Ray job {self.job_id} execution succeeded ...")
else:
raise AirflowException(f"Unexpected event status: {event['status']}")
58 changes: 28 additions & 30 deletions ray_provider/triggers/ray.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import asyncio
from functools import cached_property
from functools import cached_property, partial
from typing import Any, AsyncIterator

from airflow.triggers.base import BaseTrigger, TriggerEvent
Expand All @@ -24,11 +24,19 @@ class RayJobTrigger(BaseTrigger):
:param poll_interval: The interval in seconds at which to poll the job status. Defaults to 30 seconds.
"""

def __init__(self, job_id: str, conn_id: str, xcom_dashboard_url: str | None, poll_interval: int = 30):
def __init__(
self,
job_id: str,
conn_id: str,
xcom_dashboard_url: str | None,
poll_interval: int = 30,
fetch_logs: bool = True,
):
super().__init__() # type: ignore[no-untyped-call]
self.job_id = job_id
self.conn_id = conn_id
self.dashboard_url = xcom_dashboard_url
self.fetch_logs = fetch_logs
self.poll_interval = poll_interval

def serialize(self) -> tuple[str, dict[str, Any]]:
Expand All @@ -43,6 +51,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
"job_id": self.job_id,
"conn_id": self.conn_id,
"xcom_dashboard_url": self.dashboard_url,
"fetch_logs": self.fetch_logs,
"poll_interval": self.poll_interval,
},
)
Expand Down Expand Up @@ -71,38 +80,27 @@ async def run(self) -> AsyncIterator[TriggerEvent]:
while not self._is_terminal_state():
await asyncio.sleep(self.poll_interval)

# Stream logs if available
async for multi_line in self.hook.get_ray_tail_logs(self.job_id):
self.log.info(multi_line)
self.log.info(f"Fetch logs flag is set to : {self.fetch_logs}")
if self.fetch_logs:
# Stream logs if available
loop = asyncio.get_event_loop()
logs = await loop.run_in_executor(None, partial(self.hook.get_ray_job_logs, job_id=self.job_id))
self.log.info(f"::group::{self.job_id} logs")
for log in logs.split("\n"):
self.log.info(log)
self.log.info("::endgroup::")

completed_status = self.hook.get_ray_job_status(self.job_id)
self.log.info(f"Status of completed job {self.job_id} is: {completed_status}")
if completed_status == JobStatus.SUCCEEDED:
yield TriggerEvent(
{
"status": "success",
"message": f"Job run {self.job_id} has completed successfully.",
"job_id": self.job_id,
}
)
elif completed_status == JobStatus.STOPPED:
yield TriggerEvent(
{
"status": "cancelled",
"message": f"Job run {self.job_id} has been stopped.",
"job_id": self.job_id,
}
)
else:
yield TriggerEvent(
{
"status": "error",
"message": f"Job run {self.job_id} has failed.",
"job_id": self.job_id,
}
)
yield TriggerEvent(
{
"status": completed_status,
"message": f"Job {self.job_id} completed with status {completed_status}",
"job_id": self.job_id,
}
)
except Exception as e:
yield TriggerEvent({"status": "error", "message": str(e), "job_id": self.job_id})
yield TriggerEvent({"status": str(JobStatus.FAILED), "message": str(e), "job_id": self.job_id})

def _is_terminal_state(self) -> bool:
"""
Expand Down
11 changes: 6 additions & 5 deletions tests/operators/test_ray_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import pytest
from airflow.exceptions import AirflowException, TaskDeferred
from ray.job_submission import JobStatus

from ray_provider.operators.ray import SubmitRayJob

Expand All @@ -13,7 +14,7 @@
num_gpus = 1
memory = 1024
resources = {"CPU": 2}
timeout = 600
job_timeout_seconds = 600
context = MagicMock()


Expand All @@ -27,7 +28,7 @@ def operator():
num_gpus=num_gpus,
memory=memory,
resources=resources,
timeout=timeout,
job_timeout_seconds=job_timeout_seconds,
task_id="Testcases",
)

Expand All @@ -42,7 +43,7 @@ def test_init(self, operator):
assert operator.num_gpus == num_gpus
assert operator.memory == memory
# assert operator.resources == resources
assert operator.timeout == timeout
assert operator.job_timeout_seconds == job_timeout_seconds

@patch("ray_provider.operators.ray.SubmitRayJob.hook")
def test_execute(self, mock_hook, operator):
Expand All @@ -65,13 +66,13 @@ def test_on_kill(self, mock_hook, operator):
mock_hook.delete_ray_job.assert_called_once_with("job_12345")

def test_execute_complete_success(self, operator):
event = {"status": "success", "message": "Job completed successfully"}
event = {"status": JobStatus.SUCCEEDED, "message": "Job completed successfully"}
operator.job_id = "job_12345"

assert operator.execute_complete(context, event) is None

def test_execute_complete_failure(self, operator):
event = {"status": "error", "message": "Job failed"}
event = {"status": JobStatus.FAILED, "message": "Job failed"}
operator.job_id = "job_12345"

with pytest.raises(AirflowException, match="Job failed"):
Expand Down
34 changes: 17 additions & 17 deletions tests/triggers/test_ray_triggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,38 +2,38 @@

import pytest
from airflow.triggers.base import TriggerEvent
from ray.dashboard.modules.job.sdk import JobStatus
from ray.job_submission import JobStatus

from ray_provider.triggers.ray import RayJobTrigger


class TestRayJobTrigger:

@pytest.mark.asyncio
@patch("ray_provider.triggers.ray.RayJobTrigger._is_terminal_state")
@patch("ray_provider.triggers.ray.RayJobTrigger.hook")
async def test_run_no_job_id(self, mock_hook, mock_is_terminal):
mock_is_terminal.return_value = True
mock_hook.get_ray_job_status.return_value = JobStatus.FAILED
trigger = RayJobTrigger(job_id="", poll_interval=1, conn_id="test", xcom_dashboard_url="test")

generator = trigger.run()
event = await generator.asend(None)
assert event == TriggerEvent({"status": "error", "message": "Job run has failed.", "job_id": ""})
assert event == TriggerEvent(
{"status": JobStatus.FAILED, "message": "Job completed with status FAILED", "job_id": ""}
)

@pytest.mark.asyncio
@patch("ray_provider.triggers.ray.RayJobTrigger._is_terminal_state")
@patch("ray_provider.triggers.ray.RayJobTrigger.hook")
async def test_run_job_succeeded(self, mock_hook):
trigger = RayJobTrigger(job_id="test_job_id", poll_interval=1, conn_id="test", xcom_dashboard_url="test")

async def test_run_job_succeeded(self, mock_hook, mock_is_terminal):
mock_is_terminal.side_effect = [False, True]
mock_hook.get_ray_job_status.return_value = JobStatus.SUCCEEDED

trigger = RayJobTrigger(job_id="test_job_id", poll_interval=1, conn_id="test", xcom_dashboard_url="test")
generator = trigger.run()
async for event in generator:
assert event == TriggerEvent(
{
"status": "success",
"message": "Job run test_job_id has completed successfully.",
"job_id": "test_job_id",
}
)
break # Stop after the first event for testing purposes
event = await generator.asend(None)
assert event == TriggerEvent(
{
"status": JobStatus.SUCCEEDED,
"message": f"Job test_job_id completed with status {JobStatus.SUCCEEDED}",
"job_id": "test_job_id",
}
)

0 comments on commit 51ce370

Please sign in to comment.