diff --git a/dev/dags/ray_dynamic_config.py b/dev/dags/ray_dynamic_config.py new file mode 100644 index 0000000..046b60b --- /dev/null +++ b/dev/dags/ray_dynamic_config.py @@ -0,0 +1,197 @@ +""" +This example illustrates three DAGs. One + +The parent DAG (ray_dynamic_config_upstream_dag) uses TriggerDagRunOperator to trigger the other two: +* ray_dynamic_config_downstream_dag_1 +* ray_dynamic_config_downstream_dag_2 + +Each downstream DAG retrieves the context data (run_context) from dag_run.conf, which is passed by the parent DAG. + +The print_context tasks in the downstream DAGs output the received context to the logs. +""" + +import re +from pathlib import Path + +import yaml +from airflow import DAG +from airflow.decorators import task +from airflow.operators.empty import EmptyOperator +from airflow.operators.python import PythonOperator +from airflow.operators.trigger_dagrun import TriggerDagRunOperator +from airflow.utils.dates import days_ago +from jinja2 import Template + +from ray_provider.decorators import ray + +CONN_ID = "ray_conn" +RAY_SPEC = Path(__file__).parent / "scripts/ray.yaml" +FOLDER_PATH = Path(__file__).parent / "ray_scripts" +RAY_TASK_CONFIG = { + "conn_id": CONN_ID, + "runtime_env": {"working_dir": str(FOLDER_PATH), "pip": ["numpy"]}, + "num_cpus": 1, + "num_gpus": 0, + "memory": 0, + "poll_interval": 5, + "ray_cluster_yaml": str(RAY_SPEC), + "xcom_task_key": "dashboard", +} + + +def slugify(value): + """ + Replace invalid characters with hyphens and make lowercase. + """ + return re.sub(r"[^\w\-\.]", "-", value).lower() + + +def create_config_from_context(context, **kwargs): + default_name = "{{ dag.dag_id }}-{{ dag_run.id }}" + + raycluster_name_template = context.get("dag_run").conf.get("raycluster_name", default_name) + raycluster_name = Template(raycluster_name_template).render(context).replace("_", "-") + raycluster_name = slugify(raycluster_name) + + raycluster_k8s_yml_filename_template = context.get("dag_run").conf.get( + "raycluster_k8s_yml_filename", default_name + ".yml" + ) + raycluster_k8s_yml_filename = Template(raycluster_k8s_yml_filename_template).render(context).replace("_", "-") + raycluster_k8s_yml_filename = slugify(raycluster_k8s_yml_filename) + + with open(RAY_SPEC) as file: + data = yaml.safe_load(file) + data["metadata"]["name"] = raycluster_name + + NEW_RAY_K8S_SPEC = Path(__file__).parent / "scripts" / raycluster_k8s_yml_filename + with open(NEW_RAY_K8S_SPEC, "w") as file: + yaml.safe_dump(data, file, default_flow_style=False) + + config = dict(RAY_TASK_CONFIG) + config["ray_cluster_yaml"] = str(NEW_RAY_K8S_SPEC) + return config + + +def print_context(**kwargs): + # Retrieve `conf` passed from the parent DAG + print(kwargs) + cluster_name = kwargs.get("dag_run").conf.get("raycluster_name", "No ray cluster name provided") + raycluster_k8s_yml_filename = kwargs.get("dag_run").conf.get( + "raycluster_k8s_yml_filename", "No ray cluster YML filename provided" + ) + print(f"Received cluster name: {cluster_name}") + print(f"Received cluster K8s YML filename: {raycluster_k8s_yml_filename}") + + +# Downstream 1 +with DAG( + dag_id="ray_dynamic_config_child_1", + start_date=days_ago(1), + schedule_interval=None, + catchup=False, +) as dag: + + print_context_task = PythonOperator( + task_id="print_context", + python_callable=print_context, + ) + print_context_task + + @task + def generate_data(): + return [1, 2, 3] + + @ray.task(config=create_config_from_context) + def process_data_with_ray(data): + import numpy as np + import ray + + @ray.remote + def cubic(x): + return x**3 + + ray.init() + data = np.array(data) + futures = [cubic.remote(x) for x in data] + results = ray.get(futures) + mean = np.mean(results) + print(f"Mean of this population is {mean}") + return mean + + data = generate_data() + process_data_with_ray(data) + + +# Downstream 2 +with DAG( + dag_id="ray_dynamic_config_child_2", + start_date=days_ago(1), + schedule_interval=None, + catchup=False, +) as dag: + + print_context_task = PythonOperator( + task_id="print_context", + python_callable=print_context, + ) + + @task + def generate_data(): + return [1, 2, 3] + + @ray.task(config=create_config_from_context) + def process_data_with_ray(data): + import numpy as np + import ray + + @ray.remote + def square(x): + return x**2 + + ray.init() + data = np.array(data) + futures = [square.remote(x) for x in data] + results = ray.get(futures) + mean = np.mean(results) + print(f"Mean of this population is {mean}") + return mean + + data = generate_data() + process_data_with_ray(data) + + +# Upstream +with DAG( + dag_id="ray_dynamic_config_parent", + start_date=days_ago(1), + schedule_interval=None, + catchup=False, +) as dag: + empty_task = EmptyOperator(task_id="empty_task") + + trigger_dag_1 = TriggerDagRunOperator( + task_id="trigger_downstream_dag_1", + trigger_dag_id="ray_dynamic_config_child_1", + conf={ + "raycluster_name": "first-{{ dag_run.id }}", + "raycluster_k8s_yml_filename": "first-{{ dag_run.id }}.yaml", + }, + ) + + trigger_dag_2 = TriggerDagRunOperator( + task_id="trigger_downstream_dag_2", + trigger_dag_id="ray_dynamic_config_child_2", + conf={}, + ) + + # Illustrates that by default two DAG runs of the same DAG will be using different Ray clusters + # Disabled because in the local dev MacOS we're only managing to spin up two Ray Cluster services concurrently + # trigger_dag_3 = TriggerDagRunOperator( + # task_id="trigger_downstream_dag_3", + # trigger_dag_id="ray_dynamic_config_child_2", + # conf={}, + # ) + + empty_task >> trigger_dag_1 + trigger_dag_1 >> trigger_dag_2 + # trigger_dag_1 >> trigger_dag_3 diff --git a/ray_provider/decorators.py b/ray_provider/decorators.py index bcb150d..ed2f799 100644 --- a/ray_provider/decorators.py +++ b/ray_provider/decorators.py @@ -3,15 +3,16 @@ import inspect import os import re -import shutil +import tempfile import textwrap -from tempfile import mkdtemp +from datetime import timedelta +from pathlib import Path from typing import Any, Callable from airflow.decorators.base import DecoratedOperator, TaskDecorator, task_decorator_factory -from airflow.exceptions import AirflowException from airflow.utils.context import Context +from ray_provider.exceptions import RayAirflowException from ray_provider.operators import SubmitRayJob @@ -28,10 +29,27 @@ class _RayDecoratedOperator(DecoratedOperator, SubmitRayJob): """ custom_operator_name = "@task.ray" + _config: dict[str, Any] | Callable[..., dict[str, Any]] = dict() template_fields: Any = (*SubmitRayJob.template_fields, "op_args", "op_kwargs") - def __init__(self, config: dict[str, Any], **kwargs: Any) -> None: + def __init__(self, config: dict[str, Any] | Callable[..., dict[str, Any]], **kwargs: Any) -> None: + self._config = config + self.kwargs = kwargs + super().__init__(conn_id="", entrypoint="python script.py", runtime_env={}, **kwargs) + + def _build_config(self, context: Context) -> dict[str, Any]: + if callable(self._config): + config_params = inspect.signature(self._config).parameters + config_kwargs = {k: v for k, v in self.kwargs.items() if k in config_params and k != "context"} + if "context" in config_params: + config_kwargs["context"] = context + config = self._config(**config_kwargs) + assert isinstance(config, dict) + return config + return self._config + + def _load_config(self, config: dict[str, Any]) -> None: self.conn_id: str = config.get("conn_id", "") self.is_decorated_function = False if "entrypoint" in config else True self.entrypoint: str = config.get("entrypoint", "python script.py") @@ -39,9 +57,9 @@ def __init__(self, config: dict[str, Any], **kwargs: Any) -> None: self.num_cpus: int | float = config.get("num_cpus", 1) self.num_gpus: int | float = config.get("num_gpus", 0) - self.memory: int | float = config.get("memory", None) - self.ray_resources: dict[str, Any] | None = config.get("resources", None) - self.ray_cluster_yaml: str | None = config.get("ray_cluster_yaml", None) + self.memory: int | float = config.get("memory", 1) + self.ray_resources: dict[str, Any] | None = config.get("resources") + self.ray_cluster_yaml: str | None = config.get("ray_cluster_yaml") self.update_if_exists: bool = config.get("update_if_exists", False) self.kuberay_version: str = config.get("kuberay_version", "1.0.0") self.gpu_device_plugin_yaml: str = config.get( @@ -50,35 +68,19 @@ def __init__(self, config: dict[str, Any], **kwargs: Any) -> None: ) self.fetch_logs: bool = config.get("fetch_logs", True) self.wait_for_completion: bool = config.get("wait_for_completion", True) - job_timeout_seconds: int = config.get("job_timeout_seconds", 600) + job_timeout_seconds = config.get("job_timeout_seconds", 600) + self.job_timeout_seconds: timedelta | None = ( + timedelta(seconds=job_timeout_seconds) if job_timeout_seconds > 0 else None + ) self.poll_interval: int = config.get("poll_interval", 60) - self.xcom_task_key: str | None = config.get("xcom_task_key", None) + self.xcom_task_key: str | None = config.get("xcom_task_key") + self.config = config if not isinstance(self.num_cpus, (int, float)): - raise TypeError("num_cpus should be an integer or float value") + raise RayAirflowException("num_cpus should be an integer or float value") if not isinstance(self.num_gpus, (int, float)): - raise TypeError("num_gpus should be an integer or float value") - - super().__init__( - conn_id=self.conn_id, - entrypoint=self.entrypoint, - runtime_env=self.runtime_env, - num_cpus=self.num_cpus, - num_gpus=self.num_gpus, - memory=self.memory, - resources=self.ray_resources, - ray_cluster_yaml=self.ray_cluster_yaml, - update_if_exists=self.update_if_exists, - kuberay_version=self.kuberay_version, - gpu_device_plugin_yaml=self.gpu_device_plugin_yaml, - fetch_logs=self.fetch_logs, - wait_for_completion=self.wait_for_completion, - job_timeout_seconds=job_timeout_seconds, - poll_interval=self.poll_interval, - xcom_task_key=self.xcom_task_key, - **kwargs, - ) + raise RayAirflowException("num_gpus should be an integer or float value") def execute(self, context: Context) -> Any: """ @@ -88,21 +90,21 @@ def execute(self, context: Context) -> Any: :return: The result of the Ray job execution. :raises AirflowException: If job submission fails. """ - tmp_dir = None - try: + config = self._build_config(context) + self.log.info(f"Using the following config {config}") + self._load_config(config) + + with tempfile.TemporaryDirectory(prefix="ray_") as tmpdirname: + temp_dir = Path(tmpdirname) + if self.is_decorated_function: self.log.info( f"Entrypoint is not provided, is_decorated_function is set to {self.is_decorated_function}" ) - # Create a temporary directory that won't be immediately deleted - temp_dir = mkdtemp(prefix="ray_") - script_filename = os.path.join(temp_dir, "script.py") # Get the Python source code and extract just the function body full_source = inspect.getsource(self.python_callable) function_body = self._extract_function_body(full_source) - if not function_body: - raise ValueError("Failed to retrieve Python source code") # Prepare the function call args_str = ", ".join(repr(arg) for arg in self.op_args) @@ -110,6 +112,7 @@ def execute(self, context: Context) -> Any: call_str = f"{self.python_callable.__name__}({args_str}, {kwargs_str})" # Write the script with function definition and call + script_filename = os.path.join(temp_dir, "script.py") with open(script_filename, "w") as file: file.write(function_body) file.write(f"\n\n# Execute the function\n{call_str}\n") @@ -122,21 +125,27 @@ def execute(self, context: Context) -> Any: result = super().execute(context) return result - except Exception as e: - self.log.error(f"Failed during execution with error: {e}") - raise AirflowException("Job submission failed") from e - finally: - if tmp_dir and os.path.exists(tmp_dir): - shutil.rmtree(tmp_dir) def _extract_function_body(self, source: str) -> str: """Extract the function, excluding only the ray.task decorator.""" + self.log.info(r"Ray pipeline intended to be executed: \n %s", source) + if "@ray.task" not in source: + raise RayAirflowException("Unable to parse this body. Expects the `@ray.task` decorator.") lines = source.split("\n") + # TODO: Review the current approach, that is quite hacky. + # It feels a mistake to have a user-facing module named the same as the official ray SDK. + # In particular, the decorator is working in a very artificial way, where ray means two different things + # at the scope of the task definition (Astro Ray Provider decorator) and inside the decorated method (Ray SDK) # Find the line where the ray.task decorator is + # Additionally, if users imported the ray decorator as "from ray_provider.decorators import ray as ray_decorator + # The following will stop working. ray_task_line = next((i for i, line in enumerate(lines) if re.match(r"^\s*@ray\.task", line.strip())), -1) # Include everything except the ray.task decorator line body = "\n".join(lines[:ray_task_line] + lines[ray_task_line + 1 :]) + + if not body: + raise RayAirflowException("Failed to extract Ray pipeline code decorated with @ray.task") # Dedent the body return textwrap.dedent(body) @@ -146,6 +155,7 @@ class ray: def task( python_callable: Callable[..., Any] | None = None, multiple_outputs: bool | None = None, + config: dict[str, Any] | Callable[[], dict[str, Any]] | None = None, **kwargs: Any, ) -> TaskDecorator: """ @@ -153,12 +163,15 @@ def task( :param python_callable: The callable function to decorate. :param multiple_outputs: If True, will return multiple outputs. + :param config: A dictionary of configuration or a callable that returns a dictionary. :param kwargs: Additional keyword arguments. :return: The decorated task. """ + config = config or {} return task_decorator_factory( python_callable=python_callable, multiple_outputs=multiple_outputs, decorated_operator_class=_RayDecoratedOperator, + config=config, **kwargs, ) diff --git a/tests/test_decorators.py b/tests/test_decorators.py index a6e6b15..70808a3 100644 --- a/tests/test_decorators.py +++ b/tests/test_decorators.py @@ -2,10 +2,10 @@ from unittest.mock import MagicMock, patch import pytest -from airflow.exceptions import AirflowException from airflow.utils.context import Context from ray_provider.decorators import _RayDecoratedOperator, ray +from ray_provider.exceptions import RayAirflowException class TestRayDecoratedOperator: @@ -29,6 +29,7 @@ def dummy_callable(): pass operator = _RayDecoratedOperator(task_id="test_task", config=config, python_callable=dummy_callable) + operator._load_config(config) assert operator.conn_id == "ray_default" assert operator.entrypoint == "python my_script.py" @@ -50,13 +51,14 @@ def dummy_callable(): pass operator = _RayDecoratedOperator(task_id="test_task", config=config, python_callable=dummy_callable) + operator._load_config(config) assert operator.conn_id == "" assert operator.entrypoint == "python script.py" assert operator.runtime_env == {} assert operator.num_cpus == 1 assert operator.num_gpus == 0 - assert operator.memory is None + assert operator.memory == 1 assert operator.ray_resources is None assert operator.fetch_logs == True assert operator.wait_for_completion == True @@ -64,6 +66,18 @@ def dummy_callable(): assert operator.poll_interval == 60 assert operator.xcom_task_key is None + def test_callable_config(self): + def dummy_callable(): + pass + + callable_config = lambda context: {"ray_cluster_yaml": "different.yml"} + + operator = _RayDecoratedOperator(task_id="test_task", config=callable_config, python_callable=dummy_callable) + new_config = operator._build_config(context={}) + operator._load_config(new_config) + + assert operator.ray_cluster_yaml == "different.yml" + def test_invalid_config_raises_exception(self): config = { "num_cpus": "invalid_number", @@ -72,13 +86,16 @@ def test_invalid_config_raises_exception(self): def dummy_callable(): pass - with pytest.raises(TypeError): - _RayDecoratedOperator(task_id="test_task", config=config, python_callable=dummy_callable) + operator = _RayDecoratedOperator(task_id="test_task", config=config, python_callable=dummy_callable) + + with pytest.raises(RayAirflowException): + operator._load_config(config) config["num_cpus"] = 1 config["num_gpus"] = "invalid_number" - with pytest.raises(TypeError): - _RayDecoratedOperator(task_id="test_task", config=config, python_callable=dummy_callable) + + with pytest.raises(RayAirflowException): + operator._load_config(config) @patch.object(_RayDecoratedOperator, "_extract_function_body") @patch("ray_provider.decorators.SubmitRayJob.execute") @@ -130,7 +147,7 @@ def dummy_callable(): operator = _RayDecoratedOperator(task_id="test_task", config=config, python_callable=dummy_callable) mock_super_execute.side_effect = Exception("Ray job failed") - with pytest.raises(AirflowException): + with pytest.raises(Exception): operator.execute(context) def test_extract_function_body(self): @@ -155,6 +172,37 @@ def dummy_callable(): """ ) + def test_extract_function_body_invalid_body(self): + config = {} + + @ray.task() + def dummy_callable(): + return "dummy" + + operator = _RayDecoratedOperator(task_id="test_task", config=config, python_callable=dummy_callable) + + with pytest.raises(RayAirflowException) as exc_info: + operator._extract_function_body( + """@ray_decorator.task() + def dummy_callable(): + return "dummy" + """ + ) + assert str(exc_info.value) == "Unable to parse this body. Expects the `@ray.task` decorator." + + def test_extract_function_body_empty_body(self): + config = {} + + @ray.task() + def dummy_callable(): + return "dummy" + + operator = _RayDecoratedOperator(task_id="test_task", config=config, python_callable=dummy_callable) + + with pytest.raises(RayAirflowException) as exc_info: + operator._extract_function_body("""@ray.task()""") + assert str(exc_info.value) == "Failed to extract Ray pipeline code decorated with @ray.task" + class TestRayTaskDecorator: def test_ray_task_decorator(self):