From fadce492c5c63baf8cb06190f6389521c2babaf0 Mon Sep 17 00:00:00 2001 From: Tatiana Al-Chueyr Date: Fri, 29 Nov 2024 11:32:56 +0000 Subject: [PATCH] Stop catching generic Exception in operators (#100) By catching `Exception`, we risk hitting an unexpected exception that the program can't recover from, or worse, swallowing an important exception without properly logging it - a massive headache when trying to debug programs that are failing in weird ways. If this change raises any exceptions that should be caught, we'll have the opportunity to understand which exceptions to capture and handle them gracefully. This was identified during the #81 development. --- ray_provider/exceptions.py | 2 + ray_provider/operators.py | 163 +++++++++++--------- tests/test_operators.py | 304 +++++++++++++++++++------------------ 3 files changed, 244 insertions(+), 225 deletions(-) create mode 100644 ray_provider/exceptions.py diff --git a/ray_provider/exceptions.py b/ray_provider/exceptions.py new file mode 100644 index 0000000..5a7a101 --- /dev/null +++ b/ray_provider/exceptions.py @@ -0,0 +1,2 @@ +class RayAirflowException(Exception): + pass diff --git a/ray_provider/operators.py b/ray_provider/operators.py index b6a38d6..3a0776e 100644 --- a/ray_provider/operators.py +++ b/ray_provider/operators.py @@ -4,12 +4,14 @@ from functools import cached_property from typing import Any -from airflow.exceptions import AirflowException from airflow.models import BaseOperator from airflow.providers.cncf.kubernetes.utils.pod_manager import PodOperatorHookProtocol from airflow.utils.context import Context +from kubernetes.client.exceptions import ApiException from ray.job_submission import JobStatus +from ray_provider.constants import TERMINAL_JOB_STATUSES +from ray_provider.exceptions import RayAirflowException from ray_provider.hooks import RayHook from ray_provider.triggers import RayJobTrigger @@ -41,7 +43,7 @@ def __init__( self.gpu_device_plugin_yaml = gpu_device_plugin_yaml self.update_if_exists = update_if_exists - @cached_property + @property def hook(self) -> RayHook: """Lazily initialize and return the RayHook.""" return RayHook(conn_id=self.conn_id) @@ -52,6 +54,7 @@ def execute(self, context: Context) -> None: :param context: The context in which the operator is being executed. """ + self.log.info(f"Trying to setup the ray cluster defined in {self.ray_cluster_yaml}") self.hook.setup_ray_cluster( context=context, ray_cluster_yaml=self.ray_cluster_yaml, @@ -59,6 +62,7 @@ def execute(self, context: Context) -> None: gpu_device_plugin_yaml=self.gpu_device_plugin_yaml, update_if_exists=self.update_if_exists, ) + self.log.info("Finished setting up the ray cluster.") class DeleteRayCluster(BaseOperator): @@ -82,7 +86,7 @@ def __init__( self.ray_cluster_yaml = ray_cluster_yaml self.gpu_device_plugin_yaml = gpu_device_plugin_yaml - @cached_property + @property def hook(self) -> PodOperatorHookProtocol: """Lazily initialize and return the RayHook.""" return RayHook(conn_id=self.conn_id) @@ -93,7 +97,9 @@ def execute(self, context: Context) -> None: :param context: The context in which the operator is being executed. """ + self.log.info(f"Trying to delete the ray cluster defined in {self.ray_cluster_yaml}") self.hook.delete_ray_cluster(self.ray_cluster_yaml, self.gpu_device_plugin_yaml) + self.log.info("Finished deleting the ray cluster.") class SubmitRayJob(BaseOperator): @@ -173,7 +179,6 @@ def __init__( self.xcom_task_key = xcom_task_key self.dashboard_url: str | None = None self.job_id = "" - self.terminal_states = {JobStatus.SUCCEEDED, JobStatus.STOPPED, JobStatus.FAILED} def on_kill(self) -> None: """ @@ -226,28 +231,36 @@ def _setup_cluster(self, context: Context) -> None: Set up the Ray cluster if a cluster YAML is provided. :param context: The context in which the task is being executed. - :raises Exception: If there's an error during cluster setup. """ if self.ray_cluster_yaml: - self.hook.setup_ray_cluster( - context=context, - ray_cluster_yaml=self.ray_cluster_yaml, - kuberay_version=self.kuberay_version, - gpu_device_plugin_yaml=self.gpu_device_plugin_yaml, - update_if_exists=self.update_if_exists, - ) + try: + self.hook.setup_ray_cluster( + context=context, + ray_cluster_yaml=self.ray_cluster_yaml, + kuberay_version=self.kuberay_version, + gpu_device_plugin_yaml=self.gpu_device_plugin_yaml, + update_if_exists=self.update_if_exists, + ) + except ApiException as e: + self.log.info(f"Unable to setup the Ray cluster using {self.ray_cluster_yaml}") + self.log.error("Exception details:", exc_info=True) + self.log.info("Trying to delete any parts of the RayCluster that may have been spun up...") + self._delete_cluster() + raise e + else: + self.log.info(f"Skipping setting up a Ray cluster because no `ray_cluster_yaml` was given.") def _delete_cluster(self) -> None: """ Delete the Ray cluster if a cluster YAML is provided. - - :raises Exception: If there's an error during cluster deletion. """ if self.ray_cluster_yaml: self.hook.delete_ray_cluster( ray_cluster_yaml=self.ray_cluster_yaml, gpu_device_plugin_yaml=self.gpu_device_plugin_yaml, ) + else: + self.log.info(f"Skipping deleting the Ray cluster because no `ray_cluster_yaml` was given.") def execute(self, context: Context) -> str: """ @@ -258,58 +271,51 @@ def execute(self, context: Context) -> str: :param context: The context in which the task is being executed. :return: The job ID of the submitted Ray job. - :raises AirflowException: If the job fails, is cancelled, or reaches an unexpected state. """ - try: - self._setup_cluster(context=context) - - self.dashboard_url = self._get_dashboard_url(context) - - self.job_id = self.hook.submit_ray_job( - dashboard_url=self.dashboard_url, - entrypoint=self.entrypoint, - runtime_env=self.runtime_env, - entrypoint_num_cpus=self.num_cpus, - entrypoint_num_gpus=self.num_gpus, - entrypoint_memory=self.memory, - entrypoint_resources=self.ray_resources, - ) - self.log.info(f"Ray job submitted with id: {self.job_id}") - - if self.wait_for_completion: - current_status = self.hook.get_ray_job_status(self.dashboard_url, self.job_id) - self.log.info(f"Current job status for {self.job_id} is: {current_status}") - - if current_status not in self.terminal_states: - self.log.info("Deferring the polling to RayJobTrigger...") - self.defer( - trigger=RayJobTrigger( - job_id=self.job_id, - conn_id=self.conn_id, - xcom_dashboard_url=self.dashboard_url, - ray_cluster_yaml=self.ray_cluster_yaml, - gpu_device_plugin_yaml=self.gpu_device_plugin_yaml, - poll_interval=self.poll_interval, - fetch_logs=self.fetch_logs, - ), - method_name="execute_complete", - timeout=self.job_timeout_seconds, - ) - elif current_status == JobStatus.SUCCEEDED: - self.log.info("Job %s completed successfully", self.job_id) - elif current_status == JobStatus.FAILED: - raise AirflowException(f"Job failed:\n{self.job_id}") - elif current_status == JobStatus.STOPPED: - raise AirflowException(f"Job was cancelled:\n{self.job_id}") - else: - raise AirflowException( - f"Encountered unexpected state `{current_status}` for job_id `{self.job_id}`" - ) - return self.job_id - except Exception as e: - self._delete_cluster() - raise AirflowException(f"SubmitRayJob operator failed due to {e}. Cleaning up resources...") + self.log.info("::group:: (SubmitJob 1/5) Setup Cluster") + self._setup_cluster(context=context) + self.log.info("::endgroup::") + + self.log.info("::group:: (SubmitJob 2/5) Identify Dashboard URL") + self.dashboard_url = self._get_dashboard_url(context) + self.log.info("::endgroup::") + + self.log.info("::group:: (SubmitJob 3/5) Submit job") + self.log.info(f"Ray job with id {self.job_id} submitted") + self.job_id = self.hook.submit_ray_job( + dashboard_url=self.dashboard_url, + entrypoint=self.entrypoint, + runtime_env=self.runtime_env, + entrypoint_num_cpus=self.num_cpus, + entrypoint_num_gpus=self.num_gpus, + entrypoint_memory=self.memory, + entrypoint_resources=self.ray_resources, + ) + self.log.info("::endgroup::") + + self.log.info("::group:: (SubmitJob 4/5) Wait for completion") + if self.wait_for_completion: + current_status = self.hook.get_ray_job_status(self.dashboard_url, self.job_id) + self.log.info(f"Current job status for {self.job_id} is: {current_status}") + + if current_status not in TERMINAL_JOB_STATUSES: + self.log.info("Deferring the polling to RayJobTrigger...") + self.defer( + trigger=RayJobTrigger( + job_id=self.job_id, + conn_id=self.conn_id, + xcom_dashboard_url=self.dashboard_url, + ray_cluster_yaml=self.ray_cluster_yaml, + gpu_device_plugin_yaml=self.gpu_device_plugin_yaml, + poll_interval=self.poll_interval, + fetch_logs=self.fetch_logs, + ), + method_name="execute_complete", + timeout=self.job_timeout_seconds, + ) + + return self.job_id def execute_complete(self, context: Context, event: dict[str, Any]) -> None: """ @@ -320,15 +326,24 @@ def execute_complete(self, context: Context, event: dict[str, Any]) -> None: :param context: The context in which the task is being executed. :param event: The event containing the job execution result. - :raises AirflowException: If the job execution fails, is cancelled, or reaches an unexpected state. + :raises RayAirflowException: If the job execution fails, is cancelled, or reaches an unexpected state. """ - try: - if event["status"] in [JobStatus.STOPPED, JobStatus.FAILED]: - self.log.info(f"Ray job {self.job_id} execution not completed successfully...") - raise AirflowException(f"Job {self.job_id} {event['status'].lower()}: {event['message']}") - elif event["status"] == JobStatus.SUCCEEDED: - self.log.info(f"Ray job {self.job_id} execution succeeded.") + self.log.info("::endgroup::") + self.log.info("::group:: (SubmitJob 5/5) Execution completed") + + self._delete_cluster() + + job_status = event["status"] + if job_status == JobStatus.SUCCEEDED: + self.log.info("Job %s completed successfully", self.job_id) + return + else: + self.log.info(f"Ray job {self.job_id} execution not completed successfully...") + if job_status in (JobStatus.FAILED, JobStatus.STOPPED): + msg = f"Job {self.job_id} {job_status.lower()}: {event['message']}" else: - raise AirflowException(f"Unexpected event status for job {self.job_id}: {event['status']}") - finally: - self._delete_cluster() + msg = f"Encountered unexpected state `{job_status}` for job_id `{self.job_id}`" + + self.log.info("::endgroup::") + + raise RayAirflowException(msg) diff --git a/tests/test_operators.py b/tests/test_operators.py index 72df6b5..a351f3b 100644 --- a/tests/test_operators.py +++ b/tests/test_operators.py @@ -2,22 +2,16 @@ from unittest.mock import MagicMock, Mock, patch import pytest -from airflow.exceptions import AirflowException, TaskDeferred +from airflow.exceptions import TaskDeferred +from kubernetes.client.exceptions import ApiException from ray.job_submission import JobStatus +from ray_provider.exceptions import RayAirflowException from ray_provider.operators import DeleteRayCluster, SetupRayCluster, SubmitRayJob from ray_provider.triggers import RayJobTrigger class TestSetupRayCluster: - @pytest.fixture - def mock_hook(self): - with patch("ray_provider.operators.RayHook") as mock: - yield mock.return_value - - @pytest.fixture - def operator(self): - return SetupRayCluster(task_id="test_setup_ray_cluster", conn_id="test_conn", ray_cluster_yaml="cluster.yaml") def test_init(self): operator = SetupRayCluster( @@ -47,16 +41,25 @@ def test_init_default_values(self): ) assert operator.update_if_exists is False - def test_hook_property(self, operator): - with patch("ray_provider.operators.RayHook") as mock_ray_hook: - hook = operator.hook - mock_ray_hook.assert_called_once_with(conn_id=operator.conn_id) - assert hook == mock_ray_hook.return_value + @patch("ray_provider.operators.RayHook") + def test_hook_property(self, mock_ray_hook): + operator = SetupRayCluster( + task_id="test_setup_ray_cluster", conn_id="test_conn", ray_cluster_yaml="cluster.yaml" + ) + operator.hook + mock_ray_hook.assert_called_once_with(conn_id=operator.conn_id) + + @patch("ray_provider.operators.SetupRayCluster.hook.setup_ray_cluster") + @patch("ray_provider.operators.SetupRayCluster.hook") + def test_execute(self, mock_ray_hook, mock_setup_ray_cluster): + operator = SetupRayCluster( + task_id="test_setup_ray_cluster", conn_id="test_conn", ray_cluster_yaml="cluster.yaml" + ) - def test_execute(self, operator, mock_hook): context = MagicMock() operator.execute(context) - mock_hook.setup_ray_cluster.assert_called_once_with( + + mock_setup_ray_cluster.assert_called_once_with( context=context, ray_cluster_yaml=operator.ray_cluster_yaml, kuberay_version=operator.kuberay_version, @@ -66,14 +69,6 @@ def test_execute(self, operator, mock_hook): class TestDeleteRayCluster: - @pytest.fixture - def mock_hook(self): - with patch("ray_provider.operators.RayHook") as mock: - yield mock.return_value - - @pytest.fixture - def operator(self): - return DeleteRayCluster(task_id="test_delete_ray_cluster", conn_id="test_conn", ray_cluster_yaml="cluster.yaml") def test_init(self): operator = DeleteRayCluster( @@ -97,13 +92,19 @@ def test_init_default_gpu_plugin(self): == "https://raw.githubusercontent.com/NVIDIA/k8s-device-plugin/v0.9.0/nvidia-device-plugin.yml" ) - def test_hook_property(self, operator): - with patch("ray_provider.operators.RayHook") as mock_ray_hook: - hook = operator.hook - mock_ray_hook.assert_called_once_with(conn_id=operator.conn_id) - assert hook == mock_ray_hook.return_value + @patch("ray_provider.operators.RayHook") + def test_hook_property(self, mock_ray_hook): + operator = DeleteRayCluster( + task_id="test_delete_ray_cluster", conn_id="test_conn", ray_cluster_yaml="cluster.yaml" + ) + operator.hook + mock_ray_hook.assert_called_once_with(conn_id=operator.conn_id) - def test_execute(self, operator, mock_hook): + @patch("ray_provider.operators.DeleteRayCluster.hook") + def test_execute(self, mock_hook): + operator = DeleteRayCluster( + task_id="test_delete_ray_cluster", conn_id="test_conn", ray_cluster_yaml="cluster.yaml" + ) context = MagicMock() operator.execute(context) mock_hook.delete_ray_cluster.assert_called_once_with(operator.ray_cluster_yaml, operator.gpu_device_plugin_yaml) @@ -111,15 +112,6 @@ def test_execute(self, operator, mock_hook): class TestSubmitRayJob: - @pytest.fixture - def mock_hook(self): - with patch("ray_provider.operators.RayHook") as mock: - yield mock.return_value - - @pytest.fixture - def operator(self): - return SubmitRayJob(task_id="test_task", conn_id="test_conn", entrypoint="python script.py", runtime_env={}) - @pytest.fixture def task_instance(self): return Mock() @@ -188,18 +180,24 @@ def test_init_no_timeout(self): ) assert operator.job_timeout_seconds is None - def test_on_kill(self, mock_hook): - operator = SubmitRayJob(task_id="test_task", conn_id="test_conn", entrypoint="python script.py", runtime_env={}) + @patch("ray_provider.operators.SubmitRayJob._delete_cluster") + @patch("ray_provider.operators.SubmitRayJob.hook.delete_ray_job") + @patch("ray_provider.operators.SubmitRayJob.hook") + def test_on_kill(self, mock_hook, mock_delete_ray_job, mock_delete_cluster): + operator = SubmitRayJob( + task_id="test_task", + conn_id="test_conn", + entrypoint="python script.py", + runtime_env={}, + ray_cluster_yaml="cluster.yaml", + ) operator.job_id = "test_job_id" - operator.hook = mock_hook operator.dashboard_url = "http://dashboard.url" - operator.ray_cluster_yaml = "cluster.yaml" - with patch.object(operator, "_delete_cluster") as mock_delete_cluster: - operator.on_kill() + operator.on_kill() - mock_hook.delete_ray_job.assert_called_once_with("http://dashboard.url", "test_job_id") - mock_delete_cluster.assert_called_once() + mock_delete_ray_job.assert_called_once_with("http://dashboard.url", "test_job_id") + mock_delete_cluster.assert_called_once() def test_get_dashboard_url_with_xcom(self, context, task_instance): operator = SubmitRayJob( @@ -241,12 +239,11 @@ def test_setup_cluster(self, mock_ray_hook, context): gpu_device_plugin_yaml="https://example.com/plugin.yml", ) - mock_hook = mock_ray_hook.return_value - operator.hook = mock_hook + operator.hook = mock_ray_hook.return_value operator._setup_cluster(context) - mock_hook.setup_ray_cluster.assert_called_once_with( + mock_ray_hook.return_value.setup_ray_cluster.assert_called_once_with( context=context, ray_cluster_yaml="cluster.yaml", kuberay_version="1.0.0", @@ -254,7 +251,7 @@ def test_setup_cluster(self, mock_ray_hook, context): update_if_exists=True, ) - @patch("ray_provider.operators.RayHook") + @patch("ray_provider.operators.SubmitRayJob.hook") def test_delete_cluster(self, mock_ray_hook): operator = SubmitRayJob( task_id="test_task", @@ -264,18 +261,18 @@ def test_delete_cluster(self, mock_ray_hook): ray_cluster_yaml="cluster.yaml", gpu_device_plugin_yaml="https://example.com/plugin.yml", ) - - mock_hook = mock_ray_hook.return_value - operator.hook = mock_hook - operator._delete_cluster() - mock_hook.delete_ray_cluster.assert_called_once_with( + mock_ray_hook.delete_ray_cluster.assert_called_once_with( ray_cluster_yaml="cluster.yaml", gpu_device_plugin_yaml="https://example.com/plugin.yml", ) - def test_execute_without_wait(self, mock_hook, context): + @patch("ray_provider.operators.SubmitRayJob._setup_cluster") + @patch("ray_provider.operators.SubmitRayJob.hook.submit_ray_job", return_value="test_job_id") + @patch("ray_provider.operators.SubmitRayJob.hook") + def test_execute_without_wait(self, mock_hook, mock_submit_ray_job, mock_setup_cluster, context): + operator = SubmitRayJob( task_id="test_task", conn_id="test_conn", @@ -284,22 +281,20 @@ def test_execute_without_wait(self, mock_hook, context): wait_for_completion=False, ) - mock_hook.submit_ray_job.return_value = "test_job_id" + result = operator.execute(context) + assert result == "test_job_id" - with patch.object(operator, "_setup_cluster") as mock_setup_cluster: - result = operator.execute(context) + mock_setup_cluster.assert_called_once_with(context=context) - mock_setup_cluster.assert_called_once_with(context=context) - assert result == "test_job_id" - mock_hook.submit_ray_job.assert_called_once_with( - dashboard_url=None, - entrypoint="python script.py", - runtime_env={}, - entrypoint_num_cpus=0, - entrypoint_num_gpus=0, - entrypoint_memory=0, - entrypoint_resources=None, - ) + mock_submit_ray_job.assert_called_once_with( + dashboard_url=None, + entrypoint="python script.py", + runtime_env={}, + entrypoint_num_cpus=0, + entrypoint_num_gpus=0, + entrypoint_memory=0, + entrypoint_resources=None, + ) @pytest.mark.parametrize( "job_status,expected_action", @@ -307,11 +302,14 @@ def test_execute_without_wait(self, mock_hook, context): (JobStatus.PENDING, "defer"), (JobStatus.RUNNING, "defer"), (JobStatus.SUCCEEDED, None), - (JobStatus.FAILED, "raise"), - (JobStatus.STOPPED, "raise"), ], ) - def test_execute_with_wait(self, mock_hook, context, job_status, expected_action): + @patch("ray_provider.operators.SubmitRayJob._setup_cluster") + @patch("ray_provider.operators.SubmitRayJob.hook.submit_ray_job", return_value="test_job_id") + @patch("ray_provider.operators.SubmitRayJob.hook") + def test_execute_with_wait(self, mock_hook, mock_setup_cluster, context, job_status, expected_action): + mock_hook.get_ray_job_status.return_value = job_status + operator = SubmitRayJob( task_id="test_task", conn_id="test_conn", @@ -320,20 +318,13 @@ def test_execute_with_wait(self, mock_hook, context, job_status, expected_action wait_for_completion=True, ) - mock_hook.submit_ray_job.return_value = "test_job_id" - mock_hook.get_ray_job_status.return_value = job_status - - with patch.object(operator, "_setup_cluster"): - if expected_action == "defer": - with patch.object(operator, "defer") as mock_defer: - operator.execute(context) - mock_defer.assert_called_once() - elif expected_action == "raise": - with pytest.raises(AirflowException): - operator.execute(context) - else: - result = operator.execute(context) - assert result == "test_job_id" + if expected_action == "defer": + with patch.object(operator, "defer") as mock_defer: + operator.execute(context) + mock_defer.assert_called_once() + else: + result = operator.execute(context) + assert result == "test_job_id" @pytest.mark.parametrize( "event_status,expected_action", @@ -344,19 +335,24 @@ def test_execute_with_wait(self, mock_hook, context, job_status, expected_action ("UNEXPECTED", "raise"), ], ) - def test_execute_complete(self, operator, event_status, expected_action): + @patch("ray_provider.operators.SubmitRayJob._delete_cluster") + def test_execute_complete(self, mock_delete_cluster, event_status, expected_action): + operator = SubmitRayJob( + task_id="test_task", + conn_id="test_conn", + entrypoint="python script.py", + runtime_env={}, + ) operator.job_id = "test_job_id" event = {"status": event_status, "message": "Test message"} - with patch.object(operator, "_delete_cluster") as mock_delete_cluster: - if expected_action == "raise": - with pytest.raises(AirflowException): - operator.execute_complete({}, event) - else: + if expected_action == "raise": + with pytest.raises(RayAirflowException): operator.execute_complete({}, event) + else: + operator.execute_complete({}, event) - # _delete_cluster should be called in all cases - mock_delete_cluster.assert_called_once() + mock_delete_cluster.assert_called_once() def test_template_fields(self): assert SubmitRayJob.template_fields == ( @@ -420,7 +416,8 @@ def test_delete_cluster_exception(self, mock_ray_hook): ("single_key", None, "single_key"), ], ) - def test_get_dashboard_url_xcom_variants(self, operator, context, xcom_task_key, expected_task, expected_key): + def test_get_dashboard_url_xcom_variants(self, context, xcom_task_key, expected_task, expected_key): + operator = SubmitRayJob(task_id="test_task", conn_id="test_conn", entrypoint="python script.py", runtime_env={}) operator.xcom_task_key = xcom_task_key context["ti"].xcom_pull.return_value = "http://dashboard.url" @@ -432,7 +429,13 @@ def test_get_dashboard_url_xcom_variants(self, operator, context, xcom_task_key, else: context["ti"].xcom_pull.assert_called_once_with(task_ids=context["task"].task_id, key=expected_key) - def test_execute_job_unexpected_state(self, mock_hook, context): + @patch("ray_provider.operators.SubmitRayJob._setup_cluster") + @patch("ray_provider.operators.SubmitRayJob.hook.get_ray_job_status", return_value="UNEXPECTED_STATE") + @patch("ray_provider.operators.SubmitRayJob.hook.submit_ray_job", return_value="test_job_id") + @patch("ray_provider.operators.SubmitRayJob.hook") + def test_execute_job_unexpected_state( + self, mock_hook, mock_submit_ray_job, mock_get_ray_job_status, mock_setup_cluster, context + ): operator = SubmitRayJob( task_id="test_task", conn_id="test_conn", @@ -440,16 +443,16 @@ def test_execute_job_unexpected_state(self, mock_hook, context): runtime_env={}, wait_for_completion=True, ) - mock_hook.submit_ray_job.return_value = "test_job_id" - mock_hook.get_ray_job_status.return_value = "UNEXPECTED_STATE" - with patch.object(operator, "_setup_cluster"), pytest.raises(TaskDeferred) as exc_info: + with pytest.raises(TaskDeferred) as exc_info: operator.execute(context) assert isinstance(exc_info.value.trigger, RayJobTrigger) @pytest.mark.parametrize("dashboard_url", [None, "http://dashboard.url"]) - def test_execute_defer(self, mock_hook, context, dashboard_url): + @patch("ray_provider.operators.SubmitRayJob._setup_cluster") + @patch("ray_provider.operators.SubmitRayJob.hook") + def test_execute_defer(self, mock_hook, mock_setup_cluster, context, dashboard_url): operator = SubmitRayJob( task_id="test_task", conn_id="test_conn", @@ -465,9 +468,9 @@ def test_execute_defer(self, mock_hook, context, dashboard_url): mock_hook.submit_ray_job.return_value = "test_job_id" mock_hook.get_ray_job_status.return_value = JobStatus.PENDING - with patch.object(operator, "_setup_cluster"), patch.object( - operator, "_get_dashboard_url", return_value=dashboard_url - ), pytest.raises(TaskDeferred) as exc_info: + with patch.object(operator, "_get_dashboard_url", return_value=dashboard_url), pytest.raises( + TaskDeferred + ) as exc_info: operator.execute(context) trigger = exc_info.value.trigger @@ -480,21 +483,29 @@ def test_execute_defer(self, mock_hook, context, dashboard_url): assert trigger.poll_interval == 30 assert trigger.fetch_logs is True - def test_execute_complete_unexpected_status(self, operator): + @patch("ray_provider.operators.SubmitRayJob._delete_cluster") + def test_execute_complete_unexpected_status(self, mock_delete_cluster): + operator = SubmitRayJob(task_id="test_task", conn_id="test_conn", entrypoint="python script.py", runtime_env={}) event = {"status": "UNEXPECTED", "message": "Unexpected status"} - with patch.object(operator, "_delete_cluster"), pytest.raises(AirflowException) as exc_info: + with pytest.raises(RayAirflowException) as exc_info: operator.execute_complete({}, event) - assert "Unexpected event status" in str(exc_info.value) + assert "Encountered unexpected state" in str(exc_info.value) - def test_execute_complete_cleanup_on_exception(self, operator): + @patch("ray_provider.operators.SubmitRayJob._delete_cluster") + def test_execute_complete_cleanup_on_exception(self, mock_delete_cluster): + operator = SubmitRayJob(task_id="test_task", conn_id="test_conn", entrypoint="python script.py", runtime_env={}) event = {"status": JobStatus.FAILED, "message": "Job failed"} - with patch.object(operator, "_delete_cluster") as mock_delete_cluster, pytest.raises(AirflowException): + with pytest.raises(RayAirflowException): operator.execute_complete({}, event) mock_delete_cluster.assert_called_once() - def test_execute_exception_handling(self, mock_hook, context): + @patch("ray_provider.operators.SubmitRayJob._setup_cluster") + @patch("ray_provider.operators.SubmitRayJob._delete_cluster") + @patch("ray_provider.operators.SubmitRayJob.hook.submit_ray_job", side_effect=Exception("Job submission failed")) + @patch("ray_provider.operators.SubmitRayJob.hook") + def test_execute_exception_handling(self, mock_hook, mock_delete, mock_setup, context): operator = SubmitRayJob( task_id="test_task", conn_id="test_conn", @@ -503,17 +514,17 @@ def test_execute_exception_handling(self, mock_hook, context): ray_cluster_yaml="cluster.yaml", ) - mock_hook.submit_ray_job.side_effect = Exception("Job submission failed") - - with patch.object(operator, "_setup_cluster"), patch.object( - operator, "_delete_cluster" - ) as mock_delete_cluster, pytest.raises(AirflowException) as exc_info: + with pytest.raises(Exception) as exc_info: operator.execute(context) - assert "SubmitRayJob operator failed due to Job submission failed" in str(exc_info.value) - mock_delete_cluster.assert_called_once() + assert "Job submission failed" == str(exc_info.value) - def test_execute_cluster_setup_exception(self, mock_hook, context): + @patch("ray_provider.operators.SubmitRayJob._delete_cluster") + @patch( + "ray_provider.operators.SubmitRayJob.hook.setup_ray_cluster", side_effect=ApiException("Cluster setup failed") + ) + @patch("ray_provider.operators.SubmitRayJob.hook") + def test_execute_cluster_setup_exception(self, mock_hook, mock_setup_cluster, mock_delete_cluster, context): operator = SubmitRayJob( task_id="test_task", conn_id="test_conn", @@ -522,14 +533,13 @@ def test_execute_cluster_setup_exception(self, mock_hook, context): ray_cluster_yaml="cluster.yaml", ) - with patch.object(operator, "_setup_cluster", side_effect=Exception("Cluster setup failed")), patch.object( - operator, "_delete_cluster" - ) as mock_delete_cluster, pytest.raises(AirflowException) as exc_info: + with pytest.raises(ApiException) as exc_info: operator.execute(context) - assert "SubmitRayJob operator failed due to Cluster setup failed" in str(exc_info.value) + assert "Cluster setup failed" in str(exc_info.value) mock_delete_cluster.assert_called_once() + @patch("ray_provider.operators.RayHook") def test_execute_with_wait_and_defer(self, mock_hook, context): operator = SubmitRayJob( task_id="test_task", @@ -554,16 +564,13 @@ def test_execute_with_wait_and_defer(self, mock_hook, context): assert kwargs["method_name"] == "execute_complete" assert kwargs["timeout"].total_seconds() == 600 - def test_execute_complete_with_cleanup(self, operator): - operator.job_id = "test_job_id" - event = {"status": JobStatus.FAILED, "message": "Job failed"} - - with patch.object(operator, "_delete_cluster") as mock_delete_cluster, pytest.raises(AirflowException): - operator.execute_complete({}, event) - - mock_delete_cluster.assert_called_once() - - def test_execute_without_wait_no_cleanup(self, mock_hook, context): + @patch("ray_provider.operators.SubmitRayJob._delete_cluster") + @patch("ray_provider.operators.SubmitRayJob._setup_cluster") + @patch("ray_provider.operators.SubmitRayJob.hook.submit_ray_job", return_value="test_job_id") + @patch("ray_provider.operators.SubmitRayJob.hook") + def test_execute_without_wait_no_cleanup( + self, mock_hook, mock_submit, mock_setup_cluster, mock_delete_cluster, context + ): operator = SubmitRayJob( task_id="test_task", conn_id="test_conn", @@ -572,22 +579,17 @@ def test_execute_without_wait_no_cleanup(self, mock_hook, context): wait_for_completion=False, ) - mock_hook.submit_ray_job.return_value = "test_job_id" + result = operator.execute(context) + assert result == "test_job_id" - with patch.object(operator, "_setup_cluster") as mock_setup_cluster, patch.object( - operator, "_delete_cluster" - ) as mock_delete_cluster: - result = operator.execute(context) - - mock_setup_cluster.assert_called_once_with(context=context) - assert result == "test_job_id" - mock_hook.submit_ray_job.assert_called_once_with( - dashboard_url=None, - entrypoint="python script.py", - runtime_env={}, - entrypoint_num_cpus=0, - entrypoint_num_gpus=0, - entrypoint_memory=0, - entrypoint_resources=None, - ) - mock_delete_cluster.assert_not_called() + mock_setup_cluster.assert_called_once_with(context=context) + mock_submit.assert_called_once_with( + dashboard_url=None, + entrypoint="python script.py", + runtime_env={}, + entrypoint_num_cpus=0, + entrypoint_num_gpus=0, + entrypoint_memory=0, + entrypoint_resources=None, + ) + mock_delete_cluster.assert_not_called()