diff --git a/dev/tests/dags/test_dag_example.py b/dev/tests/dags/test_dag_example.py index 689f1c5..9345ca7 100644 --- a/dev/tests/dags/test_dag_example.py +++ b/dev/tests/dags/test_dag_example.py @@ -53,19 +53,6 @@ def test_file_imports(rel_path, rv): raise Exception(f"{rel_path} failed to import with message \n {rv}") -APPROVED_TAGS = {} - - -@pytest.mark.parametrize("dag_id,dag,fileloc", get_dags(), ids=[x[2] for x in get_dags()]) -def test_dag_tags(dag_id, dag, fileloc): - """ - test if a DAG is tagged and if those TAGs are in the approved list - """ - assert dag.tags, f"{dag_id} in {fileloc} has no tags" - if APPROVED_TAGS: - assert not set(dag.tags) - APPROVED_TAGS - - @pytest.mark.parametrize("dag_id,dag, fileloc", get_dags(), ids=[x[2] for x in get_dags()]) def test_dag_retries(dag_id, dag, fileloc): """ diff --git a/ray_provider/decorators.py b/ray_provider/decorators.py index ed2f799..45079f4 100644 --- a/ray_provider/decorators.py +++ b/ray_provider/decorators.py @@ -62,10 +62,7 @@ def _load_config(self, config: dict[str, Any]) -> None: self.ray_cluster_yaml: str | None = config.get("ray_cluster_yaml") self.update_if_exists: bool = config.get("update_if_exists", False) self.kuberay_version: str = config.get("kuberay_version", "1.0.0") - self.gpu_device_plugin_yaml: str = config.get( - "gpu_device_plugin_yaml", - "https://raw.githubusercontent.com/NVIDIA/k8s-device-plugin/v0.9.0/nvidia-device-plugin.yml", - ) + self.gpu_device_plugin_yaml: str = config.get("gpu_device_plugin_yaml", "") self.fetch_logs: bool = config.get("fetch_logs", True) self.wait_for_completion: bool = config.get("wait_for_completion", True) job_timeout_seconds = config.get("job_timeout_seconds", 600) diff --git a/ray_provider/hooks.py b/ray_provider/hooks.py index 38c9132..bf3868f 100644 --- a/ray_provider/hooks.py +++ b/ray_provider/hooks.py @@ -416,16 +416,17 @@ def _create_or_update_cluster( def _setup_gpu_driver(self, gpu_device_plugin_yaml: str) -> None: """ - Set up the GPU device plugin if GPU is enabled. Defaults to NVIDIA's plugin + Set up the GPU device plugin if GPU is enabled. :param gpu_device_plugin_yaml: Path or URL to the GPU device plugin YAML. """ - gpu_driver = self.load_yaml_content(gpu_device_plugin_yaml) - gpu_driver_name = gpu_driver["metadata"]["name"] + if gpu_device_plugin_yaml: + gpu_driver = self.load_yaml_content(gpu_device_plugin_yaml) + gpu_driver_name = gpu_driver["metadata"]["name"] - if not self.get_daemon_set(gpu_driver_name): - self.log.info("Creating DaemonSet for NVIDIA device plugin...") - self.create_daemon_set(gpu_driver_name, gpu_driver) + if not self.get_daemon_set(gpu_driver_name): + self.log.info("Creating DaemonSet for GPU driver...") + self.create_daemon_set(gpu_driver_name, gpu_driver) def _setup_load_balancer(self, name: str, namespace: str, context: Context) -> None: """ @@ -460,7 +461,7 @@ def setup_ray_cluster( :param context: The Airflow task context. :param ray_cluster_yaml: Path to the YAML file defining the Ray cluster. :param kuberay_version: Version of KubeRay to install. - :param gpu_device_plugin_yaml: Path or URL to the GPU device plugin YAML. Defaults to NVIDIA's plugin + :param gpu_device_plugin_yaml: Path or URL to the GPU device plugin YAML. :param update_if_exists: Whether to update the cluster if it already exists. :raises AirflowException: If there's an error setting up the Ray cluster. """ @@ -535,18 +536,21 @@ def delete_ray_cluster(self, ray_cluster_yaml: str, gpu_device_plugin_yaml: str) Execute the operator to delete the Ray cluster. :param ray_cluster_yaml: Path to the YAML file defining the Ray cluster. - :param gpu_device_plugin_yaml: Path or URL to the GPU device plugin YAML. Defaults to NVIDIA's plugin + :param gpu_device_plugin_yaml: Path or URL to the GPU device plugin YAML. :raises AirflowException: If there's an error deleting the Ray cluster. """ try: self._validate_yaml_file(ray_cluster_yaml) - """Delete the NVIDIA GPU device plugin DaemonSet if it exists.""" - gpu_driver = self.load_yaml_content(gpu_device_plugin_yaml) - gpu_driver_name = gpu_driver["metadata"]["name"] + """Delete the GPU device plugin DaemonSet if it exists.""" + if gpu_device_plugin_yaml: + gpu_driver = self.load_yaml_content(gpu_device_plugin_yaml) + gpu_driver_name = gpu_driver["metadata"]["name"] + else: + return if self.get_daemon_set(gpu_driver_name): - self.log.info("Deleting DaemonSet for NVIDIA device plugin...") + self.log.info("Deleting DaemonSet for the GPU device plugin...") self.delete_daemon_set(gpu_driver_name) self.log.info("::group:: Delete Ray Cluster") diff --git a/ray_provider/operators.py b/ray_provider/operators.py index 3a0776e..04148f1 100644 --- a/ray_provider/operators.py +++ b/ray_provider/operators.py @@ -23,7 +23,7 @@ class SetupRayCluster(BaseOperator): :param conn_id: The connection ID for the Ray cluster. :param ray_cluster_yaml: Path to the YAML file defining the Ray cluster. :param kuberay_version: Version of KubeRay to install. Defaults to "1.0.0". - :param gpu_device_plugin_yaml: URL or path to the GPU device plugin YAML. Defaults to NVIDIA's plugin. + :param gpu_device_plugin_yaml: URL or path to the GPU device plugin YAML. Example value: https://raw.githubusercontent.com/NVIDIA/k8s-device-plugin/v0.9.0/nvidia-device-plugin.yml. :param update_if_exists: Whether to update the cluster if it already exists. Defaults to False. """ @@ -32,7 +32,7 @@ def __init__( conn_id: str, ray_cluster_yaml: str, kuberay_version: str = "1.0.0", - gpu_device_plugin_yaml: str = "https://raw.githubusercontent.com/NVIDIA/k8s-device-plugin/v0.9.0/nvidia-device-plugin.yml", + gpu_device_plugin_yaml: str = "", update_if_exists: bool = False, **kwargs: Any, ) -> None: @@ -71,14 +71,14 @@ class DeleteRayCluster(BaseOperator): :param conn_id: The connection ID for the Ray cluster. :param ray_cluster_yaml: Path to the YAML file defining the Ray cluster. - :param gpu_device_plugin_yaml: URL or path to the GPU device plugin YAML. Defaults to NVIDIA's plugin. + :param gpu_device_plugin_yaml: URL or path to the GPU device plugin YAML. Example value: https://raw.githubusercontent.com/NVIDIA/k8s-device-plugin/v0.9.0/nvidia-device-plugin.yml """ def __init__( self, conn_id: str, ray_cluster_yaml: str, - gpu_device_plugin_yaml: str = "https://raw.githubusercontent.com/NVIDIA/k8s-device-plugin/v0.9.0/nvidia-device-plugin.yml", + gpu_device_plugin_yaml: str = "", **kwargs: Any, ) -> None: super().__init__(**kwargs) @@ -119,7 +119,7 @@ class SubmitRayJob(BaseOperator): :param ray_cluster_yaml: Path to the Ray cluster YAML configuration file. If provided, the operator will set up and tear down the cluster. :param kuberay_version: Version of KubeRay to use when setting up the Ray cluster. Defaults to "1.0.0". :param update_if_exists: Whether to update the Ray cluster if it already exists. Defaults to True. - :param gpu_device_plugin_yaml: URL or path to the GPU device plugin YAML file. Defaults to NVIDIA's plugin. + :param gpu_device_plugin_yaml: URL or path to the GPU device plugin YAML file. Example value: https://raw.githubusercontent.com/NVIDIA/k8s-device-plugin/v0.9.0/nvidia-device-plugin.yml :param fetch_logs: Whether to fetch logs from the Ray job. Defaults to True. :param wait_for_completion: Whether to wait for the job to complete before marking the task as finished. Defaults to True. :param job_timeout_seconds: Maximum time to wait for job completion in seconds. Defaults to 600 seconds. Set to 0 if you want the job to run indefinitely without timeouts. @@ -152,7 +152,7 @@ def __init__( ray_cluster_yaml: str | None = None, kuberay_version: str = "1.0.0", update_if_exists: bool = True, - gpu_device_plugin_yaml: str = "https://raw.githubusercontent.com/NVIDIA/k8s-device-plugin/v0.9.0/nvidia-device-plugin.yml", + gpu_device_plugin_yaml: str = "", fetch_logs: bool = True, wait_for_completion: bool = True, job_timeout_seconds: int = 600, diff --git a/tests/test_hooks.py b/tests/test_hooks.py index cfa33b7..61c01bf 100644 --- a/tests/test_hooks.py +++ b/tests/test_hooks.py @@ -772,3 +772,36 @@ def test_delete_ray_cluster_exception( with pytest.raises(AirflowException) as exc_info: ray_hook.delete_ray_cluster(ray_cluster_yaml="test.yaml", gpu_device_plugin_yaml="gpu.yaml") assert "Failed to delete Ray cluster: Cluster deletion failed" in str(exc_info.value) + + @patch("ray_provider.hooks.RayHook.create_daemon_set") + @patch("ray_provider.hooks.RayHook.get_daemon_set", return_value=True) + def test_setup_ray_cluster_with_config_existing_daemon(self, mock_get_daemon_set, mock_create_daemon_set, ray_hook): + gpu_device_plugin_yaml = ( + "https://raw.githubusercontent.com/NVIDIA/k8s-device-plugin/v0.9.0/nvidia-device-plugin.yml" + ) + response = ray_hook._setup_gpu_driver(gpu_device_plugin_yaml) + assert response is None + mock_get_daemon_set.assert_called_once() + mock_create_daemon_set.assert_not_called() + + @patch("ray_provider.hooks.RayHook.create_daemon_set") + @patch("ray_provider.hooks.RayHook.get_daemon_set", return_value=False) + def test_setup_ray_cluster_with_config_inexistent_daemon( + self, mock_get_daemon_set, mock_create_daemon_set, ray_hook + ): + gpu_device_plugin_yaml = ( + "https://raw.githubusercontent.com/NVIDIA/k8s-device-plugin/v0.9.0/nvidia-device-plugin.yml" + ) + response = ray_hook._setup_gpu_driver(gpu_device_plugin_yaml) + assert response is None + mock_get_daemon_set.assert_called_once() + mock_create_daemon_set.assert_called_once() + + @patch("ray_provider.hooks.RayHook.create_daemon_set") + @patch("ray_provider.hooks.RayHook.get_daemon_set") + def test_setup_ray_cluster_no_config(self, mock_get_daemon_set, mock_create_daemon_set, ray_hook): + gpu_device_plugin_yaml = "" + response = ray_hook._setup_gpu_driver(gpu_device_plugin_yaml) + assert response is None + mock_get_daemon_set.assert_not_called() + mock_create_daemon_set.assert_not_called() diff --git a/tests/test_operators.py b/tests/test_operators.py index a351f3b..faa89e3 100644 --- a/tests/test_operators.py +++ b/tests/test_operators.py @@ -35,10 +35,7 @@ def test_init_default_values(self): ray_cluster_yaml="cluster.yaml", ) assert operator.kuberay_version == "1.0.0" - assert ( - operator.gpu_device_plugin_yaml - == "https://raw.githubusercontent.com/NVIDIA/k8s-device-plugin/v0.9.0/nvidia-device-plugin.yml" - ) + assert not operator.gpu_device_plugin_yaml assert operator.update_if_exists is False @patch("ray_provider.operators.RayHook") @@ -87,10 +84,7 @@ def test_init_default_gpu_plugin(self): conn_id="test_conn", ray_cluster_yaml="cluster.yaml", ) - assert ( - operator.gpu_device_plugin_yaml - == "https://raw.githubusercontent.com/NVIDIA/k8s-device-plugin/v0.9.0/nvidia-device-plugin.yml" - ) + assert not operator.gpu_device_plugin_yaml @patch("ray_provider.operators.RayHook") def test_hook_property(self, mock_ray_hook):