diff --git a/Makefile b/Makefile index 9a1c12b..c0434cf 100644 --- a/Makefile +++ b/Makefile @@ -14,6 +14,9 @@ build-whl: setup-dev ## Build installable whl file # Delete any previous wheels, so different versions don't conflict rm dev/include/* cd dev + # delete potential previous versions, otherwise there will be a conflict + # during installation + rm include/* python3 -m build --outdir dev/include/ .PHONY: docker-run diff --git a/ray_provider/constants.py b/ray_provider/constants.py new file mode 100644 index 0000000..9e37f6a --- /dev/null +++ b/ray_provider/constants.py @@ -0,0 +1,3 @@ +from ray.job_submission import JobStatus + +TERMINAL_JOB_STATUSES = {JobStatus.SUCCEEDED, JobStatus.STOPPED, JobStatus.FAILED} diff --git a/ray_provider/triggers.py b/ray_provider/triggers.py index 3252c99..f6d9358 100644 --- a/ray_provider/triggers.py +++ b/ray_provider/triggers.py @@ -5,8 +5,10 @@ from typing import Any, AsyncIterator from airflow.triggers.base import BaseTrigger, TriggerEvent +from kubernetes.client.exceptions import ApiException from ray.job_submission import JobStatus +from ray_provider.constants import TERMINAL_JOB_STATUSES from ray_provider.hooks import RayHook @@ -43,6 +45,7 @@ def __init__( self.gpu_device_plugin_yaml = gpu_device_plugin_yaml self.fetch_logs = fetch_logs self.poll_interval = poll_interval + self._job_status: None | JobStatus = None def serialize(self) -> tuple[str, dict[str, Any]]: """ @@ -81,22 +84,22 @@ async def cleanup(self) -> None: resources are not deleted. """ - try: - if self.ray_cluster_yaml: - self.log.info(f"Attempting to delete Ray cluster using YAML: {self.ray_cluster_yaml}") - loop = asyncio.get_running_loop() - await loop.run_in_executor( - None, self.hook.delete_ray_cluster, self.ray_cluster_yaml, self.gpu_device_plugin_yaml - ) - self.log.info("Ray cluster deletion process completed") - else: - self.log.info("No Ray cluster YAML provided, skipping cluster deletion") - except Exception as e: - self.log.error(f"Unexpected error during cleanup: {str(e)}") + if self.ray_cluster_yaml: + self.log.info(f"Attempting to delete Ray cluster using YAML: {self.ray_cluster_yaml}") + loop = asyncio.get_running_loop() + await loop.run_in_executor( + None, self.hook.delete_ray_cluster, self.ray_cluster_yaml, self.gpu_device_plugin_yaml + ) + self.log.info("Ray cluster deletion process completed") + else: + self.log.info("No Ray cluster YAML provided, skipping cluster deletion") async def _poll_status(self) -> None: - while not self._is_terminal_state(): + self._job_status = self.hook.get_ray_job_status(self.dashboard_url, self.job_id) + while self._job_status not in TERMINAL_JOB_STATUSES: + self.log.info(f"Status of job {self.job_id} is: {self._job_status}") await asyncio.sleep(self.poll_interval) + self._job_status = self.hook.get_ray_job_status(self.dashboard_url, self.job_id) async def _stream_logs(self) -> None: """ @@ -111,46 +114,42 @@ async def _stream_logs(self) -> None: async def run(self) -> AsyncIterator[TriggerEvent]: """ - Asynchronously polls the job status and yields events based on the job's state. + Asynchronously polls the Ray job status and yields events based on the job's state. This method gets job status at each poll interval and streams logs if available. It yields a TriggerEvent upon job completion, cancellation, or failure. :yield: TriggerEvent containing the status, message, and job ID related to the job. """ - try: - self.log.info(f"Polling for job {self.job_id} every {self.poll_interval} seconds...") + self.log.info(f"::group:: Trigger 1/2: Checking the job status") + self.log.info(f"Polling for job {self.job_id} every {self.poll_interval} seconds...") + try: tasks = [self._poll_status()] if self.fetch_logs: tasks.append(self._stream_logs()) - await asyncio.gather(*tasks) + except ApiException as e: + error_msg = str(e) + self.log.info(f"::endgroup::") + self.log.error("::group:: Trigger unable to poll job status") + self.log.error("Exception details:", exc_info=True) + self.log.info("Attempting to clean up...") + await self.cleanup() + self.log.info("Cleanup completed!") + self.log.info(f"::endgroup::") + + yield TriggerEvent({"status": "EXCEPTION", "message": error_msg, "job_id": self.job_id}) + else: + self.log.info(f"::endgroup::") + self.log.info(f"::group:: Trigger 2/2: Job reached a terminal state") + self.log.info(f"Status of completed job {self.job_id} is: {self._job_status}") + self.log.info(f"::endgroup::") - completed_status = self.hook.get_ray_job_status(self.dashboard_url, self.job_id) - self.log.info(f"Status of completed job {self.job_id} is: {completed_status}") yield TriggerEvent( { - "status": completed_status, - "message": f"Job {self.job_id} completed with status {completed_status}", + "status": self._job_status, + "message": f"Job {self.job_id} completed with status {self._job_status}", "job_id": self.job_id, } ) - except Exception as e: - self.log.error(f"Error occurred: {str(e)}") - await self.cleanup() - yield TriggerEvent({"status": str(JobStatus.FAILED), "message": str(e), "job_id": self.job_id}) - - def _is_terminal_state(self) -> bool: - """ - Checks if the Ray job is in a terminal state. - - A terminal state is one of the following: SUCCEEDED, STOPPED, or FAILED. - - :return: True if the job is in a terminal state, False otherwise. - """ - return self.hook.get_ray_job_status(self.dashboard_url, self.job_id) in ( - JobStatus.SUCCEEDED, - JobStatus.STOPPED, - JobStatus.FAILED, - ) diff --git a/tests/test_triggers.py b/tests/test_triggers.py index f97611e..fff3fee 100644 --- a/tests/test_triggers.py +++ b/tests/test_triggers.py @@ -1,8 +1,11 @@ +from __future__ import annotations + import logging from unittest.mock import AsyncMock, call, patch import pytest from airflow.triggers.base import TriggerEvent +from kubernetes.client.exceptions import ApiException from ray.job_submission import JobStatus from ray_provider.triggers import RayJobTrigger @@ -22,11 +25,9 @@ def trigger(self): ) @pytest.mark.asyncio - @patch("ray_provider.triggers.RayJobTrigger._is_terminal_state") + @patch("ray_provider.triggers.RayJobTrigger.hook.get_ray_job_status", return_value=JobStatus.FAILED) @patch("ray_provider.triggers.RayJobTrigger.hook") - async def test_run_no_job_id(self, mock_hook, mock_is_terminal): - mock_is_terminal.return_value = True - mock_hook.get_ray_job_status.return_value = JobStatus.FAILED + async def test_run_no_job_id(self, mock_hook, mock_job_status): trigger = RayJobTrigger( job_id="", poll_interval=1, @@ -42,11 +43,12 @@ async def test_run_no_job_id(self, mock_hook, mock_is_terminal): ) @pytest.mark.asyncio - @patch("ray_provider.triggers.RayJobTrigger._is_terminal_state") + @patch( + "ray_provider.triggers.RayJobTrigger.hook.get_ray_job_status", + side_effect=[JobStatus.RUNNING, JobStatus.SUCCEEDED], + ) @patch("ray_provider.triggers.RayJobTrigger.hook") - async def test_run_job_succeeded(self, mock_hook, mock_is_terminal): - mock_is_terminal.side_effect = [False, True] - mock_hook.get_ray_job_status.return_value = JobStatus.SUCCEEDED + async def test_run_job_succeeded(self, mock_hook, mock_job_status): trigger = RayJobTrigger( job_id="test_job_id", poll_interval=1, @@ -66,12 +68,12 @@ async def test_run_job_succeeded(self, mock_hook, mock_is_terminal): ) @pytest.mark.asyncio - @patch("ray_provider.triggers.RayJobTrigger._is_terminal_state") + @patch( + "ray_provider.triggers.RayJobTrigger.hook.get_ray_job_status", + side_effect=[JobStatus.RUNNING, JobStatus.STOPPED], + ) @patch("ray_provider.triggers.RayJobTrigger.hook") - async def test_run_job_stopped(self, mock_hook, mock_is_terminal, trigger): - mock_is_terminal.side_effect = [False, True] - mock_hook.get_ray_job_status.return_value = JobStatus.STOPPED - + async def test_run_job_stopped(self, mock_hook, mock_job_status, trigger): generator = trigger.run() event = await generator.asend(None) @@ -84,12 +86,11 @@ async def test_run_job_stopped(self, mock_hook, mock_is_terminal, trigger): ) @pytest.mark.asyncio - @patch("ray_provider.triggers.RayJobTrigger._is_terminal_state") + @patch( + "ray_provider.triggers.RayJobTrigger.hook.get_ray_job_status", side_effect=[JobStatus.RUNNING, JobStatus.FAILED] + ) @patch("ray_provider.triggers.RayJobTrigger.hook") - async def test_run_job_failed(self, mock_hook, mock_is_terminal, trigger): - mock_is_terminal.side_effect = [False, True] - mock_hook.get_ray_job_status.return_value = JobStatus.FAILED - + async def test_run_job_failed(self, mock_hook, mock_job_status, trigger): generator = trigger.run() event = await generator.asend(None) @@ -102,12 +103,13 @@ async def test_run_job_failed(self, mock_hook, mock_is_terminal, trigger): ) @pytest.mark.asyncio - @patch("ray_provider.triggers.RayJobTrigger._is_terminal_state") + @patch( + "ray_provider.triggers.RayJobTrigger.hook.get_ray_job_status", + side_effect=[JobStatus.RUNNING, JobStatus.SUCCEEDED], + ) @patch("ray_provider.triggers.RayJobTrigger.hook") @patch("ray_provider.triggers.RayJobTrigger._stream_logs") - async def test_run_with_log_streaming(self, mock_stream_logs, mock_hook, mock_is_terminal, trigger): - mock_is_terminal.side_effect = [False, True] - mock_hook.get_ray_job_status.return_value = JobStatus.SUCCEEDED + async def test_run_with_log_streaming(self, mock_stream_logs, mock_hook, mock_job_status, trigger): mock_stream_logs.return_value = None generator = trigger.run() @@ -156,19 +158,6 @@ def test_serialize(self, trigger): }, ) - @pytest.mark.asyncio - @patch("ray_provider.triggers.RayJobTrigger.hook") - async def test_is_terminal_state(self, mock_hook, trigger): - mock_hook.get_ray_job_status.side_effect = [ - JobStatus.PENDING, - JobStatus.RUNNING, - JobStatus.SUCCEEDED, - ] - - assert not trigger._is_terminal_state() - assert not trigger._is_terminal_state() - assert trigger._is_terminal_state() - @pytest.mark.asyncio @patch.object(RayJobTrigger, "hook") @patch.object(logging.Logger, "info") @@ -200,41 +189,33 @@ async def test_cleanup_without_cluster_yaml(self, mock_log_info): mock_log_info.assert_called_once_with("No Ray cluster YAML provided, skipping cluster deletion") - @pytest.mark.asyncio - @patch.object(RayJobTrigger, "hook") - @patch.object(logging.Logger, "error") - async def test_cleanup_with_exception(self, mock_log_error, mock_hook, trigger): - mock_hook.delete_ray_cluster.side_effect = Exception("Test exception") - - await trigger.cleanup() - - mock_log_error.assert_called_once_with("Unexpected error during cleanup: Test exception") - @pytest.mark.asyncio @patch("asyncio.sleep", new_callable=AsyncMock) - @patch("ray_provider.triggers.RayJobTrigger._is_terminal_state") - async def test_poll_status(self, mock_is_terminal, mock_sleep, trigger): - mock_is_terminal.side_effect = [False, False, True] - + @patch( + "ray_provider.triggers.RayJobTrigger.hook.get_ray_job_status", + side_effect=[JobStatus.RUNNING, JobStatus.RUNNING, JobStatus.SUCCEEDED], + ) + @patch("ray_provider.triggers.RayJobTrigger.hook") + async def test_poll_status(self, mock_hook, mock_job_status, mock_sleep, trigger): await trigger._poll_status() assert mock_sleep.call_count == 2 mock_sleep.assert_called_with(1) @pytest.mark.asyncio - @patch("ray_provider.triggers.RayJobTrigger._is_terminal_state") + @patch( + "ray_provider.triggers.RayJobTrigger.hook.get_ray_job_status", side_effect=ApiException("Failed to get job.") + ) @patch("ray_provider.triggers.RayJobTrigger.hook") @patch("ray_provider.triggers.RayJobTrigger.cleanup") - async def test_run_with_exception(self, mock_cleanup, mock_hook, mock_is_terminal, trigger): - mock_is_terminal.side_effect = Exception("Test exception") - + async def test_run_with_exception(self, mock_cleanup, mock_hook, mock_job_status, trigger): generator = trigger.run() event = await generator.asend(None) assert event == TriggerEvent( { - "status": str(JobStatus.FAILED), - "message": "Test exception", + "status": "EXCEPTION", + "message": "(Failed to get job.)\nReason: None\n", "job_id": "test_job_id", } )