Skip to content

Commit

Permalink
Stop catching generic Exception in operators (#100)
Browse files Browse the repository at this point in the history
By catching `Exception`, we risk hitting an unexpected exception that the program can't recover from, or worse, swallowing an important exception without properly logging it - a massive headache when trying to debug programs that are failing in weird ways.

If this change raises any exceptions that should be caught, we'll have the opportunity to understand which exceptions to capture and handle them gracefully.

This was identified during the #81 development.
  • Loading branch information
tatiana authored Nov 29, 2024
1 parent 6e69368 commit fadce49
Show file tree
Hide file tree
Showing 3 changed files with 244 additions and 225 deletions.
2 changes: 2 additions & 0 deletions ray_provider/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
class RayAirflowException(Exception):
pass
163 changes: 89 additions & 74 deletions ray_provider/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
from functools import cached_property
from typing import Any

from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.providers.cncf.kubernetes.utils.pod_manager import PodOperatorHookProtocol
from airflow.utils.context import Context
from kubernetes.client.exceptions import ApiException
from ray.job_submission import JobStatus

from ray_provider.constants import TERMINAL_JOB_STATUSES
from ray_provider.exceptions import RayAirflowException
from ray_provider.hooks import RayHook
from ray_provider.triggers import RayJobTrigger

Expand Down Expand Up @@ -41,7 +43,7 @@ def __init__(
self.gpu_device_plugin_yaml = gpu_device_plugin_yaml
self.update_if_exists = update_if_exists

@cached_property
@property
def hook(self) -> RayHook:
"""Lazily initialize and return the RayHook."""
return RayHook(conn_id=self.conn_id)
Expand All @@ -52,13 +54,15 @@ def execute(self, context: Context) -> None:
:param context: The context in which the operator is being executed.
"""
self.log.info(f"Trying to setup the ray cluster defined in {self.ray_cluster_yaml}")
self.hook.setup_ray_cluster(
context=context,
ray_cluster_yaml=self.ray_cluster_yaml,
kuberay_version=self.kuberay_version,
gpu_device_plugin_yaml=self.gpu_device_plugin_yaml,
update_if_exists=self.update_if_exists,
)
self.log.info("Finished setting up the ray cluster.")


class DeleteRayCluster(BaseOperator):
Expand All @@ -82,7 +86,7 @@ def __init__(
self.ray_cluster_yaml = ray_cluster_yaml
self.gpu_device_plugin_yaml = gpu_device_plugin_yaml

@cached_property
@property
def hook(self) -> PodOperatorHookProtocol:
"""Lazily initialize and return the RayHook."""
return RayHook(conn_id=self.conn_id)
Expand All @@ -93,7 +97,9 @@ def execute(self, context: Context) -> None:
:param context: The context in which the operator is being executed.
"""
self.log.info(f"Trying to delete the ray cluster defined in {self.ray_cluster_yaml}")
self.hook.delete_ray_cluster(self.ray_cluster_yaml, self.gpu_device_plugin_yaml)
self.log.info("Finished deleting the ray cluster.")


class SubmitRayJob(BaseOperator):
Expand Down Expand Up @@ -173,7 +179,6 @@ def __init__(
self.xcom_task_key = xcom_task_key
self.dashboard_url: str | None = None
self.job_id = ""
self.terminal_states = {JobStatus.SUCCEEDED, JobStatus.STOPPED, JobStatus.FAILED}

def on_kill(self) -> None:
"""
Expand Down Expand Up @@ -226,28 +231,36 @@ def _setup_cluster(self, context: Context) -> None:
Set up the Ray cluster if a cluster YAML is provided.
:param context: The context in which the task is being executed.
:raises Exception: If there's an error during cluster setup.
"""
if self.ray_cluster_yaml:
self.hook.setup_ray_cluster(
context=context,
ray_cluster_yaml=self.ray_cluster_yaml,
kuberay_version=self.kuberay_version,
gpu_device_plugin_yaml=self.gpu_device_plugin_yaml,
update_if_exists=self.update_if_exists,
)
try:
self.hook.setup_ray_cluster(
context=context,
ray_cluster_yaml=self.ray_cluster_yaml,
kuberay_version=self.kuberay_version,
gpu_device_plugin_yaml=self.gpu_device_plugin_yaml,
update_if_exists=self.update_if_exists,
)
except ApiException as e:
self.log.info(f"Unable to setup the Ray cluster using {self.ray_cluster_yaml}")
self.log.error("Exception details:", exc_info=True)
self.log.info("Trying to delete any parts of the RayCluster that may have been spun up...")
self._delete_cluster()
raise e
else:
self.log.info(f"Skipping setting up a Ray cluster because no `ray_cluster_yaml` was given.")

def _delete_cluster(self) -> None:
"""
Delete the Ray cluster if a cluster YAML is provided.
:raises Exception: If there's an error during cluster deletion.
"""
if self.ray_cluster_yaml:
self.hook.delete_ray_cluster(
ray_cluster_yaml=self.ray_cluster_yaml,
gpu_device_plugin_yaml=self.gpu_device_plugin_yaml,
)
else:
self.log.info(f"Skipping deleting the Ray cluster because no `ray_cluster_yaml` was given.")

def execute(self, context: Context) -> str:
"""
Expand All @@ -258,58 +271,51 @@ def execute(self, context: Context) -> str:
:param context: The context in which the task is being executed.
:return: The job ID of the submitted Ray job.
:raises AirflowException: If the job fails, is cancelled, or reaches an unexpected state.
"""

try:
self._setup_cluster(context=context)

self.dashboard_url = self._get_dashboard_url(context)

self.job_id = self.hook.submit_ray_job(
dashboard_url=self.dashboard_url,
entrypoint=self.entrypoint,
runtime_env=self.runtime_env,
entrypoint_num_cpus=self.num_cpus,
entrypoint_num_gpus=self.num_gpus,
entrypoint_memory=self.memory,
entrypoint_resources=self.ray_resources,
)
self.log.info(f"Ray job submitted with id: {self.job_id}")

if self.wait_for_completion:
current_status = self.hook.get_ray_job_status(self.dashboard_url, self.job_id)
self.log.info(f"Current job status for {self.job_id} is: {current_status}")

if current_status not in self.terminal_states:
self.log.info("Deferring the polling to RayJobTrigger...")
self.defer(
trigger=RayJobTrigger(
job_id=self.job_id,
conn_id=self.conn_id,
xcom_dashboard_url=self.dashboard_url,
ray_cluster_yaml=self.ray_cluster_yaml,
gpu_device_plugin_yaml=self.gpu_device_plugin_yaml,
poll_interval=self.poll_interval,
fetch_logs=self.fetch_logs,
),
method_name="execute_complete",
timeout=self.job_timeout_seconds,
)
elif current_status == JobStatus.SUCCEEDED:
self.log.info("Job %s completed successfully", self.job_id)
elif current_status == JobStatus.FAILED:
raise AirflowException(f"Job failed:\n{self.job_id}")
elif current_status == JobStatus.STOPPED:
raise AirflowException(f"Job was cancelled:\n{self.job_id}")
else:
raise AirflowException(
f"Encountered unexpected state `{current_status}` for job_id `{self.job_id}`"
)
return self.job_id
except Exception as e:
self._delete_cluster()
raise AirflowException(f"SubmitRayJob operator failed due to {e}. Cleaning up resources...")
self.log.info("::group:: (SubmitJob 1/5) Setup Cluster")
self._setup_cluster(context=context)
self.log.info("::endgroup::")

self.log.info("::group:: (SubmitJob 2/5) Identify Dashboard URL")
self.dashboard_url = self._get_dashboard_url(context)
self.log.info("::endgroup::")

self.log.info("::group:: (SubmitJob 3/5) Submit job")
self.log.info(f"Ray job with id {self.job_id} submitted")
self.job_id = self.hook.submit_ray_job(
dashboard_url=self.dashboard_url,
entrypoint=self.entrypoint,
runtime_env=self.runtime_env,
entrypoint_num_cpus=self.num_cpus,
entrypoint_num_gpus=self.num_gpus,
entrypoint_memory=self.memory,
entrypoint_resources=self.ray_resources,
)
self.log.info("::endgroup::")

self.log.info("::group:: (SubmitJob 4/5) Wait for completion")
if self.wait_for_completion:
current_status = self.hook.get_ray_job_status(self.dashboard_url, self.job_id)
self.log.info(f"Current job status for {self.job_id} is: {current_status}")

if current_status not in TERMINAL_JOB_STATUSES:
self.log.info("Deferring the polling to RayJobTrigger...")
self.defer(
trigger=RayJobTrigger(
job_id=self.job_id,
conn_id=self.conn_id,
xcom_dashboard_url=self.dashboard_url,
ray_cluster_yaml=self.ray_cluster_yaml,
gpu_device_plugin_yaml=self.gpu_device_plugin_yaml,
poll_interval=self.poll_interval,
fetch_logs=self.fetch_logs,
),
method_name="execute_complete",
timeout=self.job_timeout_seconds,
)

return self.job_id

def execute_complete(self, context: Context, event: dict[str, Any]) -> None:
"""
Expand All @@ -320,15 +326,24 @@ def execute_complete(self, context: Context, event: dict[str, Any]) -> None:
:param context: The context in which the task is being executed.
:param event: The event containing the job execution result.
:raises AirflowException: If the job execution fails, is cancelled, or reaches an unexpected state.
:raises RayAirflowException: If the job execution fails, is cancelled, or reaches an unexpected state.
"""
try:
if event["status"] in [JobStatus.STOPPED, JobStatus.FAILED]:
self.log.info(f"Ray job {self.job_id} execution not completed successfully...")
raise AirflowException(f"Job {self.job_id} {event['status'].lower()}: {event['message']}")
elif event["status"] == JobStatus.SUCCEEDED:
self.log.info(f"Ray job {self.job_id} execution succeeded.")
self.log.info("::endgroup::")
self.log.info("::group:: (SubmitJob 5/5) Execution completed")

self._delete_cluster()

job_status = event["status"]
if job_status == JobStatus.SUCCEEDED:
self.log.info("Job %s completed successfully", self.job_id)
return
else:
self.log.info(f"Ray job {self.job_id} execution not completed successfully...")
if job_status in (JobStatus.FAILED, JobStatus.STOPPED):
msg = f"Job {self.job_id} {job_status.lower()}: {event['message']}"
else:
raise AirflowException(f"Unexpected event status for job {self.job_id}: {event['status']}")
finally:
self._delete_cluster()
msg = f"Encountered unexpected state `{job_status}` for job_id `{self.job_id}`"

self.log.info("::endgroup::")

raise RayAirflowException(msg)
Loading

0 comments on commit fadce49

Please sign in to comment.