Skip to content

Commit

Permalink
Bugfix: Better exception handling and cluster clean up (#68)
Browse files Browse the repository at this point in the history
* initial changes

* update

* Updated

* removing on_failure_callback and on_success_callback

* Added comment
  • Loading branch information
venkatajagannath authored Sep 25, 2024
1 parent 48b9785 commit b7dc197
Show file tree
Hide file tree
Showing 4 changed files with 354 additions and 65 deletions.
94 changes: 50 additions & 44 deletions ray_provider/operators/ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,7 @@ def on_kill(self) -> None:
if hasattr(self, "hook") and self.job_id:
self.log.info(f"Deleting Ray job {self.job_id} due to task kill.")
self.hook.delete_ray_job(self.dashboard_url, self.job_id)
if self.ray_cluster_yaml:
self._delete_cluster()
self._delete_cluster()

@cached_property
def hook(self) -> PodOperatorHookProtocol:
Expand Down Expand Up @@ -262,48 +261,55 @@ def execute(self, context: Context) -> str:
:raises AirflowException: If the job fails, is cancelled, or reaches an unexpected state.
"""

self._setup_cluster(context=context)

self.dashboard_url = self._get_dashboard_url(context)

self.job_id = self.hook.submit_ray_job(
dashboard_url=self.dashboard_url,
entrypoint=self.entrypoint,
runtime_env=self.runtime_env,
entrypoint_num_cpus=self.num_cpus,
entrypoint_num_gpus=self.num_gpus,
entrypoint_memory=self.memory,
entrypoint_resources=self.ray_resources,
)
self.log.info(f"Ray job submitted with id: {self.job_id}")

if self.wait_for_completion:
current_status = self.hook.get_ray_job_status(self.dashboard_url, self.job_id)
self.log.info(f"Current job status for {self.job_id} is: {current_status}")

if current_status not in self.terminal_states:
self.log.info("Deferring the polling to RayJobTrigger...")
self.defer(
trigger=RayJobTrigger(
job_id=self.job_id,
conn_id=self.conn_id,
xcom_dashboard_url=self.dashboard_url,
poll_interval=self.poll_interval,
fetch_logs=self.fetch_logs,
),
method_name="execute_complete",
timeout=timedelta(seconds=self.job_timeout_seconds),
)
elif current_status == JobStatus.SUCCEEDED:
self.log.info("Job %s completed successfully", self.job_id)
elif current_status == JobStatus.FAILED:
raise AirflowException(f"Job failed:\n{self.job_id}")
elif current_status == JobStatus.STOPPED:
raise AirflowException(f"Job was cancelled:\n{self.job_id}")
else:
raise AirflowException(f"Encountered unexpected state `{current_status}` for job_id `{self.job_id}`")

return self.job_id
try:
self._setup_cluster(context=context)

self.dashboard_url = self._get_dashboard_url(context)

self.job_id = self.hook.submit_ray_job(
dashboard_url=self.dashboard_url,
entrypoint=self.entrypoint,
runtime_env=self.runtime_env,
entrypoint_num_cpus=self.num_cpus,
entrypoint_num_gpus=self.num_gpus,
entrypoint_memory=self.memory,
entrypoint_resources=self.ray_resources,
)
self.log.info(f"Ray job submitted with id: {self.job_id}")

if self.wait_for_completion:
current_status = self.hook.get_ray_job_status(self.dashboard_url, self.job_id)
self.log.info(f"Current job status for {self.job_id} is: {current_status}")

if current_status not in self.terminal_states:
self.log.info("Deferring the polling to RayJobTrigger...")
self.defer(
trigger=RayJobTrigger(
job_id=self.job_id,
conn_id=self.conn_id,
xcom_dashboard_url=self.dashboard_url,
ray_cluster_yaml=self.ray_cluster_yaml,
gpu_device_plugin_yaml=self.gpu_device_plugin_yaml,
poll_interval=self.poll_interval,
fetch_logs=self.fetch_logs,
),
method_name="execute_complete",
timeout=timedelta(seconds=self.job_timeout_seconds),
)
elif current_status == JobStatus.SUCCEEDED:
self.log.info("Job %s completed successfully", self.job_id)
elif current_status == JobStatus.FAILED:
raise AirflowException(f"Job failed:\n{self.job_id}")
elif current_status == JobStatus.STOPPED:
raise AirflowException(f"Job was cancelled:\n{self.job_id}")
else:
raise AirflowException(
f"Encountered unexpected state `{current_status}` for job_id `{self.job_id}`"
)
return self.job_id
except Exception as e:
self._delete_cluster()
raise AirflowException(f"SubmitRayJob operator failed due to {e}. Cleaning up resources...")

def execute_complete(self, context: Context, event: dict[str, Any]) -> None:
"""
Expand Down
30 changes: 30 additions & 0 deletions ray_provider/triggers/ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,17 @@ def __init__(
job_id: str,
conn_id: str,
xcom_dashboard_url: str | None,
ray_cluster_yaml: str | None,
gpu_device_plugin_yaml: str,
poll_interval: int = 30,
fetch_logs: bool = True,
):
super().__init__() # type: ignore[no-untyped-call]
self.job_id = job_id
self.conn_id = conn_id
self.dashboard_url = xcom_dashboard_url
self.ray_cluster_yaml = ray_cluster_yaml
self.gpu_device_plugin_yaml = gpu_device_plugin_yaml
self.fetch_logs = fetch_logs
self.poll_interval = poll_interval

Expand All @@ -52,6 +56,8 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
"job_id": self.job_id,
"conn_id": self.conn_id,
"xcom_dashboard_url": self.dashboard_url,
"ray_cluster_yaml": self.ray_cluster_yaml,
"gpu_device_plugin_yaml": self.gpu_device_plugin_yaml,
"fetch_logs": self.fetch_logs,
"poll_interval": self.poll_interval,
},
Expand All @@ -66,6 +72,28 @@ def hook(self) -> RayHook:
"""
return RayHook(conn_id=self.conn_id)

async def cleanup(self) -> None:
"""
Cleanup method to ensure resources are properly deleted. This will be called when the trigger encounters an exception.
Example scenario: A job is submitted using the @ray.task decorator with a Ray specification. After the cluster is started
and the job is submitted, the trigger begins tracking its progress. However, if the job is stopped through the UI at this stage, the cluster
resources are not deleted.
"""
try:
if self.ray_cluster_yaml:
self.log.info(f"Attempting to delete Ray cluster using YAML: {self.ray_cluster_yaml}")
loop = asyncio.get_running_loop()
await loop.run_in_executor(
None, self.hook.delete_ray_cluster, self.ray_cluster_yaml, self.gpu_device_plugin_yaml
)
self.log.info("Ray cluster deletion process completed")
else:
self.log.info("No Ray cluster YAML provided, skipping cluster deletion")
except Exception as e:
self.log.error(f"Unexpected error during cleanup: {str(e)}")

async def _poll_status(self) -> None:
while not self._is_terminal_state():
await asyncio.sleep(self.poll_interval)
Expand Down Expand Up @@ -109,6 +137,8 @@ async def run(self) -> AsyncIterator[TriggerEvent]:
}
)
except Exception as e:
self.log.error(f"Error occurred: {str(e)}")
await self.cleanup()
yield TriggerEvent({"status": str(JobStatus.FAILED), "message": str(e), "job_id": self.job_id})

def _is_terminal_state(self) -> bool:
Expand Down
182 changes: 181 additions & 1 deletion tests/operators/test_ray_operators.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from unittest.mock import MagicMock, Mock, patch

import pytest
from airflow.exceptions import AirflowException
from airflow.exceptions import AirflowException, TaskDeferred
from ray.job_submission import JobStatus

from ray_provider.operators.ray import DeleteRayCluster, SetupRayCluster, SubmitRayJob
from ray_provider.triggers.ray import RayJobTrigger


class TestSetupRayCluster:
Expand Down Expand Up @@ -388,3 +389,182 @@ def test_delete_cluster_exception(self, mock_ray_hook):

assert str(exc_info.value) == "Cluster deletion failed"
mock_hook.delete_ray_cluster.assert_called_once()

@pytest.mark.parametrize(
"xcom_task_key, expected_task, expected_key",
[
("task.key", "task", "key"),
("single_key", None, "single_key"),
],
)
def test_get_dashboard_url_xcom_variants(self, operator, context, xcom_task_key, expected_task, expected_key):
operator.xcom_task_key = xcom_task_key
context["ti"].xcom_pull.return_value = "http://dashboard.url"

result = operator._get_dashboard_url(context)

assert result == "http://dashboard.url"
if expected_task:
context["ti"].xcom_pull.assert_called_once_with(task_ids=expected_task, key=expected_key)
else:
context["ti"].xcom_pull.assert_called_once_with(task_ids=context["task"].task_id, key=expected_key)

def test_execute_job_unexpected_state(self, mock_hook, context):
operator = SubmitRayJob(
task_id="test_task",
conn_id="test_conn",
entrypoint="python script.py",
runtime_env={},
wait_for_completion=True,
)
mock_hook.submit_ray_job.return_value = "test_job_id"
mock_hook.get_ray_job_status.return_value = "UNEXPECTED_STATE"

with patch.object(operator, "_setup_cluster"), pytest.raises(TaskDeferred) as exc_info:
operator.execute(context)

assert isinstance(exc_info.value.trigger, RayJobTrigger)

@pytest.mark.parametrize("dashboard_url", [None, "http://dashboard.url"])
def test_execute_defer(self, mock_hook, context, dashboard_url):
operator = SubmitRayJob(
task_id="test_task",
conn_id="test_conn",
entrypoint="python script.py",
runtime_env={},
wait_for_completion=True,
ray_cluster_yaml="cluster.yaml",
gpu_device_plugin_yaml="gpu_plugin.yaml",
poll_interval=30,
fetch_logs=True,
job_timeout_seconds=600,
)
mock_hook.submit_ray_job.return_value = "test_job_id"
mock_hook.get_ray_job_status.return_value = JobStatus.PENDING

with patch.object(operator, "_setup_cluster"), patch.object(
operator, "_get_dashboard_url", return_value=dashboard_url
), pytest.raises(TaskDeferred) as exc_info:
operator.execute(context)

trigger = exc_info.value.trigger
assert isinstance(trigger, RayJobTrigger)
assert trigger.job_id == "test_job_id"
assert trigger.conn_id == "test_conn"
assert trigger.dashboard_url == dashboard_url
assert trigger.ray_cluster_yaml == "cluster.yaml"
assert trigger.gpu_device_plugin_yaml == "gpu_plugin.yaml"
assert trigger.poll_interval == 30
assert trigger.fetch_logs is True

def test_execute_complete_unexpected_status(self, operator):
event = {"status": "UNEXPECTED", "message": "Unexpected status"}
with patch.object(operator, "_delete_cluster"), pytest.raises(AirflowException) as exc_info:
operator.execute_complete({}, event)

assert "Unexpected event status" in str(exc_info.value)

def test_execute_complete_cleanup_on_exception(self, operator):
event = {"status": JobStatus.FAILED, "message": "Job failed"}
with patch.object(operator, "_delete_cluster") as mock_delete_cluster, pytest.raises(AirflowException):
operator.execute_complete({}, event)

mock_delete_cluster.assert_called_once()

def test_execute_exception_handling(self, mock_hook, context):
operator = SubmitRayJob(
task_id="test_task",
conn_id="test_conn",
entrypoint="python script.py",
runtime_env={},
ray_cluster_yaml="cluster.yaml",
)

mock_hook.submit_ray_job.side_effect = Exception("Job submission failed")

with patch.object(operator, "_setup_cluster"), patch.object(
operator, "_delete_cluster"
) as mock_delete_cluster, pytest.raises(AirflowException) as exc_info:
operator.execute(context)

assert "SubmitRayJob operator failed due to Job submission failed" in str(exc_info.value)
mock_delete_cluster.assert_called_once()

def test_execute_cluster_setup_exception(self, mock_hook, context):
operator = SubmitRayJob(
task_id="test_task",
conn_id="test_conn",
entrypoint="python script.py",
runtime_env={},
ray_cluster_yaml="cluster.yaml",
)

with patch.object(operator, "_setup_cluster", side_effect=Exception("Cluster setup failed")), patch.object(
operator, "_delete_cluster"
) as mock_delete_cluster, pytest.raises(AirflowException) as exc_info:
operator.execute(context)

assert "SubmitRayJob operator failed due to Cluster setup failed" in str(exc_info.value)
mock_delete_cluster.assert_called_once()

def test_execute_with_wait_and_defer(self, mock_hook, context):
operator = SubmitRayJob(
task_id="test_task",
conn_id="test_conn",
entrypoint="python script.py",
runtime_env={},
wait_for_completion=True,
poll_interval=30,
fetch_logs=True,
job_timeout_seconds=600,
)

mock_hook.submit_ray_job.return_value = "test_job_id"
mock_hook.get_ray_job_status.return_value = JobStatus.PENDING

with patch.object(operator, "_setup_cluster"), patch.object(operator, "defer") as mock_defer:
operator.execute(context)

mock_defer.assert_called_once()
args, kwargs = mock_defer.call_args
assert isinstance(kwargs["trigger"], RayJobTrigger)
assert kwargs["method_name"] == "execute_complete"
assert kwargs["timeout"].total_seconds() == 600

def test_execute_complete_with_cleanup(self, operator):
operator.job_id = "test_job_id"
event = {"status": JobStatus.FAILED, "message": "Job failed"}

with patch.object(operator, "_delete_cluster") as mock_delete_cluster, pytest.raises(AirflowException):
operator.execute_complete({}, event)

mock_delete_cluster.assert_called_once()

def test_execute_without_wait_no_cleanup(self, mock_hook, context):
operator = SubmitRayJob(
task_id="test_task",
conn_id="test_conn",
entrypoint="python script.py",
runtime_env={},
wait_for_completion=False,
)

mock_hook.submit_ray_job.return_value = "test_job_id"

with patch.object(operator, "_setup_cluster") as mock_setup_cluster, patch.object(
operator, "_delete_cluster"
) as mock_delete_cluster:
result = operator.execute(context)

mock_setup_cluster.assert_called_once_with(context=context)
assert result == "test_job_id"
mock_hook.submit_ray_job.assert_called_once_with(
dashboard_url=None,
entrypoint="python script.py",
runtime_env={},
entrypoint_num_cpus=0,
entrypoint_num_gpus=0,
entrypoint_memory=0,
entrypoint_resources=None,
)
mock_delete_cluster.assert_not_called()
Loading

0 comments on commit b7dc197

Please sign in to comment.