Skip to content

Commit

Permalink
Support running Ray jobs indefinitely without timing out (#74)
Browse files Browse the repository at this point in the history
Customers have requested for this feature. It should be used at their own risk.
---------

Co-authored-by: Tatiana Al-Chueyr <[email protected]>
  • Loading branch information
venkatajagannath and tatiana authored Sep 27, 2024
1 parent 6d5a191 commit a900d43
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 8 deletions.
4 changes: 2 additions & 2 deletions ray_provider/decorators/ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(self, config: dict[str, Any], **kwargs: Any) -> None:
)
self.fetch_logs: bool = config.get("fetch_logs", True)
self.wait_for_completion: bool = config.get("wait_for_completion", True)
self.job_timeout_seconds: int = config.get("job_timeout_seconds", 600)
job_timeout_seconds: int = config.get("job_timeout_seconds", 600)
self.poll_interval: int = config.get("poll_interval", 60)
self.xcom_task_key: str | None = config.get("xcom_task_key", None)
self.config = config
Expand All @@ -74,7 +74,7 @@ def __init__(self, config: dict[str, Any], **kwargs: Any) -> None:
gpu_device_plugin_yaml=self.gpu_device_plugin_yaml,
fetch_logs=self.fetch_logs,
wait_for_completion=self.wait_for_completion,
job_timeout_seconds=self.job_timeout_seconds,
job_timeout_seconds=job_timeout_seconds,
poll_interval=self.poll_interval,
xcom_task_key=self.xcom_task_key,
**kwargs,
Expand Down
6 changes: 3 additions & 3 deletions ray_provider/operators/ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ class SubmitRayJob(BaseOperator):
:param gpu_device_plugin_yaml: URL or path to the GPU device plugin YAML file. Defaults to NVIDIA's plugin.
:param fetch_logs: Whether to fetch logs from the Ray job. Defaults to True.
:param wait_for_completion: Whether to wait for the job to complete before marking the task as finished. Defaults to True.
:param job_timeout_seconds: Maximum time to wait for job completion in seconds. Defaults to 600 seconds.
:param job_timeout_seconds: Maximum time to wait for job completion in seconds. Defaults to 600 seconds. Set to 0 if you want the job to run indefinitely without timeouts.
:param poll_interval: Interval between job status checks in seconds. Defaults to 60 seconds.
:param xcom_task_key: XCom key to retrieve the dashboard URL. Defaults to None.
"""
Expand Down Expand Up @@ -168,7 +168,7 @@ def __init__(
self.gpu_device_plugin_yaml = gpu_device_plugin_yaml
self.fetch_logs = fetch_logs
self.wait_for_completion = wait_for_completion
self.job_timeout_seconds = job_timeout_seconds
self.job_timeout_seconds = timedelta(seconds=job_timeout_seconds) if job_timeout_seconds > 0 else None
self.poll_interval = poll_interval
self.xcom_task_key = xcom_task_key
self.dashboard_url: str | None = None
Expand Down Expand Up @@ -294,7 +294,7 @@ def execute(self, context: Context) -> str:
fetch_logs=self.fetch_logs,
),
method_name="execute_complete",
timeout=timedelta(seconds=self.job_timeout_seconds),
timeout=self.job_timeout_seconds,
)
elif current_status == JobStatus.SUCCEEDED:
self.log.info("Job %s completed successfully", self.job_id)
Expand Down
5 changes: 3 additions & 2 deletions tests/decorators/test_ray_decorators.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from datetime import timedelta
from unittest.mock import MagicMock, patch

import pytest
Expand Down Expand Up @@ -38,7 +39,7 @@ def dummy_callable():
assert operator.ray_resources == {"custom_resource": 1}
assert operator.fetch_logs == True
assert operator.wait_for_completion == True
assert operator.job_timeout_seconds == 300
assert operator.job_timeout_seconds == timedelta(seconds=300)
assert operator.poll_interval == 30
assert operator.xcom_task_key == "ray_result"

Expand All @@ -59,7 +60,7 @@ def dummy_callable():
assert operator.ray_resources is None
assert operator.fetch_logs == True
assert operator.wait_for_completion == True
assert operator.job_timeout_seconds == 600
assert operator.job_timeout_seconds == timedelta(seconds=600)
assert operator.poll_interval == 60
assert operator.xcom_task_key is None

Expand Down
25 changes: 24 additions & 1 deletion tests/operators/test_ray_operators.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from datetime import timedelta
from unittest.mock import MagicMock, Mock, patch

import pytest
Expand Down Expand Up @@ -161,10 +162,32 @@ def test_init(self):
assert operator.gpu_device_plugin_yaml == "https://example.com/plugin.yml"
assert operator.fetch_logs == True
assert operator.wait_for_completion == True
assert operator.job_timeout_seconds == 1200
assert operator.job_timeout_seconds == timedelta(seconds=1200)
assert operator.poll_interval == 30
assert operator.xcom_task_key == "task.key"

def test_init_no_timeout(self):
operator = SubmitRayJob(
task_id="test_task",
conn_id="test_conn",
entrypoint="python script.py",
runtime_env={"pip": ["package1", "package2"]},
num_cpus=2,
num_gpus=1,
memory=1000,
resources={"custom_resource": 1},
ray_cluster_yaml="cluster.yaml",
kuberay_version="1.0.0",
update_if_exists=True,
gpu_device_plugin_yaml="https://example.com/plugin.yml",
fetch_logs=True,
wait_for_completion=True,
job_timeout_seconds=0,
poll_interval=30,
xcom_task_key="task.key",
)
assert operator.job_timeout_seconds is None

def test_on_kill(self, mock_hook):
operator = SubmitRayJob(task_id="test_task", conn_id="test_conn", entrypoint="python script.py", runtime_env={})
operator.job_id = "test_job_id"
Expand Down

0 comments on commit a900d43

Please sign in to comment.