diff --git a/.astro-registry.yaml b/.astro-registry.yaml index 0618afc..53624c4 100644 --- a/.astro-registry.yaml +++ b/.astro-registry.yaml @@ -4,15 +4,15 @@ display-name: Ray docs-url: https://github.com/astronomer/astro-provider-ray/blob/main/README.md hooks: - - module: ray_provider.hooks.ray.RayHook + - module: ray_provider.hooks.RayHook decorators: - - module: ray_provider.decorators.ray.ray + - module: ray_provider.decorators.ray operators: - - module: ray_provider.operators.ray.SetupRayCluster - - module: ray_provider.operators.ray.SubmitRayJob - - module: ray_provider.operators.ray.DeleteRayCluster + - module: ray_provider.operators.SetupRayCluster + - module: ray_provider.operators.SubmitRayJob + - module: ray_provider.operators.DeleteRayCluster triggers: - - module: ray_provider.triggers.ray.RayJobTrigger + - module: ray_provider.triggers.RayJobTrigger diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 8e0e244..e88b890 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,6 +1,28 @@ CHANGELOG ========= +0.3.0 (2024-11-29) +------------------ + +**Breaking changes** + +In order to improve the development and troubleshooting DAGs created with this provider, we introduced breaking changes +to the folder structure. It was flattened and the import paths to existing decorators, hooks, operators and trigger +changed, as documented in the table below: + ++-----------+---------------------------------------------+-----------------------------------------+ +| Type | Previous import path | Current import path | ++===========+=============================================+=========================================+ +| Decorator | ray_provider.decorators.ray.ray | ray_provider.decorators.ray | +| Hook | ray_provider.hooks.ray.RayHook | ray_provider.hooks.RayHook | +| Operator | ray_provider.operators.ray.DeleteRayCluster | ray_provider.operators.DeleteRayCluster | +| Operator | ray_provider.operators.ray.SetupRayCluster | ray_provider.operators.SetupRayCluster | +| Operator | ray_provider.operators.ray.SubmitRayJob | ray_provider.operators.SubmitRayJob | +| Trigger | ray_provider.triggers.ray.RayJobTrigger | ray_provider.triggers.RayJobTrigger | ++-----------+---------------------------------------------+-----------------------------------------+ + + + 0.2.1 (2024-09-04) ------------------ diff --git a/dev/dags/ray_single_operator.py b/dev/dags/ray_single_operator.py index 6057515..cbca222 100644 --- a/dev/dags/ray_single_operator.py +++ b/dev/dags/ray_single_operator.py @@ -3,7 +3,7 @@ from airflow import DAG -from ray_provider.operators.ray import SubmitRayJob +from ray_provider.operators import SubmitRayJob CONN_ID = "ray_conn" RAY_SPEC = Path(__file__).parent / "scripts/ray.yaml" diff --git a/dev/dags/ray_taskflow_example.py b/dev/dags/ray_taskflow_example.py index 5878cf0..9ccef6e 100644 --- a/dev/dags/ray_taskflow_example.py +++ b/dev/dags/ray_taskflow_example.py @@ -3,7 +3,7 @@ from airflow.decorators import dag, task -from ray_provider.decorators.ray import ray +from ray_provider.decorators import ray CONN_ID = "ray_conn" RAY_SPEC = Path(__file__).parent / "scripts/ray.yaml" diff --git a/dev/dags/ray_taskflow_example_existing_cluster.py b/dev/dags/ray_taskflow_example_existing_cluster.py index 9160f50..c5c8a40 100644 --- a/dev/dags/ray_taskflow_example_existing_cluster.py +++ b/dev/dags/ray_taskflow_example_existing_cluster.py @@ -3,7 +3,7 @@ from airflow.decorators import dag, task -from ray_provider.decorators.ray import ray +from ray_provider.decorators import ray CONN_ID = "ray_job" FOLDER_PATH = Path(__file__).parent / "ray_scripts" diff --git a/dev/dags/setup-teardown.py b/dev/dags/setup-teardown.py index c2ac712..48451c5 100644 --- a/dev/dags/setup-teardown.py +++ b/dev/dags/setup-teardown.py @@ -3,7 +3,7 @@ from airflow import DAG -from ray_provider.operators.ray import DeleteRayCluster, SetupRayCluster, SubmitRayJob +from ray_provider.operators import DeleteRayCluster, SetupRayCluster, SubmitRayJob CONN_ID = "ray_conn" RAY_SPEC = Path(__file__).parent / "scripts/ray.yaml" diff --git a/docs/api/ray_provider.decorators.rst b/docs/api/ray_provider.decorators.rst index e2f5347..8b36137 100644 --- a/docs/api/ray_provider.decorators.rst +++ b/docs/api/ray_provider.decorators.rst @@ -2,7 +2,7 @@ Decorators ---------- -.. automodule:: ray_provider.decorators.ray +.. automodule:: ray_provider.decorators :members: :undoc-members: :show-inheritance: diff --git a/docs/api/ray_provider.hooks.rst b/docs/api/ray_provider.hooks.rst index 6cf932c..2f9b39a 100644 --- a/docs/api/ray_provider.hooks.rst +++ b/docs/api/ray_provider.hooks.rst @@ -1,7 +1,7 @@ Hook ----- -.. automodule:: ray_provider.hooks.ray +.. automodule:: ray_provider.hooks :members: :undoc-members: :show-inheritance: diff --git a/docs/api/ray_provider.operators.rst b/docs/api/ray_provider.operators.rst index af1bfa6..fac793e 100644 --- a/docs/api/ray_provider.operators.rst +++ b/docs/api/ray_provider.operators.rst @@ -1,7 +1,7 @@ Operators --------- -.. automodule:: ray_provider.operators.ray +.. automodule:: ray_provider.operators :members: :undoc-members: :show-inheritance: diff --git a/docs/api/ray_provider.triggers.rst b/docs/api/ray_provider.triggers.rst index 4b71046..7edcea6 100644 --- a/docs/api/ray_provider.triggers.rst +++ b/docs/api/ray_provider.triggers.rst @@ -1,7 +1,7 @@ Trigger -------- -.. automodule:: ray_provider.triggers.ray +.. automodule:: ray_provider.triggers :members: :undoc-members: :show-inheritance: diff --git a/ray_provider/__init__.py b/ray_provider/__init__.py index d9657e7..04fb8c2 100644 --- a/ray_provider/__init__.py +++ b/ray_provider/__init__.py @@ -11,6 +11,6 @@ def get_provider_info() -> dict[str, Any]: "package-name": "astro-provider-ray", # Required "name": "Ray", # Required "description": "An integration between airflow and ray", # Required - "connection-types": [{"connection-type": "ray", "hook-class-name": "ray_provider.hooks.ray.RayHook"}], + "connection-types": [{"connection-type": "ray", "hook-class-name": "ray_provider.hooks.RayHook"}], "versions": [__version__], # Required } diff --git a/ray_provider/decorators/ray.py b/ray_provider/decorators.py similarity index 99% rename from ray_provider/decorators/ray.py rename to ray_provider/decorators.py index 82ca474..bcb150d 100644 --- a/ray_provider/decorators/ray.py +++ b/ray_provider/decorators.py @@ -12,7 +12,7 @@ from airflow.exceptions import AirflowException from airflow.utils.context import Context -from ray_provider.operators.ray import SubmitRayJob +from ray_provider.operators import SubmitRayJob class _RayDecoratedOperator(DecoratedOperator, SubmitRayJob): diff --git a/ray_provider/decorators/__init__.py b/ray_provider/decorators/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/ray_provider/hooks/ray.py b/ray_provider/hooks.py similarity index 100% rename from ray_provider/hooks/ray.py rename to ray_provider/hooks.py diff --git a/ray_provider/hooks/__init__.py b/ray_provider/hooks/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/ray_provider/operators/ray.py b/ray_provider/operators.py similarity index 99% rename from ray_provider/operators/ray.py rename to ray_provider/operators.py index 02b6a73..b6a38d6 100644 --- a/ray_provider/operators/ray.py +++ b/ray_provider/operators.py @@ -10,8 +10,8 @@ from airflow.utils.context import Context from ray.job_submission import JobStatus -from ray_provider.hooks.ray import RayHook -from ray_provider.triggers.ray import RayJobTrigger +from ray_provider.hooks import RayHook +from ray_provider.triggers import RayJobTrigger class SetupRayCluster(BaseOperator): diff --git a/ray_provider/operators/__init__.py b/ray_provider/operators/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/ray_provider/triggers/ray.py b/ray_provider/triggers.py similarity index 98% rename from ray_provider/triggers/ray.py rename to ray_provider/triggers.py index 745c74f..3252c99 100644 --- a/ray_provider/triggers/ray.py +++ b/ray_provider/triggers.py @@ -7,7 +7,7 @@ from airflow.triggers.base import BaseTrigger, TriggerEvent from ray.job_submission import JobStatus -from ray_provider.hooks.ray import RayHook +from ray_provider.hooks import RayHook class RayJobTrigger(BaseTrigger): @@ -51,7 +51,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]: :return: A tuple containing the fully qualified class name and a dictionary of its parameters. """ return ( - "ray_provider.triggers.ray.RayJobTrigger", + "ray_provider.triggers.RayJobTrigger", { "job_id": self.job_id, "conn_id": self.conn_id, diff --git a/ray_provider/triggers/__init__.py b/ray_provider/triggers/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/decorators/__init__.py b/tests/decorators/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/hooks/__init__.py b/tests/hooks/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/operators/__init__.py b/tests/operators/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/decorators/test_ray_decorators.py b/tests/test_decorators.py similarity index 95% rename from tests/decorators/test_ray_decorators.py rename to tests/test_decorators.py index 109f748..a6e6b15 100644 --- a/tests/decorators/test_ray_decorators.py +++ b/tests/test_decorators.py @@ -5,7 +5,7 @@ from airflow.exceptions import AirflowException from airflow.utils.context import Context -from ray_provider.decorators.ray import _RayDecoratedOperator, ray +from ray_provider.decorators import _RayDecoratedOperator, ray class TestRayDecoratedOperator: @@ -81,7 +81,7 @@ def dummy_callable(): _RayDecoratedOperator(task_id="test_task", config=config, python_callable=dummy_callable) @patch.object(_RayDecoratedOperator, "_extract_function_body") - @patch("ray_provider.decorators.ray.SubmitRayJob.execute") + @patch("ray_provider.decorators.SubmitRayJob.execute") def test_execute_decorated_function(self, mock_super_execute, mock_extract_function_body): config = { "runtime_env": {"pip": ["ray"]}, @@ -101,7 +101,7 @@ def dummy_callable(): assert operator.entrypoint == "python script.py" assert "working_dir" in operator.runtime_env - @patch("ray_provider.decorators.ray.SubmitRayJob.execute") + @patch("ray_provider.decorators.SubmitRayJob.execute") def test_execute_with_entrypoint(self, mock_super_execute): config = { "entrypoint": "python my_script.py", @@ -119,7 +119,7 @@ def dummy_callable(): assert result == "success" assert operator.entrypoint == "python my_script.py" - @patch("ray_provider.decorators.ray.SubmitRayJob.execute") + @patch("ray_provider.decorators.SubmitRayJob.execute") def test_execute_failure(self, mock_super_execute): config = {} diff --git a/tests/hooks/test_ray_hooks.py b/tests/test_hooks.py similarity index 79% rename from tests/hooks/test_ray_hooks.py rename to tests/test_hooks.py index 95c787d..cfa33b7 100644 --- a/tests/hooks/test_ray_hooks.py +++ b/tests/test_hooks.py @@ -8,19 +8,19 @@ from kubernetes.client.exceptions import ApiException from ray.job_submission import JobStatus -from ray_provider.hooks.ray import RayHook +from ray_provider.hooks import RayHook class TestRayHook: @pytest.fixture def ray_hook(self): - with patch("ray_provider.hooks.ray.KubernetesHook.get_connection") as mock_get_connection: + with patch("ray_provider.hooks.KubernetesHook.get_connection") as mock_get_connection: mock_connection = Mock() mock_connection.extra_dejson = {"kube_config_path": None, "kube_config": None, "cluster_context": None} mock_get_connection.return_value = mock_connection - with patch("ray_provider.hooks.ray.KubernetesHook.__init__", return_value=None): + with patch("ray_provider.hooks.KubernetesHook.__init__", return_value=None): hook = RayHook(conn_id="test_conn") # Manually set the necessary attributes hook.namespace = "default" @@ -40,7 +40,7 @@ def test_get_connection_form_widgets(self): assert "kube_config_path" in widgets assert "namespace" in widgets - @patch("ray_provider.hooks.ray.JobSubmissionClient") + @patch("ray_provider.hooks.JobSubmissionClient") def test_ray_client(self, mock_job_client, ray_hook): mock_job_client.return_value = MagicMock() client = ray_hook.ray_client() @@ -54,16 +54,16 @@ def test_ray_client(self, mock_job_client, ray_hook): verify=ray_hook.verify, ) - @patch("ray_provider.hooks.ray.JobSubmissionClient") + @patch("ray_provider.hooks.JobSubmissionClient") def test_submit_ray_job(self, mock_job_client, ray_hook): mock_client_instance = mock_job_client.return_value mock_client_instance.submit_job.return_value = "test_job_id" job_id = ray_hook.submit_ray_job(dashboard_url="http://example.com", entrypoint="test_entry") assert job_id == "test_job_id" - @patch("ray_provider.hooks.ray.KubernetesHook.get_connection") - @patch("ray_provider.hooks.ray.KubernetesHook.__init__") - @patch("ray_provider.hooks.ray.config.load_kube_config") + @patch("ray_provider.hooks.KubernetesHook.get_connection") + @patch("ray_provider.hooks.KubernetesHook.__init__") + @patch("ray_provider.hooks.config.load_kube_config") def test_setup_kubeconfig_path(self, mock_load_kube_config, mock_kubernetes_init, mock_get_connection): mock_kubernetes_init.return_value = None mock_get_connection.return_value = MagicMock(conn_id="test_conn", extra_dejson={}) @@ -74,9 +74,9 @@ def test_setup_kubeconfig_path(self, mock_load_kube_config, mock_kubernetes_init assert hook.kubeconfig == "/tmp/fake_kubeconfig" mock_load_kube_config.assert_called_once_with(config_file="/tmp/fake_kubeconfig", context="test_context") - @patch("ray_provider.hooks.ray.KubernetesHook.get_connection") - @patch("ray_provider.hooks.ray.KubernetesHook.__init__") - @patch("ray_provider.hooks.ray.config.load_kube_config") + @patch("ray_provider.hooks.KubernetesHook.get_connection") + @patch("ray_provider.hooks.KubernetesHook.__init__") + @patch("ray_provider.hooks.config.load_kube_config") @patch("tempfile.NamedTemporaryFile") def test_setup_kubeconfig_content( self, mock_tempfile, mock_load_kube_config, mock_kubernetes_init, mock_get_connection @@ -95,8 +95,8 @@ def test_setup_kubeconfig_content( mock_tempfile.return_value.__enter__.return_value.write.assert_called_once_with(kubeconfig_content.encode()) mock_load_kube_config.assert_called_once_with(config_file="/tmp/fake_kubeconfig", context="test_context") - @patch("ray_provider.hooks.ray.KubernetesHook.get_connection") - @patch("ray_provider.hooks.ray.KubernetesHook.__init__") + @patch("ray_provider.hooks.KubernetesHook.get_connection") + @patch("ray_provider.hooks.KubernetesHook.__init__") def test_setup_kubeconfig_invalid_config(self, mock_kubernetes_init, mock_get_connection): mock_kubernetes_init.return_value = None mock_get_connection.return_value = MagicMock(conn_id="test_conn", extra_dejson={}) @@ -110,8 +110,8 @@ def test_setup_kubeconfig_invalid_config(self, mock_kubernetes_init, mock_get_co "kube_config are mutually exclusive. You can only use one option at a time." ) - @patch("ray_provider.hooks.ray.KubernetesHook.get_connection") - @patch("ray_provider.hooks.ray.JobSubmissionClient") + @patch("ray_provider.hooks.KubernetesHook.get_connection") + @patch("ray_provider.hooks.JobSubmissionClient") def test_delete_ray_job(self, mock_job_client, mock_get_connection): mock_get_connection.return_value = MagicMock(conn_id="test_conn", extra_dejson={}) mock_client_instance = mock_job_client.return_value @@ -120,8 +120,8 @@ def test_delete_ray_job(self, mock_job_client, mock_get_connection): result = hook.delete_ray_job("http://example.com", job_id="test_job_id") assert result == "deleted" - @patch("ray_provider.hooks.ray.KubernetesHook.get_connection") - @patch("ray_provider.hooks.ray.JobSubmissionClient") + @patch("ray_provider.hooks.KubernetesHook.get_connection") + @patch("ray_provider.hooks.JobSubmissionClient") def test_get_ray_job_status(self, mock_job_client, mock_get_connection): mock_get_connection.return_value = MagicMock(conn_id="test_conn", extra_dejson={}) mock_client_instance = mock_job_client.return_value @@ -130,8 +130,8 @@ def test_get_ray_job_status(self, mock_job_client, mock_get_connection): status = hook.get_ray_job_status("http://example.com", "test_job_id") assert status == JobStatus.SUCCEEDED - @patch("ray_provider.hooks.ray.KubernetesHook.get_connection") - @patch("ray_provider.hooks.ray.JobSubmissionClient") + @patch("ray_provider.hooks.KubernetesHook.get_connection") + @patch("ray_provider.hooks.JobSubmissionClient") def test_get_ray_job_logs(self, mock_job_client, mock_get_connection): mock_get_connection.return_value = MagicMock(conn_id="test_conn", extra_dejson={}) mock_client_instance = mock_job_client.return_value @@ -154,8 +154,8 @@ def test_get_ray_job_logs(self, mock_job_client, mock_get_connection): ) mock_client_instance.get_job_logs.assert_called_once_with(job_id=job_id) - @patch("ray_provider.hooks.ray.KubernetesHook.get_connection") - @patch("ray_provider.hooks.ray.requests.get") + @patch("ray_provider.hooks.KubernetesHook.get_connection") + @patch("ray_provider.hooks.requests.get") @patch("builtins.open", new_callable=mock_open, read_data="key: value\n") def test_load_yaml_content(self, mock_open, mock_requests, mock_get_connection): mock_get_connection.return_value = MagicMock(conn_id="test_conn", extra_dejson={}) @@ -203,8 +203,8 @@ def test_validate_yaml_file_not_exists(self, mock_isfile, ray_hook): assert "The specified YAML file does not exist" in str(exc_info.value) mock_isfile.assert_called_once_with("non_existent_file.yaml") - @patch("ray_provider.hooks.ray.KubernetesHook.get_connection") - @patch("ray_provider.hooks.ray.socket.socket") + @patch("ray_provider.hooks.KubernetesHook.get_connection") + @patch("ray_provider.hooks.socket.socket") def test_is_port_open(self, mock_socket, mock_get_connection): mock_get_connection.return_value = MagicMock(conn_id="test_conn", extra_dejson={}) mock_socket_instance = mock_socket.return_value @@ -215,7 +215,7 @@ def test_is_port_open(self, mock_socket, mock_get_connection): result = hook._is_port_open("localhost", 8080) assert result is True - @patch("ray_provider.hooks.ray.RayHook.core_v1_client") + @patch("ray_provider.hooks.RayHook.core_v1_client") def test_get_service_success(self, mock_core_v1_client, ray_hook): mock_service = Mock(spec=client.V1Service) mock_core_v1_client.read_namespaced_service.return_value = mock_service @@ -225,7 +225,7 @@ def test_get_service_success(self, mock_core_v1_client, ray_hook): assert service == mock_service mock_core_v1_client.read_namespaced_service.assert_called_once_with("test-service", "default") - @patch("ray_provider.hooks.ray.RayHook.core_v1_client") + @patch("ray_provider.hooks.RayHook.core_v1_client") def test_get_service_not_found(self, mock_core_v1_client, ray_hook): mock_core_v1_client.read_namespaced_service.side_effect = client.exceptions.ApiException(status=404) @@ -285,8 +285,8 @@ def test_get_load_balancer_details_no_ip_or_hostname(self, ray_hook): assert lb_details is None - @patch("ray_provider.hooks.ray.RayHook.log") - @patch("ray_provider.hooks.ray.subprocess.run") + @patch("ray_provider.hooks.RayHook.log") + @patch("ray_provider.hooks.subprocess.run") def test_run_bash_command_exception(self, mock_subprocess_run, mock_log, ray_hook): mock_subprocess_run.side_effect = subprocess.CalledProcessError( returncode=1, cmd="test command", output="test output", stderr="test error" @@ -313,9 +313,9 @@ def test_run_bash_command_exception(self, mock_subprocess_run, mock_log, ray_hoo env=ray_hook._run_bash_command.__globals__["os"].environ.copy(), ) - @patch("ray_provider.hooks.ray.KubernetesHook.get_connection") - @patch("ray_provider.hooks.ray.KubernetesHook.__init__") - @patch("ray_provider.hooks.ray.subprocess.run") + @patch("ray_provider.hooks.KubernetesHook.get_connection") + @patch("ray_provider.hooks.KubernetesHook.__init__") + @patch("ray_provider.hooks.subprocess.run") def test_install_kuberay_operator(self, mock_subprocess_run, mock_kubernetes_init, mock_get_connection): mock_kubernetes_init.return_value = None mock_get_connection.return_value = MagicMock(conn_id="test_conn", extra_dejson={}) @@ -327,9 +327,9 @@ def test_install_kuberay_operator(self, mock_subprocess_run, mock_kubernetes_ini assert "install output" in stdout assert stderr == "" - @patch("ray_provider.hooks.ray.KubernetesHook.get_connection") - @patch("ray_provider.hooks.ray.KubernetesHook.__init__") - @patch("ray_provider.hooks.ray.subprocess.run") + @patch("ray_provider.hooks.KubernetesHook.get_connection") + @patch("ray_provider.hooks.KubernetesHook.__init__") + @patch("ray_provider.hooks.subprocess.run") def test_uninstall_kuberay_operator(self, mock_subprocess_run, mock_kubernetes_init, mock_get_connection): mock_kubernetes_init.return_value = None mock_get_connection.return_value = MagicMock(conn_id="test_conn", extra_dejson={}) @@ -341,9 +341,9 @@ def test_uninstall_kuberay_operator(self, mock_subprocess_run, mock_kubernetes_i assert "uninstall output" in stdout assert stderr == "" - @patch("ray_provider.hooks.ray.RayHook._get_service") - @patch("ray_provider.hooks.ray.RayHook._get_load_balancer_details") - @patch("ray_provider.hooks.ray.RayHook._check_load_balancer_readiness") + @patch("ray_provider.hooks.RayHook._get_service") + @patch("ray_provider.hooks.RayHook._get_load_balancer_details") + @patch("ray_provider.hooks.RayHook._check_load_balancer_readiness") def test_wait_for_load_balancer_success( self, mock_check_readiness, mock_get_lb_details, mock_get_service, ray_hook ): @@ -370,9 +370,9 @@ def test_wait_for_load_balancer_success( mock_get_lb_details.assert_called_once_with(mock_service) mock_check_readiness.assert_called_once() - @patch("ray_provider.hooks.ray.RayHook._get_service") - @patch("ray_provider.hooks.ray.RayHook._get_load_balancer_details") - @patch("ray_provider.hooks.ray.RayHook._is_port_open") + @patch("ray_provider.hooks.RayHook._get_service") + @patch("ray_provider.hooks.RayHook._get_load_balancer_details") + @patch("ray_provider.hooks.RayHook._is_port_open") def test_wait_for_load_balancer_timeout(self, mock_is_port_open, mock_get_lb_details, mock_get_service, ray_hook): mock_service = Mock(spec=client.V1Service) mock_get_service.return_value = mock_service @@ -390,7 +390,7 @@ def test_wait_for_load_balancer_timeout(self, mock_is_port_open, mock_get_lb_det assert "LoadBalancer did not become ready after 2 attempts" in str(exc_info.value) - @patch("ray_provider.hooks.ray.RayHook._get_service") + @patch("ray_provider.hooks.RayHook._get_service") def test_wait_for_load_balancer_service_not_found(self, mock_get_service, ray_hook): mock_get_service.side_effect = AirflowException("Service test-service not found") @@ -399,7 +399,7 @@ def test_wait_for_load_balancer_service_not_found(self, mock_get_service, ray_ho assert "LoadBalancer did not become ready after 1 attempts" in str(exc_info.value) - @patch("ray_provider.hooks.ray.RayHook._is_port_open") + @patch("ray_provider.hooks.RayHook._is_port_open") def test_check_load_balancer_readiness_ip(self, mock_is_port_open, ray_hook): mock_is_port_open.return_value = True lb_details = {"ip": "192.168.1.1", "hostname": None, "ports": [{"name": "http", "port": 80}]} @@ -409,7 +409,7 @@ def test_check_load_balancer_readiness_ip(self, mock_is_port_open, ray_hook): assert result == "192.168.1.1" mock_is_port_open.assert_called_once_with("192.168.1.1", 80) - @patch("ray_provider.hooks.ray.RayHook._is_port_open") + @patch("ray_provider.hooks.RayHook._is_port_open") def test_check_load_balancer_readiness_hostname(self, mock_is_port_open, ray_hook): mock_is_port_open.side_effect = [False, True] lb_details = { @@ -424,7 +424,7 @@ def test_check_load_balancer_readiness_hostname(self, mock_is_port_open, ray_hoo mock_is_port_open.assert_any_call("192.168.1.1", 80) mock_is_port_open.assert_any_call("example.com", 80) - @patch("ray_provider.hooks.ray.RayHook._is_port_open") + @patch("ray_provider.hooks.RayHook._is_port_open") def test_check_load_balancer_readiness_not_ready(self, mock_is_port_open, ray_hook): mock_is_port_open.return_value = False lb_details = {"ip": "192.168.1.1", "hostname": "example.com", "ports": [{"name": "http", "port": 80}]} @@ -435,10 +435,10 @@ def test_check_load_balancer_readiness_not_ready(self, mock_is_port_open, ray_ho mock_is_port_open.assert_any_call("192.168.1.1", 80) mock_is_port_open.assert_any_call("example.com", 80) - @patch("ray_provider.hooks.ray.KubernetesHook.get_connection") - @patch("ray_provider.hooks.ray.KubernetesHook.__init__") - @patch("ray_provider.hooks.ray.client.AppsV1Api.read_namespaced_daemon_set") - @patch("ray_provider.hooks.ray.config.load_kube_config") + @patch("ray_provider.hooks.KubernetesHook.get_connection") + @patch("ray_provider.hooks.KubernetesHook.__init__") + @patch("ray_provider.hooks.client.AppsV1Api.read_namespaced_daemon_set") + @patch("ray_provider.hooks.config.load_kube_config") def test_get_daemon_set( self, mock_load_kube_config, mock_read_daemon_set, mock_kubernetes_init, mock_get_connection ): @@ -454,10 +454,10 @@ def test_get_daemon_set( assert daemon_set.metadata.name == "test-daemonset" - @patch("ray_provider.hooks.ray.KubernetesHook.get_connection") - @patch("ray_provider.hooks.ray.KubernetesHook.__init__") - @patch("ray_provider.hooks.ray.client.AppsV1Api.read_namespaced_daemon_set") - @patch("ray_provider.hooks.ray.config.load_kube_config") + @patch("ray_provider.hooks.KubernetesHook.get_connection") + @patch("ray_provider.hooks.KubernetesHook.__init__") + @patch("ray_provider.hooks.client.AppsV1Api.read_namespaced_daemon_set") + @patch("ray_provider.hooks.config.load_kube_config") def test_get_daemon_set_not_found( self, mock_load_kube_config, mock_read_daemon_set, mock_kubernetes_init, mock_get_connection ): @@ -470,10 +470,10 @@ def test_get_daemon_set_not_found( assert daemon_set is None - @patch("ray_provider.hooks.ray.KubernetesHook.get_connection") - @patch("ray_provider.hooks.ray.KubernetesHook.__init__") - @patch("ray_provider.hooks.ray.client.AppsV1Api.create_namespaced_daemon_set") - @patch("ray_provider.hooks.ray.config.load_kube_config") + @patch("ray_provider.hooks.KubernetesHook.get_connection") + @patch("ray_provider.hooks.KubernetesHook.__init__") + @patch("ray_provider.hooks.client.AppsV1Api.create_namespaced_daemon_set") + @patch("ray_provider.hooks.config.load_kube_config") def test_create_daemon_set( self, mock_load_kube_config, mock_create_daemon_set, mock_kubernetes_init, mock_get_connection ): @@ -490,10 +490,10 @@ def test_create_daemon_set( assert daemon_set.metadata.name == "test-daemonset" - @patch("ray_provider.hooks.ray.KubernetesHook.get_connection") - @patch("ray_provider.hooks.ray.KubernetesHook.__init__") - @patch("ray_provider.hooks.ray.client.AppsV1Api.create_namespaced_daemon_set") - @patch("ray_provider.hooks.ray.config.load_kube_config") + @patch("ray_provider.hooks.KubernetesHook.get_connection") + @patch("ray_provider.hooks.KubernetesHook.__init__") + @patch("ray_provider.hooks.client.AppsV1Api.create_namespaced_daemon_set") + @patch("ray_provider.hooks.config.load_kube_config") def test_create_daemon_set_no_body( self, mock_load_kube_config, mock_create_daemon_set, mock_kubernetes_init, mock_get_connection ): @@ -505,10 +505,10 @@ def test_create_daemon_set_no_body( assert daemon_set is None - @patch("ray_provider.hooks.ray.KubernetesHook.get_connection") - @patch("ray_provider.hooks.ray.KubernetesHook.__init__") - @patch("ray_provider.hooks.ray.client.AppsV1Api.create_namespaced_daemon_set") - @patch("ray_provider.hooks.ray.config.load_kube_config") + @patch("ray_provider.hooks.KubernetesHook.get_connection") + @patch("ray_provider.hooks.KubernetesHook.__init__") + @patch("ray_provider.hooks.client.AppsV1Api.create_namespaced_daemon_set") + @patch("ray_provider.hooks.config.load_kube_config") def test_create_daemon_set_exception( self, mock_load_kube_config, mock_create_daemon_set, mock_kubernetes_init, mock_get_connection ): @@ -522,10 +522,10 @@ def test_create_daemon_set_exception( assert daemon_set is None - @patch("ray_provider.hooks.ray.KubernetesHook.get_connection") - @patch("ray_provider.hooks.ray.KubernetesHook.__init__") - @patch("ray_provider.hooks.ray.client.AppsV1Api.delete_namespaced_daemon_set") - @patch("ray_provider.hooks.ray.config.load_kube_config") + @patch("ray_provider.hooks.KubernetesHook.get_connection") + @patch("ray_provider.hooks.KubernetesHook.__init__") + @patch("ray_provider.hooks.client.AppsV1Api.delete_namespaced_daemon_set") + @patch("ray_provider.hooks.config.load_kube_config") def test_delete_daemon_set( self, mock_load_kube_config, mock_delete_daemon_set, mock_kubernetes_init, mock_get_connection ): @@ -538,10 +538,10 @@ def test_delete_daemon_set( assert response.status == "Success" - @patch("ray_provider.hooks.ray.KubernetesHook.get_connection") - @patch("ray_provider.hooks.ray.KubernetesHook.__init__") - @patch("ray_provider.hooks.ray.client.AppsV1Api.delete_namespaced_daemon_set") - @patch("ray_provider.hooks.ray.config.load_kube_config") + @patch("ray_provider.hooks.KubernetesHook.get_connection") + @patch("ray_provider.hooks.KubernetesHook.__init__") + @patch("ray_provider.hooks.client.AppsV1Api.delete_namespaced_daemon_set") + @patch("ray_provider.hooks.config.load_kube_config") def test_delete_daemon_set_not_found( self, mock_load_kube_config, mock_delete_daemon_set, mock_kubernetes_init, mock_get_connection ): @@ -554,10 +554,10 @@ def test_delete_daemon_set_not_found( assert response is None - @patch("ray_provider.hooks.ray.KubernetesHook.get_connection") - @patch("ray_provider.hooks.ray.KubernetesHook.__init__") - @patch("ray_provider.hooks.ray.client.AppsV1Api.delete_namespaced_daemon_set") - @patch("ray_provider.hooks.ray.config.load_kube_config") + @patch("ray_provider.hooks.KubernetesHook.get_connection") + @patch("ray_provider.hooks.KubernetesHook.__init__") + @patch("ray_provider.hooks.client.AppsV1Api.delete_namespaced_daemon_set") + @patch("ray_provider.hooks.config.load_kube_config") def test_delete_daemon_set_exception( self, mock_load_kube_config, mock_delete_daemon_set, mock_kubernetes_init, mock_get_connection ): @@ -570,8 +570,8 @@ def test_delete_daemon_set_exception( assert response is None - @patch("ray_provider.hooks.ray.KubernetesHook.get_connection") - @patch("ray_provider.hooks.ray.KubernetesHook.__init__") + @patch("ray_provider.hooks.KubernetesHook.get_connection") + @patch("ray_provider.hooks.KubernetesHook.__init__") @patch("os.path.isfile") def test_validate_yaml_file_not_found(self, mock_is_file, mock_kubernetes_init, mock_get_connection): mock_kubernetes_init.return_value = None @@ -584,8 +584,8 @@ def test_validate_yaml_file_not_found(self, mock_is_file, mock_kubernetes_init, assert "The specified YAML file does not exist" in str(exc_info.value) - @patch("ray_provider.hooks.ray.KubernetesHook.get_connection") - @patch("ray_provider.hooks.ray.KubernetesHook.__init__") + @patch("ray_provider.hooks.KubernetesHook.get_connection") + @patch("ray_provider.hooks.KubernetesHook.__init__") @patch("os.path.isfile") def test_validate_yaml_file_invalid_extension(self, mock_is_file, mock_kubernetes_init, mock_get_connection): mock_kubernetes_init.return_value = None @@ -598,13 +598,13 @@ def test_validate_yaml_file_invalid_extension(self, mock_is_file, mock_kubernete assert "The specified YAML file must have a .yaml or .yml extension" in str(exc_info.value) - @patch("ray_provider.hooks.ray.RayHook._validate_yaml_file") - @patch("ray_provider.hooks.ray.RayHook.install_kuberay_operator") - @patch("ray_provider.hooks.ray.RayHook.load_yaml_content") - @patch("ray_provider.hooks.ray.RayHook.get_custom_object") - @patch("ray_provider.hooks.ray.RayHook.create_custom_object") - @patch("ray_provider.hooks.ray.RayHook._setup_gpu_driver") - @patch("ray_provider.hooks.ray.RayHook._setup_load_balancer") + @patch("ray_provider.hooks.RayHook._validate_yaml_file") + @patch("ray_provider.hooks.RayHook.install_kuberay_operator") + @patch("ray_provider.hooks.RayHook.load_yaml_content") + @patch("ray_provider.hooks.RayHook.get_custom_object") + @patch("ray_provider.hooks.RayHook.create_custom_object") + @patch("ray_provider.hooks.RayHook._setup_gpu_driver") + @patch("ray_provider.hooks.RayHook._setup_load_balancer") def test_setup_ray_cluster_success( self, mock_setup_load_balancer, @@ -639,13 +639,13 @@ def test_setup_ray_cluster_success( mock_setup_gpu_driver.assert_called_once_with(gpu_device_plugin_yaml="gpu.yaml") mock_setup_load_balancer.assert_called_once() - @patch("ray_provider.hooks.ray.RayHook._validate_yaml_file") - @patch("ray_provider.hooks.ray.RayHook.uninstall_kuberay_operator") - @patch("ray_provider.hooks.ray.RayHook.load_yaml_content") - @patch("ray_provider.hooks.ray.RayHook.get_custom_object") - @patch("ray_provider.hooks.ray.RayHook.delete_custom_object") - @patch("ray_provider.hooks.ray.RayHook.get_daemon_set") - @patch("ray_provider.hooks.ray.RayHook.delete_daemon_set") + @patch("ray_provider.hooks.RayHook._validate_yaml_file") + @patch("ray_provider.hooks.RayHook.uninstall_kuberay_operator") + @patch("ray_provider.hooks.RayHook.load_yaml_content") + @patch("ray_provider.hooks.RayHook.get_custom_object") + @patch("ray_provider.hooks.RayHook.delete_custom_object") + @patch("ray_provider.hooks.RayHook.get_daemon_set") + @patch("ray_provider.hooks.RayHook.delete_daemon_set") def test_delete_ray_cluster_success( self, mock_delete_daemon_set, @@ -677,15 +677,15 @@ def test_delete_ray_cluster_success( mock_delete_custom_object.assert_called_once() mock_uninstall_kuberay_operator.assert_called_once() - @patch("ray_provider.hooks.ray.JobSubmissionClient") + @patch("ray_provider.hooks.JobSubmissionClient") def test_ray_client_exception(self, mock_job_client, ray_hook): mock_job_client.side_effect = Exception("Connection failed") with pytest.raises(AirflowException) as exc_info: ray_hook.ray_client() assert str(exc_info.value) == "Failed to create Ray JobSubmissionClient: Connection failed" - @patch("ray_provider.hooks.ray.RayHook.get_custom_object") - @patch("ray_provider.hooks.ray.RayHook.create_custom_object") + @patch("ray_provider.hooks.RayHook.get_custom_object") + @patch("ray_provider.hooks.RayHook.create_custom_object") def test_create_or_update_cluster_exception(self, mock_create, mock_get, ray_hook): mock_get.side_effect = client.exceptions.ApiException(status=500, reason="Internal Server Error") with pytest.raises(AirflowException) as exc_info: @@ -700,8 +700,8 @@ def test_create_or_update_cluster_exception(self, mock_create, mock_get, ray_hoo ) assert "Error accessing Ray cluster 'test-cluster'" in str(exc_info.value) - @patch("ray_provider.hooks.ray.RayHook.get_custom_object") - @patch("ray_provider.hooks.ray.RayHook.custom_object_client") + @patch("ray_provider.hooks.RayHook.get_custom_object") + @patch("ray_provider.hooks.RayHook.custom_object_client") def test_create_or_update_cluster_update(self, mock_client, mock_get, ray_hook): mock_get.return_value = {"metadata": {"name": "test-cluster"}} ray_hook._create_or_update_cluster( @@ -722,12 +722,12 @@ def test_create_or_update_cluster_update(self, mock_client, mock_get, ray_hook): body={"spec": {"some": "config"}}, ) - @patch("ray_provider.hooks.ray.RayHook._validate_yaml_file") - @patch("ray_provider.hooks.ray.RayHook.install_kuberay_operator") - @patch("ray_provider.hooks.ray.RayHook.load_yaml_content") - @patch("ray_provider.hooks.ray.RayHook._create_or_update_cluster") - @patch("ray_provider.hooks.ray.RayHook._setup_gpu_driver") - @patch("ray_provider.hooks.ray.RayHook._setup_load_balancer") + @patch("ray_provider.hooks.RayHook._validate_yaml_file") + @patch("ray_provider.hooks.RayHook.install_kuberay_operator") + @patch("ray_provider.hooks.RayHook.load_yaml_content") + @patch("ray_provider.hooks.RayHook._create_or_update_cluster") + @patch("ray_provider.hooks.RayHook._setup_gpu_driver") + @patch("ray_provider.hooks.RayHook._setup_load_balancer") def test_setup_ray_cluster_exception( self, mock_setup_lb, @@ -750,13 +750,13 @@ def test_setup_ray_cluster_exception( ) assert "Failed to set up Ray cluster: Cluster creation failed" in str(exc_info.value) - @patch("ray_provider.hooks.ray.RayHook._validate_yaml_file") - @patch("ray_provider.hooks.ray.RayHook.load_yaml_content") - @patch("ray_provider.hooks.ray.RayHook.get_custom_object") - @patch("ray_provider.hooks.ray.RayHook.delete_custom_object") - @patch("ray_provider.hooks.ray.RayHook.get_daemon_set") - @patch("ray_provider.hooks.ray.RayHook.delete_daemon_set") - @patch("ray_provider.hooks.ray.RayHook.uninstall_kuberay_operator") + @patch("ray_provider.hooks.RayHook._validate_yaml_file") + @patch("ray_provider.hooks.RayHook.load_yaml_content") + @patch("ray_provider.hooks.RayHook.get_custom_object") + @patch("ray_provider.hooks.RayHook.delete_custom_object") + @patch("ray_provider.hooks.RayHook.get_daemon_set") + @patch("ray_provider.hooks.RayHook.delete_daemon_set") + @patch("ray_provider.hooks.RayHook.uninstall_kuberay_operator") def test_delete_ray_cluster_exception( self, mock_uninstall_operator, diff --git a/tests/operators/test_ray_operators.py b/tests/test_operators.py similarity index 96% rename from tests/operators/test_ray_operators.py rename to tests/test_operators.py index 11d3d0c..72df6b5 100644 --- a/tests/operators/test_ray_operators.py +++ b/tests/test_operators.py @@ -5,14 +5,14 @@ from airflow.exceptions import AirflowException, TaskDeferred from ray.job_submission import JobStatus -from ray_provider.operators.ray import DeleteRayCluster, SetupRayCluster, SubmitRayJob -from ray_provider.triggers.ray import RayJobTrigger +from ray_provider.operators import DeleteRayCluster, SetupRayCluster, SubmitRayJob +from ray_provider.triggers import RayJobTrigger class TestSetupRayCluster: @pytest.fixture def mock_hook(self): - with patch("ray_provider.operators.ray.RayHook") as mock: + with patch("ray_provider.operators.RayHook") as mock: yield mock.return_value @pytest.fixture @@ -48,7 +48,7 @@ def test_init_default_values(self): assert operator.update_if_exists is False def test_hook_property(self, operator): - with patch("ray_provider.operators.ray.RayHook") as mock_ray_hook: + with patch("ray_provider.operators.RayHook") as mock_ray_hook: hook = operator.hook mock_ray_hook.assert_called_once_with(conn_id=operator.conn_id) assert hook == mock_ray_hook.return_value @@ -68,7 +68,7 @@ def test_execute(self, operator, mock_hook): class TestDeleteRayCluster: @pytest.fixture def mock_hook(self): - with patch("ray_provider.operators.ray.RayHook") as mock: + with patch("ray_provider.operators.RayHook") as mock: yield mock.return_value @pytest.fixture @@ -98,7 +98,7 @@ def test_init_default_gpu_plugin(self): ) def test_hook_property(self, operator): - with patch("ray_provider.operators.ray.RayHook") as mock_ray_hook: + with patch("ray_provider.operators.RayHook") as mock_ray_hook: hook = operator.hook mock_ray_hook.assert_called_once_with(conn_id=operator.conn_id) assert hook == mock_ray_hook.return_value @@ -113,7 +113,7 @@ class TestSubmitRayJob: @pytest.fixture def mock_hook(self): - with patch("ray_provider.operators.ray.RayHook") as mock: + with patch("ray_provider.operators.RayHook") as mock: yield mock.return_value @pytest.fixture @@ -228,7 +228,7 @@ def test_get_dashboard_url_without_xcom(self, context): assert result is None - @patch("ray_provider.operators.ray.RayHook") + @patch("ray_provider.operators.RayHook") def test_setup_cluster(self, mock_ray_hook, context): operator = SubmitRayJob( task_id="test_task", @@ -254,7 +254,7 @@ def test_setup_cluster(self, mock_ray_hook, context): update_if_exists=True, ) - @patch("ray_provider.operators.ray.RayHook") + @patch("ray_provider.operators.RayHook") def test_delete_cluster(self, mock_ray_hook): operator = SubmitRayJob( task_id="test_task", @@ -371,7 +371,7 @@ def test_template_fields(self): "job_timeout_seconds", ) - @patch("ray_provider.operators.ray.RayHook") + @patch("ray_provider.operators.RayHook") def test_setup_cluster_exception(self, mock_ray_hook, context): operator = SubmitRayJob( task_id="test_task", @@ -392,7 +392,7 @@ def test_setup_cluster_exception(self, mock_ray_hook, context): assert str(exc_info.value) == "Cluster setup failed" mock_hook.setup_ray_cluster.assert_called_once() - @patch("ray_provider.operators.ray.RayHook") + @patch("ray_provider.operators.RayHook") def test_delete_cluster_exception(self, mock_ray_hook): operator = SubmitRayJob( task_id="test_task", diff --git a/tests/triggers/test_ray_triggers.py b/tests/test_triggers.py similarity index 85% rename from tests/triggers/test_ray_triggers.py rename to tests/test_triggers.py index f82b521..f97611e 100644 --- a/tests/triggers/test_ray_triggers.py +++ b/tests/test_triggers.py @@ -5,7 +5,7 @@ from airflow.triggers.base import TriggerEvent from ray.job_submission import JobStatus -from ray_provider.triggers.ray import RayJobTrigger +from ray_provider.triggers import RayJobTrigger class TestRayJobTrigger: @@ -22,8 +22,8 @@ def trigger(self): ) @pytest.mark.asyncio - @patch("ray_provider.triggers.ray.RayJobTrigger._is_terminal_state") - @patch("ray_provider.triggers.ray.RayJobTrigger.hook") + @patch("ray_provider.triggers.RayJobTrigger._is_terminal_state") + @patch("ray_provider.triggers.RayJobTrigger.hook") async def test_run_no_job_id(self, mock_hook, mock_is_terminal): mock_is_terminal.return_value = True mock_hook.get_ray_job_status.return_value = JobStatus.FAILED @@ -42,8 +42,8 @@ async def test_run_no_job_id(self, mock_hook, mock_is_terminal): ) @pytest.mark.asyncio - @patch("ray_provider.triggers.ray.RayJobTrigger._is_terminal_state") - @patch("ray_provider.triggers.ray.RayJobTrigger.hook") + @patch("ray_provider.triggers.RayJobTrigger._is_terminal_state") + @patch("ray_provider.triggers.RayJobTrigger.hook") async def test_run_job_succeeded(self, mock_hook, mock_is_terminal): mock_is_terminal.side_effect = [False, True] mock_hook.get_ray_job_status.return_value = JobStatus.SUCCEEDED @@ -66,8 +66,8 @@ async def test_run_job_succeeded(self, mock_hook, mock_is_terminal): ) @pytest.mark.asyncio - @patch("ray_provider.triggers.ray.RayJobTrigger._is_terminal_state") - @patch("ray_provider.triggers.ray.RayJobTrigger.hook") + @patch("ray_provider.triggers.RayJobTrigger._is_terminal_state") + @patch("ray_provider.triggers.RayJobTrigger.hook") async def test_run_job_stopped(self, mock_hook, mock_is_terminal, trigger): mock_is_terminal.side_effect = [False, True] mock_hook.get_ray_job_status.return_value = JobStatus.STOPPED @@ -84,8 +84,8 @@ async def test_run_job_stopped(self, mock_hook, mock_is_terminal, trigger): ) @pytest.mark.asyncio - @patch("ray_provider.triggers.ray.RayJobTrigger._is_terminal_state") - @patch("ray_provider.triggers.ray.RayJobTrigger.hook") + @patch("ray_provider.triggers.RayJobTrigger._is_terminal_state") + @patch("ray_provider.triggers.RayJobTrigger.hook") async def test_run_job_failed(self, mock_hook, mock_is_terminal, trigger): mock_is_terminal.side_effect = [False, True] mock_hook.get_ray_job_status.return_value = JobStatus.FAILED @@ -102,9 +102,9 @@ async def test_run_job_failed(self, mock_hook, mock_is_terminal, trigger): ) @pytest.mark.asyncio - @patch("ray_provider.triggers.ray.RayJobTrigger._is_terminal_state") - @patch("ray_provider.triggers.ray.RayJobTrigger.hook") - @patch("ray_provider.triggers.ray.RayJobTrigger._stream_logs") + @patch("ray_provider.triggers.RayJobTrigger._is_terminal_state") + @patch("ray_provider.triggers.RayJobTrigger.hook") + @patch("ray_provider.triggers.RayJobTrigger._stream_logs") async def test_run_with_log_streaming(self, mock_stream_logs, mock_hook, mock_is_terminal, trigger): mock_is_terminal.side_effect = [False, True] mock_hook.get_ray_job_status.return_value = JobStatus.SUCCEEDED @@ -123,7 +123,7 @@ async def test_run_with_log_streaming(self, mock_stream_logs, mock_hook, mock_is ) @pytest.mark.asyncio - @patch("ray_provider.triggers.ray.RayJobTrigger.hook") + @patch("ray_provider.triggers.RayJobTrigger.hook") async def test_stream_logs(self, mock_hook, trigger): # Create a mock async iterator async def mock_async_iterator(): @@ -133,7 +133,7 @@ async def mock_async_iterator(): # Set up the mock to return an async iterator mock_hook.get_ray_tail_logs.return_value = mock_async_iterator() - with patch("ray_provider.triggers.ray.RayJobTrigger.log") as mock_log: + with patch("ray_provider.triggers.RayJobTrigger.log") as mock_log: await trigger._stream_logs() mock_log.info.assert_any_call("::group::test_job_id logs") @@ -144,7 +144,7 @@ async def mock_async_iterator(): def test_serialize(self, trigger): serialized = trigger.serialize() assert serialized == ( - "ray_provider.triggers.ray.RayJobTrigger", + "ray_provider.triggers.RayJobTrigger", { "job_id": "test_job_id", "conn_id": "test_conn", @@ -157,7 +157,7 @@ def test_serialize(self, trigger): ) @pytest.mark.asyncio - @patch("ray_provider.triggers.ray.RayJobTrigger.hook") + @patch("ray_provider.triggers.RayJobTrigger.hook") async def test_is_terminal_state(self, mock_hook, trigger): mock_hook.get_ray_job_status.side_effect = [ JobStatus.PENDING, @@ -212,7 +212,7 @@ async def test_cleanup_with_exception(self, mock_log_error, mock_hook, trigger): @pytest.mark.asyncio @patch("asyncio.sleep", new_callable=AsyncMock) - @patch("ray_provider.triggers.ray.RayJobTrigger._is_terminal_state") + @patch("ray_provider.triggers.RayJobTrigger._is_terminal_state") async def test_poll_status(self, mock_is_terminal, mock_sleep, trigger): mock_is_terminal.side_effect = [False, False, True] @@ -222,9 +222,9 @@ async def test_poll_status(self, mock_is_terminal, mock_sleep, trigger): mock_sleep.assert_called_with(1) @pytest.mark.asyncio - @patch("ray_provider.triggers.ray.RayJobTrigger._is_terminal_state") - @patch("ray_provider.triggers.ray.RayJobTrigger.hook") - @patch("ray_provider.triggers.ray.RayJobTrigger.cleanup") + @patch("ray_provider.triggers.RayJobTrigger._is_terminal_state") + @patch("ray_provider.triggers.RayJobTrigger.hook") + @patch("ray_provider.triggers.RayJobTrigger.cleanup") async def test_run_with_exception(self, mock_cleanup, mock_hook, mock_is_terminal, trigger): mock_is_terminal.side_effect = Exception("Test exception") diff --git a/tests/triggers/__init__.py b/tests/triggers/__init__.py deleted file mode 100644 index e69de29..0000000