Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stop catching generic Exception in triggerer #98

Merged
merged 10 commits into from
Nov 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ build-whl: setup-dev ## Build installable whl file
# Delete any previous wheels, so different versions don't conflict
rm dev/include/*
cd dev
# delete potential previous versions, otherwise there will be a conflict
# during installation
rm include/*
python3 -m build --outdir dev/include/

.PHONY: docker-run
Expand Down
3 changes: 3 additions & 0 deletions ray_provider/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from ray.job_submission import JobStatus

TERMINAL_JOB_STATUSES = {JobStatus.SUCCEEDED, JobStatus.STOPPED, JobStatus.FAILED}
77 changes: 38 additions & 39 deletions ray_provider/triggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
from typing import Any, AsyncIterator

from airflow.triggers.base import BaseTrigger, TriggerEvent
from kubernetes.client.exceptions import ApiException
from ray.job_submission import JobStatus

from ray_provider.constants import TERMINAL_JOB_STATUSES
from ray_provider.hooks import RayHook


Expand Down Expand Up @@ -43,6 +45,7 @@ def __init__(
self.gpu_device_plugin_yaml = gpu_device_plugin_yaml
self.fetch_logs = fetch_logs
self.poll_interval = poll_interval
self._job_status: None | JobStatus = None

def serialize(self) -> tuple[str, dict[str, Any]]:
"""
Expand Down Expand Up @@ -81,22 +84,22 @@ async def cleanup(self) -> None:
resources are not deleted.

"""
try:
if self.ray_cluster_yaml:
self.log.info(f"Attempting to delete Ray cluster using YAML: {self.ray_cluster_yaml}")
loop = asyncio.get_running_loop()
await loop.run_in_executor(
None, self.hook.delete_ray_cluster, self.ray_cluster_yaml, self.gpu_device_plugin_yaml
)
self.log.info("Ray cluster deletion process completed")
else:
self.log.info("No Ray cluster YAML provided, skipping cluster deletion")
except Exception as e:
self.log.error(f"Unexpected error during cleanup: {str(e)}")
if self.ray_cluster_yaml:
self.log.info(f"Attempting to delete Ray cluster using YAML: {self.ray_cluster_yaml}")
loop = asyncio.get_running_loop()
await loop.run_in_executor(
None, self.hook.delete_ray_cluster, self.ray_cluster_yaml, self.gpu_device_plugin_yaml
)
self.log.info("Ray cluster deletion process completed")
else:
self.log.info("No Ray cluster YAML provided, skipping cluster deletion")

async def _poll_status(self) -> None:
while not self._is_terminal_state():
self._job_status = self.hook.get_ray_job_status(self.dashboard_url, self.job_id)
while self._job_status not in TERMINAL_JOB_STATUSES:
self.log.info(f"Status of job {self.job_id} is: {self._job_status}")
await asyncio.sleep(self.poll_interval)
self._job_status = self.hook.get_ray_job_status(self.dashboard_url, self.job_id)

async def _stream_logs(self) -> None:
"""
Expand All @@ -111,46 +114,42 @@ async def _stream_logs(self) -> None:

async def run(self) -> AsyncIterator[TriggerEvent]:
"""
Asynchronously polls the job status and yields events based on the job's state.
Asynchronously polls the Ray job status and yields events based on the job's state.

This method gets job status at each poll interval and streams logs if available.
It yields a TriggerEvent upon job completion, cancellation, or failure.

:yield: TriggerEvent containing the status, message, and job ID related to the job.
"""
try:
self.log.info(f"Polling for job {self.job_id} every {self.poll_interval} seconds...")
self.log.info(f"::group:: Trigger 1/2: Checking the job status")
self.log.info(f"Polling for job {self.job_id} every {self.poll_interval} seconds...")

try:
tasks = [self._poll_status()]
if self.fetch_logs:
tasks.append(self._stream_logs())

await asyncio.gather(*tasks)
except ApiException as e:
error_msg = str(e)
self.log.info(f"::endgroup::")
self.log.error("::group:: Trigger unable to poll job status")
self.log.error("Exception details:", exc_info=True)
self.log.info("Attempting to clean up...")
await self.cleanup()
tatiana marked this conversation as resolved.
Show resolved Hide resolved
self.log.info("Cleanup completed!")
self.log.info(f"::endgroup::")

yield TriggerEvent({"status": "EXCEPTION", "message": error_msg, "job_id": self.job_id})
else:
self.log.info(f"::endgroup::")
self.log.info(f"::group:: Trigger 2/2: Job reached a terminal state")
self.log.info(f"Status of completed job {self.job_id} is: {self._job_status}")
self.log.info(f"::endgroup::")

completed_status = self.hook.get_ray_job_status(self.dashboard_url, self.job_id)
self.log.info(f"Status of completed job {self.job_id} is: {completed_status}")
yield TriggerEvent(
{
"status": completed_status,
"message": f"Job {self.job_id} completed with status {completed_status}",
"status": self._job_status,
"message": f"Job {self.job_id} completed with status {self._job_status}",
"job_id": self.job_id,
}
)
except Exception as e:
self.log.error(f"Error occurred: {str(e)}")
await self.cleanup()
yield TriggerEvent({"status": str(JobStatus.FAILED), "message": str(e), "job_id": self.job_id})

def _is_terminal_state(self) -> bool:
"""
Checks if the Ray job is in a terminal state.

A terminal state is one of the following: SUCCEEDED, STOPPED, or FAILED.

:return: True if the job is in a terminal state, False otherwise.
"""
return self.hook.get_ray_job_status(self.dashboard_url, self.job_id) in (
JobStatus.SUCCEEDED,
JobStatus.STOPPED,
JobStatus.FAILED,
)
91 changes: 36 additions & 55 deletions tests/test_triggers.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from __future__ import annotations

import logging
from unittest.mock import AsyncMock, call, patch

import pytest
from airflow.triggers.base import TriggerEvent
from kubernetes.client.exceptions import ApiException
tatiana marked this conversation as resolved.
Show resolved Hide resolved
from ray.job_submission import JobStatus

from ray_provider.triggers import RayJobTrigger
Expand All @@ -22,11 +25,9 @@ def trigger(self):
)

@pytest.mark.asyncio
@patch("ray_provider.triggers.RayJobTrigger._is_terminal_state")
@patch("ray_provider.triggers.RayJobTrigger.hook.get_ray_job_status", return_value=JobStatus.FAILED)
@patch("ray_provider.triggers.RayJobTrigger.hook")
async def test_run_no_job_id(self, mock_hook, mock_is_terminal):
mock_is_terminal.return_value = True
mock_hook.get_ray_job_status.return_value = JobStatus.FAILED
async def test_run_no_job_id(self, mock_hook, mock_job_status):
trigger = RayJobTrigger(
job_id="",
poll_interval=1,
Expand All @@ -42,11 +43,12 @@ async def test_run_no_job_id(self, mock_hook, mock_is_terminal):
)

@pytest.mark.asyncio
@patch("ray_provider.triggers.RayJobTrigger._is_terminal_state")
@patch(
"ray_provider.triggers.RayJobTrigger.hook.get_ray_job_status",
side_effect=[JobStatus.RUNNING, JobStatus.SUCCEEDED],
)
@patch("ray_provider.triggers.RayJobTrigger.hook")
async def test_run_job_succeeded(self, mock_hook, mock_is_terminal):
mock_is_terminal.side_effect = [False, True]
mock_hook.get_ray_job_status.return_value = JobStatus.SUCCEEDED
async def test_run_job_succeeded(self, mock_hook, mock_job_status):
trigger = RayJobTrigger(
job_id="test_job_id",
poll_interval=1,
Expand All @@ -66,12 +68,12 @@ async def test_run_job_succeeded(self, mock_hook, mock_is_terminal):
)

@pytest.mark.asyncio
@patch("ray_provider.triggers.RayJobTrigger._is_terminal_state")
@patch(
"ray_provider.triggers.RayJobTrigger.hook.get_ray_job_status",
side_effect=[JobStatus.RUNNING, JobStatus.STOPPED],
)
@patch("ray_provider.triggers.RayJobTrigger.hook")
async def test_run_job_stopped(self, mock_hook, mock_is_terminal, trigger):
mock_is_terminal.side_effect = [False, True]
mock_hook.get_ray_job_status.return_value = JobStatus.STOPPED

async def test_run_job_stopped(self, mock_hook, mock_job_status, trigger):
generator = trigger.run()
event = await generator.asend(None)

Expand All @@ -84,12 +86,11 @@ async def test_run_job_stopped(self, mock_hook, mock_is_terminal, trigger):
)

@pytest.mark.asyncio
@patch("ray_provider.triggers.RayJobTrigger._is_terminal_state")
@patch(
"ray_provider.triggers.RayJobTrigger.hook.get_ray_job_status", side_effect=[JobStatus.RUNNING, JobStatus.FAILED]
)
@patch("ray_provider.triggers.RayJobTrigger.hook")
async def test_run_job_failed(self, mock_hook, mock_is_terminal, trigger):
mock_is_terminal.side_effect = [False, True]
mock_hook.get_ray_job_status.return_value = JobStatus.FAILED

async def test_run_job_failed(self, mock_hook, mock_job_status, trigger):
generator = trigger.run()
event = await generator.asend(None)

Expand All @@ -102,12 +103,13 @@ async def test_run_job_failed(self, mock_hook, mock_is_terminal, trigger):
)

@pytest.mark.asyncio
@patch("ray_provider.triggers.RayJobTrigger._is_terminal_state")
@patch(
"ray_provider.triggers.RayJobTrigger.hook.get_ray_job_status",
side_effect=[JobStatus.RUNNING, JobStatus.SUCCEEDED],
)
@patch("ray_provider.triggers.RayJobTrigger.hook")
@patch("ray_provider.triggers.RayJobTrigger._stream_logs")
async def test_run_with_log_streaming(self, mock_stream_logs, mock_hook, mock_is_terminal, trigger):
mock_is_terminal.side_effect = [False, True]
mock_hook.get_ray_job_status.return_value = JobStatus.SUCCEEDED
async def test_run_with_log_streaming(self, mock_stream_logs, mock_hook, mock_job_status, trigger):
mock_stream_logs.return_value = None

generator = trigger.run()
Expand Down Expand Up @@ -156,19 +158,6 @@ def test_serialize(self, trigger):
},
)

@pytest.mark.asyncio
@patch("ray_provider.triggers.RayJobTrigger.hook")
async def test_is_terminal_state(self, mock_hook, trigger):
mock_hook.get_ray_job_status.side_effect = [
JobStatus.PENDING,
JobStatus.RUNNING,
JobStatus.SUCCEEDED,
]

assert not trigger._is_terminal_state()
assert not trigger._is_terminal_state()
assert trigger._is_terminal_state()

@pytest.mark.asyncio
@patch.object(RayJobTrigger, "hook")
@patch.object(logging.Logger, "info")
Expand Down Expand Up @@ -200,41 +189,33 @@ async def test_cleanup_without_cluster_yaml(self, mock_log_info):

mock_log_info.assert_called_once_with("No Ray cluster YAML provided, skipping cluster deletion")

@pytest.mark.asyncio
@patch.object(RayJobTrigger, "hook")
@patch.object(logging.Logger, "error")
async def test_cleanup_with_exception(self, mock_log_error, mock_hook, trigger):
mock_hook.delete_ray_cluster.side_effect = Exception("Test exception")

await trigger.cleanup()

mock_log_error.assert_called_once_with("Unexpected error during cleanup: Test exception")

@pytest.mark.asyncio
@patch("asyncio.sleep", new_callable=AsyncMock)
@patch("ray_provider.triggers.RayJobTrigger._is_terminal_state")
async def test_poll_status(self, mock_is_terminal, mock_sleep, trigger):
mock_is_terminal.side_effect = [False, False, True]

@patch(
"ray_provider.triggers.RayJobTrigger.hook.get_ray_job_status",
side_effect=[JobStatus.RUNNING, JobStatus.RUNNING, JobStatus.SUCCEEDED],
)
@patch("ray_provider.triggers.RayJobTrigger.hook")
async def test_poll_status(self, mock_hook, mock_job_status, mock_sleep, trigger):
await trigger._poll_status()

assert mock_sleep.call_count == 2
mock_sleep.assert_called_with(1)

@pytest.mark.asyncio
@patch("ray_provider.triggers.RayJobTrigger._is_terminal_state")
@patch(
"ray_provider.triggers.RayJobTrigger.hook.get_ray_job_status", side_effect=ApiException("Failed to get job.")
)
@patch("ray_provider.triggers.RayJobTrigger.hook")
@patch("ray_provider.triggers.RayJobTrigger.cleanup")
async def test_run_with_exception(self, mock_cleanup, mock_hook, mock_is_terminal, trigger):
mock_is_terminal.side_effect = Exception("Test exception")

async def test_run_with_exception(self, mock_cleanup, mock_hook, mock_job_status, trigger):
generator = trigger.run()
event = await generator.asend(None)

assert event == TriggerEvent(
{
"status": str(JobStatus.FAILED),
"message": "Test exception",
"status": "EXCEPTION",
"message": "(Failed to get job.)\nReason: None\n",
"job_id": "test_job_id",
}
)
Expand Down