From 555322d4c0ef3423bf286fabfff4bfe4bbd51b72 Mon Sep 17 00:00:00 2001 From: Tatiana Al-Chueyr Date: Thu, 28 Nov 2024 09:21:10 +0000 Subject: [PATCH 01/10] Add TERMINAL_JOB_STATUSES --- ray_provider/constants.py | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 ray_provider/constants.py diff --git a/ray_provider/constants.py b/ray_provider/constants.py new file mode 100644 index 0000000..d5d3703 --- /dev/null +++ b/ray_provider/constants.py @@ -0,0 +1,3 @@ +from ray.job_submission import JobStatus + +TERMINAL_JOB_STATUSES = {JobStatus.SUCCEEDED, JobStatus.STOPPED, JobStatus.FAILED} \ No newline at end of file From 9aad4de4f89a57c81ff91beaa942f9a00de22090 Mon Sep 17 00:00:00 2001 From: Tatiana Al-Chueyr Date: Thu, 28 Nov 2024 09:33:35 +0000 Subject: [PATCH 02/10] Stop catching generic `Exception` in triggerer By catching `Exception`, we run into the ristk of hitting an unexpected exception that the program can't recover from, or worse, swallowing an important exception without properly logging it - a huge headache when trying to debug programs that are failing in weird ways This was identified during #81 development. --- ray_provider/triggers.py | 77 ++++++++++++++++++++-------------------- tests/test_triggers.py | 75 +++++++++++--------------------------- 2 files changed, 58 insertions(+), 94 deletions(-) diff --git a/ray_provider/triggers.py b/ray_provider/triggers.py index 3252c99..151333a 100644 --- a/ray_provider/triggers.py +++ b/ray_provider/triggers.py @@ -5,8 +5,10 @@ from typing import Any, AsyncIterator from airflow.triggers.base import BaseTrigger, TriggerEvent +from kubernetes.client.exceptions import ApiException from ray.job_submission import JobStatus +from ray_provider.constants import TERMINAL_JOB_STATUSES from ray_provider.hooks import RayHook @@ -43,6 +45,7 @@ def __init__( self.gpu_device_plugin_yaml = gpu_device_plugin_yaml self.fetch_logs = fetch_logs self.poll_interval = poll_interval + self._job_status = None | JobStatus def serialize(self) -> tuple[str, dict[str, Any]]: """ @@ -81,22 +84,22 @@ async def cleanup(self) -> None: resources are not deleted. """ - try: - if self.ray_cluster_yaml: - self.log.info(f"Attempting to delete Ray cluster using YAML: {self.ray_cluster_yaml}") - loop = asyncio.get_running_loop() - await loop.run_in_executor( - None, self.hook.delete_ray_cluster, self.ray_cluster_yaml, self.gpu_device_plugin_yaml - ) - self.log.info("Ray cluster deletion process completed") - else: - self.log.info("No Ray cluster YAML provided, skipping cluster deletion") - except Exception as e: - self.log.error(f"Unexpected error during cleanup: {str(e)}") + if self.ray_cluster_yaml: + self.log.info(f"Attempting to delete Ray cluster using YAML: {self.ray_cluster_yaml}") + loop = asyncio.get_running_loop() + await loop.run_in_executor( + None, self.hook.delete_ray_cluster, self.ray_cluster_yaml, self.gpu_device_plugin_yaml + ) + self.log.info("Ray cluster deletion process completed") + else: + self.log.info("No Ray cluster YAML provided, skipping cluster deletion") async def _poll_status(self) -> None: - while not self._is_terminal_state(): + self._job_status = self.hook.get_ray_job_status(self.dashboard_url, self.job_id) + while self._job_status not in TERMINAL_JOB_STATUSES: + self.log.info(f"Status of job {self.job_id} is: {self._job_status}") await asyncio.sleep(self.poll_interval) + self._job_status = self.hook.get_ray_job_status(self.dashboard_url, self.job_id) async def _stream_logs(self) -> None: """ @@ -111,46 +114,42 @@ async def _stream_logs(self) -> None: async def run(self) -> AsyncIterator[TriggerEvent]: """ - Asynchronously polls the job status and yields events based on the job's state. + Asynchronously polls the Ray job status and yields events based on the job's state. This method gets job status at each poll interval and streams logs if available. It yields a TriggerEvent upon job completion, cancellation, or failure. :yield: TriggerEvent containing the status, message, and job ID related to the job. """ - try: - self.log.info(f"Polling for job {self.job_id} every {self.poll_interval} seconds...") + self.log.info(f"::group:: Trigger 1/2: Checking the job status") + self.log.info(f"Polling for job {self.job_id} every {self.poll_interval} seconds...") + try: tasks = [self._poll_status()] if self.fetch_logs: tasks.append(self._stream_logs()) - await asyncio.gather(*tasks) + except ApiException as e: + error_msg = str(e) + self.log.info(f"::endgroup::") + self.log.error("::group:: Trigger unable to poll job status") + self.log.error("Exception details:", exc_info=True) + self.log.info("Attempting to clean up...") + await self.cleanup() + self.log.info("Cleanup completed!") + self.log.info(f"::endgroup::") + + yield TriggerEvent({"status": "EXCEPTION", "message": error_msg, "job_id": self.job_id}) + else: + self.log.info(f"::endgroup::") + self.log.info(f"::group:: Trigger 2/2: Job reached a terminal state") + self.log.info(f"Status of completed job {self.job_id} is: {self._job_status}") + self.log.info(f"::endgroup::") - completed_status = self.hook.get_ray_job_status(self.dashboard_url, self.job_id) - self.log.info(f"Status of completed job {self.job_id} is: {completed_status}") yield TriggerEvent( { - "status": completed_status, - "message": f"Job {self.job_id} completed with status {completed_status}", + "status": self._job_status, + "message": f"Job {self.job_id} completed with status {self._job_status}", "job_id": self.job_id, } ) - except Exception as e: - self.log.error(f"Error occurred: {str(e)}") - await self.cleanup() - yield TriggerEvent({"status": str(JobStatus.FAILED), "message": str(e), "job_id": self.job_id}) - - def _is_terminal_state(self) -> bool: - """ - Checks if the Ray job is in a terminal state. - - A terminal state is one of the following: SUCCEEDED, STOPPED, or FAILED. - - :return: True if the job is in a terminal state, False otherwise. - """ - return self.hook.get_ray_job_status(self.dashboard_url, self.job_id) in ( - JobStatus.SUCCEEDED, - JobStatus.STOPPED, - JobStatus.FAILED, - ) diff --git a/tests/test_triggers.py b/tests/test_triggers.py index f97611e..035a12f 100644 --- a/tests/test_triggers.py +++ b/tests/test_triggers.py @@ -3,6 +3,7 @@ import pytest from airflow.triggers.base import TriggerEvent +from kubernetes.client.exceptions import ApiException from ray.job_submission import JobStatus from ray_provider.triggers import RayJobTrigger @@ -22,11 +23,9 @@ def trigger(self): ) @pytest.mark.asyncio - @patch("ray_provider.triggers.RayJobTrigger._is_terminal_state") + @patch("ray_provider.triggers.RayJobTrigger.hook.get_ray_job_status", return_value=JobStatus.FAILED) @patch("ray_provider.triggers.RayJobTrigger.hook") - async def test_run_no_job_id(self, mock_hook, mock_is_terminal): - mock_is_terminal.return_value = True - mock_hook.get_ray_job_status.return_value = JobStatus.FAILED + async def test_run_no_job_id(self, mock_hook, mock_job_status): trigger = RayJobTrigger( job_id="", poll_interval=1, @@ -42,11 +41,9 @@ async def test_run_no_job_id(self, mock_hook, mock_is_terminal): ) @pytest.mark.asyncio - @patch("ray_provider.triggers.RayJobTrigger._is_terminal_state") + @patch("ray_provider.triggers.RayJobTrigger.hook.get_ray_job_status", side_effect=[None, JobStatus.SUCCEEDED]) @patch("ray_provider.triggers.RayJobTrigger.hook") - async def test_run_job_succeeded(self, mock_hook, mock_is_terminal): - mock_is_terminal.side_effect = [False, True] - mock_hook.get_ray_job_status.return_value = JobStatus.SUCCEEDED + async def test_run_job_succeeded(self, mock_hook, mock_job_status): trigger = RayJobTrigger( job_id="test_job_id", poll_interval=1, @@ -66,12 +63,9 @@ async def test_run_job_succeeded(self, mock_hook, mock_is_terminal): ) @pytest.mark.asyncio - @patch("ray_provider.triggers.RayJobTrigger._is_terminal_state") + @patch("ray_provider.triggers.RayJobTrigger.hook.get_ray_job_status", side_effect=[None, JobStatus.STOPPED]) @patch("ray_provider.triggers.RayJobTrigger.hook") - async def test_run_job_stopped(self, mock_hook, mock_is_terminal, trigger): - mock_is_terminal.side_effect = [False, True] - mock_hook.get_ray_job_status.return_value = JobStatus.STOPPED - + async def test_run_job_stopped(self, mock_hook, mock_job_status, trigger): generator = trigger.run() event = await generator.asend(None) @@ -84,12 +78,9 @@ async def test_run_job_stopped(self, mock_hook, mock_is_terminal, trigger): ) @pytest.mark.asyncio - @patch("ray_provider.triggers.RayJobTrigger._is_terminal_state") + @patch("ray_provider.triggers.RayJobTrigger.hook.get_ray_job_status", side_effect=[None, JobStatus.FAILED]) @patch("ray_provider.triggers.RayJobTrigger.hook") - async def test_run_job_failed(self, mock_hook, mock_is_terminal, trigger): - mock_is_terminal.side_effect = [False, True] - mock_hook.get_ray_job_status.return_value = JobStatus.FAILED - + async def test_run_job_failed(self, mock_hook, mock_job_status, trigger): generator = trigger.run() event = await generator.asend(None) @@ -102,12 +93,10 @@ async def test_run_job_failed(self, mock_hook, mock_is_terminal, trigger): ) @pytest.mark.asyncio - @patch("ray_provider.triggers.RayJobTrigger._is_terminal_state") + @patch("ray_provider.triggers.RayJobTrigger.hook.get_ray_job_status", side_effect=[None, JobStatus.SUCCEEDED]) @patch("ray_provider.triggers.RayJobTrigger.hook") @patch("ray_provider.triggers.RayJobTrigger._stream_logs") - async def test_run_with_log_streaming(self, mock_stream_logs, mock_hook, mock_is_terminal, trigger): - mock_is_terminal.side_effect = [False, True] - mock_hook.get_ray_job_status.return_value = JobStatus.SUCCEEDED + async def test_run_with_log_streaming(self, mock_stream_logs, mock_hook, mock_job_status, trigger): mock_stream_logs.return_value = None generator = trigger.run() @@ -156,19 +145,6 @@ def test_serialize(self, trigger): }, ) - @pytest.mark.asyncio - @patch("ray_provider.triggers.RayJobTrigger.hook") - async def test_is_terminal_state(self, mock_hook, trigger): - mock_hook.get_ray_job_status.side_effect = [ - JobStatus.PENDING, - JobStatus.RUNNING, - JobStatus.SUCCEEDED, - ] - - assert not trigger._is_terminal_state() - assert not trigger._is_terminal_state() - assert trigger._is_terminal_state() - @pytest.mark.asyncio @patch.object(RayJobTrigger, "hook") @patch.object(logging.Logger, "info") @@ -200,41 +176,30 @@ async def test_cleanup_without_cluster_yaml(self, mock_log_info): mock_log_info.assert_called_once_with("No Ray cluster YAML provided, skipping cluster deletion") - @pytest.mark.asyncio - @patch.object(RayJobTrigger, "hook") - @patch.object(logging.Logger, "error") - async def test_cleanup_with_exception(self, mock_log_error, mock_hook, trigger): - mock_hook.delete_ray_cluster.side_effect = Exception("Test exception") - - await trigger.cleanup() - - mock_log_error.assert_called_once_with("Unexpected error during cleanup: Test exception") - @pytest.mark.asyncio @patch("asyncio.sleep", new_callable=AsyncMock) - @patch("ray_provider.triggers.RayJobTrigger._is_terminal_state") - async def test_poll_status(self, mock_is_terminal, mock_sleep, trigger): - mock_is_terminal.side_effect = [False, False, True] - + @patch("ray_provider.triggers.RayJobTrigger.hook.get_ray_job_status", side_effect=[None, None, JobStatus.SUCCEEDED]) + @patch("ray_provider.triggers.RayJobTrigger.hook") + async def test_poll_status(self, mock_hook, mock_job_status, mock_sleep, trigger): await trigger._poll_status() assert mock_sleep.call_count == 2 mock_sleep.assert_called_with(1) @pytest.mark.asyncio - @patch("ray_provider.triggers.RayJobTrigger._is_terminal_state") + @patch( + "ray_provider.triggers.RayJobTrigger.hook.get_ray_job_status", side_effect=ApiException("Failed to get job.") + ) @patch("ray_provider.triggers.RayJobTrigger.hook") @patch("ray_provider.triggers.RayJobTrigger.cleanup") - async def test_run_with_exception(self, mock_cleanup, mock_hook, mock_is_terminal, trigger): - mock_is_terminal.side_effect = Exception("Test exception") - + async def test_run_with_exception(self, mock_cleanup, mock_hook, mock_job_status, trigger): generator = trigger.run() event = await generator.asend(None) assert event == TriggerEvent( { - "status": str(JobStatus.FAILED), - "message": "Test exception", + "status": "EXCEPTION", + "message": "(Failed to get job.)\nReason: None\n", "job_id": "test_job_id", } ) From 9eea5c6c14c6b0093c34d0d419b83165e5416886 Mon Sep 17 00:00:00 2001 From: Tatiana Al-Chueyr Date: Thu, 28 Nov 2024 11:09:29 +0000 Subject: [PATCH 03/10] Update tests/test_triggers.py --- tests/test_triggers.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_triggers.py b/tests/test_triggers.py index 035a12f..32a3185 100644 --- a/tests/test_triggers.py +++ b/tests/test_triggers.py @@ -3,6 +3,7 @@ import pytest from airflow.triggers.base import TriggerEvent +from __future__ import annotations from kubernetes.client.exceptions import ApiException from ray.job_submission import JobStatus From f0169f888ce10dcb42618ebd057c9737ce0ef1a7 Mon Sep 17 00:00:00 2001 From: Tatiana Al-Chueyr Date: Thu, 28 Nov 2024 11:09:41 +0000 Subject: [PATCH 04/10] Update tests/test_triggers.py --- tests/test_triggers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_triggers.py b/tests/test_triggers.py index 32a3185..992bcd0 100644 --- a/tests/test_triggers.py +++ b/tests/test_triggers.py @@ -42,7 +42,7 @@ async def test_run_no_job_id(self, mock_hook, mock_job_status): ) @pytest.mark.asyncio - @patch("ray_provider.triggers.RayJobTrigger.hook.get_ray_job_status", side_effect=[None, JobStatus.SUCCEEDED]) + @patch("ray_provider.triggers.RayJobTrigger.hook.get_ray_job_status", side_effect=[JobStatus.RUNNING, JobStatus.SUCCEEDED]) @patch("ray_provider.triggers.RayJobTrigger.hook") async def test_run_job_succeeded(self, mock_hook, mock_job_status): trigger = RayJobTrigger( From 8239cc6c3de33714e4651f3fc6cac5ec1e6aab08 Mon Sep 17 00:00:00 2001 From: Tatiana Al-Chueyr Date: Thu, 28 Nov 2024 11:10:03 +0000 Subject: [PATCH 05/10] Apply suggestions from code review --- tests/test_triggers.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_triggers.py b/tests/test_triggers.py index 992bcd0..6008a4c 100644 --- a/tests/test_triggers.py +++ b/tests/test_triggers.py @@ -64,7 +64,7 @@ async def test_run_job_succeeded(self, mock_hook, mock_job_status): ) @pytest.mark.asyncio - @patch("ray_provider.triggers.RayJobTrigger.hook.get_ray_job_status", side_effect=[None, JobStatus.STOPPED]) + @patch("ray_provider.triggers.RayJobTrigger.hook.get_ray_job_status", side_effect=[JobStatus.RUNNING, JobStatus.STOPPED]) @patch("ray_provider.triggers.RayJobTrigger.hook") async def test_run_job_stopped(self, mock_hook, mock_job_status, trigger): generator = trigger.run() @@ -79,7 +79,7 @@ async def test_run_job_stopped(self, mock_hook, mock_job_status, trigger): ) @pytest.mark.asyncio - @patch("ray_provider.triggers.RayJobTrigger.hook.get_ray_job_status", side_effect=[None, JobStatus.FAILED]) + @patch("ray_provider.triggers.RayJobTrigger.hook.get_ray_job_status", side_effect=[JobStatus.RUNNING, JobStatus.FAILED]) @patch("ray_provider.triggers.RayJobTrigger.hook") async def test_run_job_failed(self, mock_hook, mock_job_status, trigger): generator = trigger.run() @@ -94,7 +94,7 @@ async def test_run_job_failed(self, mock_hook, mock_job_status, trigger): ) @pytest.mark.asyncio - @patch("ray_provider.triggers.RayJobTrigger.hook.get_ray_job_status", side_effect=[None, JobStatus.SUCCEEDED]) + @patch("ray_provider.triggers.RayJobTrigger.hook.get_ray_job_status", side_effect=[JobStatus.RUNNING, JobStatus.SUCCEEDED]) @patch("ray_provider.triggers.RayJobTrigger.hook") @patch("ray_provider.triggers.RayJobTrigger._stream_logs") async def test_run_with_log_streaming(self, mock_stream_logs, mock_hook, mock_job_status, trigger): @@ -179,7 +179,7 @@ async def test_cleanup_without_cluster_yaml(self, mock_log_info): @pytest.mark.asyncio @patch("asyncio.sleep", new_callable=AsyncMock) - @patch("ray_provider.triggers.RayJobTrigger.hook.get_ray_job_status", side_effect=[None, None, JobStatus.SUCCEEDED]) + @patch("ray_provider.triggers.RayJobTrigger.hook.get_ray_job_status", side_effect=[JobStatus.RUNNING, JobStatus.RUNNING, JobStatus.SUCCEEDED]) @patch("ray_provider.triggers.RayJobTrigger.hook") async def test_poll_status(self, mock_hook, mock_job_status, mock_sleep, trigger): await trigger._poll_status() From ccbca8c9647ab724ebfa4edf22d4f512ecc89fe5 Mon Sep 17 00:00:00 2001 From: Tatiana Al-Chueyr Date: Thu, 28 Nov 2024 11:12:38 +0000 Subject: [PATCH 06/10] Fix import position --- tests/test_triggers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_triggers.py b/tests/test_triggers.py index 6008a4c..055d715 100644 --- a/tests/test_triggers.py +++ b/tests/test_triggers.py @@ -1,9 +1,9 @@ +from __future__ import annotations import logging from unittest.mock import AsyncMock, call, patch import pytest from airflow.triggers.base import TriggerEvent -from __future__ import annotations from kubernetes.client.exceptions import ApiException from ray.job_submission import JobStatus From 1cdae2a1a3bfeebf7c9212a53afb8bd3418e0caa Mon Sep 17 00:00:00 2001 From: Tatiana Al-Chueyr Date: Thu, 28 Nov 2024 11:27:17 +0000 Subject: [PATCH 07/10] Fix issue in the CI --- ray_provider/triggers.py | 2 +- tests/test_triggers.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/ray_provider/triggers.py b/ray_provider/triggers.py index 151333a..f6d9358 100644 --- a/ray_provider/triggers.py +++ b/ray_provider/triggers.py @@ -45,7 +45,7 @@ def __init__( self.gpu_device_plugin_yaml = gpu_device_plugin_yaml self.fetch_logs = fetch_logs self.poll_interval = poll_interval - self._job_status = None | JobStatus + self._job_status: None | JobStatus = None def serialize(self) -> tuple[str, dict[str, Any]]: """ diff --git a/tests/test_triggers.py b/tests/test_triggers.py index 055d715..86cc383 100644 --- a/tests/test_triggers.py +++ b/tests/test_triggers.py @@ -1,4 +1,5 @@ from __future__ import annotations + import logging from unittest.mock import AsyncMock, call, patch From b236d81d49affd7c3f7c66a2c451341702b51081 Mon Sep 17 00:00:00 2001 From: Tatiana Al-Chueyr Date: Thu, 28 Nov 2024 11:29:07 +0000 Subject: [PATCH 08/10] Add blank line --- ray_provider/constants.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ray_provider/constants.py b/ray_provider/constants.py index d5d3703..3242e4a 100644 --- a/ray_provider/constants.py +++ b/ray_provider/constants.py @@ -1,3 +1,4 @@ from ray.job_submission import JobStatus -TERMINAL_JOB_STATUSES = {JobStatus.SUCCEEDED, JobStatus.STOPPED, JobStatus.FAILED} \ No newline at end of file +TERMINAL_JOB_STATUSES = {JobStatus.SUCCEEDED, JobStatus.STOPPED, JobStatus.FAILED} + From 2d7cdc2a789e5a11da168322afdd9a745150fac6 Mon Sep 17 00:00:00 2001 From: Tatiana Al-Chueyr Date: Thu, 28 Nov 2024 12:16:08 +0000 Subject: [PATCH 09/10] Fix static checks --- Makefile | 3 +++ 1 file changed, 3 insertions(+) diff --git a/Makefile b/Makefile index 9a1c12b..c0434cf 100644 --- a/Makefile +++ b/Makefile @@ -14,6 +14,9 @@ build-whl: setup-dev ## Build installable whl file # Delete any previous wheels, so different versions don't conflict rm dev/include/* cd dev + # delete potential previous versions, otherwise there will be a conflict + # during installation + rm include/* python3 -m build --outdir dev/include/ .PHONY: docker-run From 866cbc95f159a517e540a4abf2f863c4a396da30 Mon Sep 17 00:00:00 2001 From: Tatiana Al-Chueyr Date: Thu, 28 Nov 2024 12:21:37 +0000 Subject: [PATCH 10/10] Fix static checks --- ray_provider/constants.py | 1 - tests/test_triggers.py | 24 +++++++++++++++++++----- 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/ray_provider/constants.py b/ray_provider/constants.py index 3242e4a..9e37f6a 100644 --- a/ray_provider/constants.py +++ b/ray_provider/constants.py @@ -1,4 +1,3 @@ from ray.job_submission import JobStatus TERMINAL_JOB_STATUSES = {JobStatus.SUCCEEDED, JobStatus.STOPPED, JobStatus.FAILED} - diff --git a/tests/test_triggers.py b/tests/test_triggers.py index 86cc383..fff3fee 100644 --- a/tests/test_triggers.py +++ b/tests/test_triggers.py @@ -43,7 +43,10 @@ async def test_run_no_job_id(self, mock_hook, mock_job_status): ) @pytest.mark.asyncio - @patch("ray_provider.triggers.RayJobTrigger.hook.get_ray_job_status", side_effect=[JobStatus.RUNNING, JobStatus.SUCCEEDED]) + @patch( + "ray_provider.triggers.RayJobTrigger.hook.get_ray_job_status", + side_effect=[JobStatus.RUNNING, JobStatus.SUCCEEDED], + ) @patch("ray_provider.triggers.RayJobTrigger.hook") async def test_run_job_succeeded(self, mock_hook, mock_job_status): trigger = RayJobTrigger( @@ -65,7 +68,10 @@ async def test_run_job_succeeded(self, mock_hook, mock_job_status): ) @pytest.mark.asyncio - @patch("ray_provider.triggers.RayJobTrigger.hook.get_ray_job_status", side_effect=[JobStatus.RUNNING, JobStatus.STOPPED]) + @patch( + "ray_provider.triggers.RayJobTrigger.hook.get_ray_job_status", + side_effect=[JobStatus.RUNNING, JobStatus.STOPPED], + ) @patch("ray_provider.triggers.RayJobTrigger.hook") async def test_run_job_stopped(self, mock_hook, mock_job_status, trigger): generator = trigger.run() @@ -80,7 +86,9 @@ async def test_run_job_stopped(self, mock_hook, mock_job_status, trigger): ) @pytest.mark.asyncio - @patch("ray_provider.triggers.RayJobTrigger.hook.get_ray_job_status", side_effect=[JobStatus.RUNNING, JobStatus.FAILED]) + @patch( + "ray_provider.triggers.RayJobTrigger.hook.get_ray_job_status", side_effect=[JobStatus.RUNNING, JobStatus.FAILED] + ) @patch("ray_provider.triggers.RayJobTrigger.hook") async def test_run_job_failed(self, mock_hook, mock_job_status, trigger): generator = trigger.run() @@ -95,7 +103,10 @@ async def test_run_job_failed(self, mock_hook, mock_job_status, trigger): ) @pytest.mark.asyncio - @patch("ray_provider.triggers.RayJobTrigger.hook.get_ray_job_status", side_effect=[JobStatus.RUNNING, JobStatus.SUCCEEDED]) + @patch( + "ray_provider.triggers.RayJobTrigger.hook.get_ray_job_status", + side_effect=[JobStatus.RUNNING, JobStatus.SUCCEEDED], + ) @patch("ray_provider.triggers.RayJobTrigger.hook") @patch("ray_provider.triggers.RayJobTrigger._stream_logs") async def test_run_with_log_streaming(self, mock_stream_logs, mock_hook, mock_job_status, trigger): @@ -180,7 +191,10 @@ async def test_cleanup_without_cluster_yaml(self, mock_log_info): @pytest.mark.asyncio @patch("asyncio.sleep", new_callable=AsyncMock) - @patch("ray_provider.triggers.RayJobTrigger.hook.get_ray_job_status", side_effect=[JobStatus.RUNNING, JobStatus.RUNNING, JobStatus.SUCCEEDED]) + @patch( + "ray_provider.triggers.RayJobTrigger.hook.get_ray_job_status", + side_effect=[JobStatus.RUNNING, JobStatus.RUNNING, JobStatus.SUCCEEDED], + ) @patch("ray_provider.triggers.RayJobTrigger.hook") async def test_poll_status(self, mock_hook, mock_job_status, mock_sleep, trigger): await trigger._poll_status()