From fc65f1eba057acf0e139ded54d71fc00da6eaf4f Mon Sep 17 00:00:00 2001 From: jacobcbeaudin Date: Sat, 7 Sep 2024 07:34:23 -0700 Subject: [PATCH] Enhance HexRunProjectOperator, update dependencies and config - Remove deprecated "apply_default" decorator from HexRunProjectOperator - Implement retry mechanism for API polling in HexHook - Decouple polling logic from run_and_poll method - Add max_poll_retries and poll_retry_delay parameters - Update flake8 pre-commit hook to use GitHub repository - Update pre-commit hooks to latest versions --- .pre-commit-config.yaml | 14 ++-- CHANGELOG.md | 12 ++- VERSION.txt | 2 +- airflow_provider_hex/hooks/hex.py | 88 ++++++++++++++++----- airflow_provider_hex/operators/hex.py | 8 +- setup.cfg | 2 +- tests/conftest.py | 14 +++- tests/hooks/test_hex_hook.py | 106 ++++++++++++++++++++++++-- tests/operators/test_hex_operator.py | 11 ++- 9 files changed, 215 insertions(+), 42 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e080290..6f6989a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,23 +1,23 @@ repos: - repo: https://github.com/psf/black - rev: 22.6.0 + rev: 24.8.0 hooks: - id: black args: ["--target-version=py38", "--line-length=88"] - repo: https://github.com/pycqa/isort - rev: 5.10.1 + rev: 5.13.2 hooks: - id: isort args: ["--profile=black"] - - repo: https://gitlab.com/pycqa/flake8 - rev: 3.9.2 + - repo: https://github.com/pycqa/flake8 + rev: 7.1.1 hooks: - id: flake8 - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.1.0 + rev: v4.6.0 hooks: - id: check-merge-conflict - id: check-toml @@ -28,14 +28,14 @@ repos: - id: trailing-whitespace - repo: https://github.com/pre-commit/mirrors-mypy - rev: v0.931 + rev: v1.11.2 hooks: - id: mypy exclude: ^tests/ additional_dependencies: [ types-requests ] - repo: https://github.com/codespell-project/codespell - rev: v2.1.0 + rev: v2.3.0 hooks: - id: codespell name: Run codespell to check for common misspellings in files diff --git a/CHANGELOG.md b/CHANGELOG.md index 53db658..cbf59bb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,7 +9,17 @@ - ### Fixed -- + +## [0.1.10] - 2024-09-09 + +### Added +- Enhanced retry mechanism for polling project status +- New `max_poll_retries` and `poll_retry_delay` parameters for `HexRunProjectOperator` +- New `run_status_with_retries` method in `HexHook` +- New `poll_project_status` method in `HexHook` with improved error handling + +### Changed +- Improved error handling for API calls and status checks ## [0.1.9] - 2023-05-16 diff --git a/VERSION.txt b/VERSION.txt index 1a03094..9767cc9 100644 --- a/VERSION.txt +++ b/VERSION.txt @@ -1 +1 @@ -0.1.9 +0.1.10 diff --git a/airflow_provider_hex/hooks/hex.py b/airflow_provider_hex/hooks/hex.py index 6ecaf87..8ca3008 100644 --- a/airflow_provider_hex/hooks/hex.py +++ b/airflow_provider_hex/hooks/hex.py @@ -7,6 +7,8 @@ from airflow.exceptions import AirflowException from airflow.hooks.base import BaseHook from importlib_metadata import PackageNotFoundError, version +from requests.exceptions import RequestException +from tenacity import retry, stop_after_attempt, wait_fixed from airflow_provider_hex.types import NotificationDetails, RunResponse, StatusResponse @@ -151,52 +153,74 @@ def run_project( ), ) - def run_status(self, project_id, run_id) -> StatusResponse: + @retry(stop=stop_after_attempt(3), wait=wait_fixed(1)) + def run_status(self, project_id: str, run_id: str) -> StatusResponse: endpoint = f"api/v1/project/{project_id}/run/{run_id}" method = "GET" + try: + response = self.run(method=method, endpoint=endpoint, data=None) + return cast(StatusResponse, response) + except RequestException as e: + self.log.error(f"API call failed: {str(e)}") + raise - return cast( - StatusResponse, self.run(method=method, endpoint=endpoint, data=None) - ) - - def cancel_run(self, project_id, run_id) -> str: + def cancel_run(self, project_id: str, run_id: str) -> str: endpoint = f"api/v1/project/{project_id}/run/{run_id}" method = "DELETE" self.run(method=method, endpoint=endpoint) return run_id - def run_and_poll( + def run_status_with_retries( + self, project_id: str, run_id: str, max_retries: int = 3, retry_delay: int = 1 + ) -> StatusResponse: + @retry(stop=stop_after_attempt(max_retries), wait=wait_fixed(retry_delay)) + def _run_status(): + return self.run_status(project_id, run_id) + + return _run_status() + + def poll_project_status( self, project_id: str, - inputs: Optional[dict], - update_cache: bool = False, + run_id: str, poll_interval: int = 3, poll_timeout: int = 600, kill_on_timeout: bool = True, - notifications: List[NotificationDetails] = [], - ): - run_response = self.run_project(project_id, inputs, update_cache, notifications) - run_id = run_response["runId"] - + max_poll_retries: int = 3, + poll_retry_delay: int = 5, + ) -> StatusResponse: poll_start = datetime.datetime.now() while True: - run_status = self.run_status(project_id, run_id) + try: + run_status = self.run_status_with_retries( + project_id, run_id, max_poll_retries, poll_retry_delay + ) + except Exception as e: + self.log.error( + f"Failed to get run status after {max_poll_retries} " + f"attempts: {str(e)}" + ) + if kill_on_timeout: + self.cancel_run(project_id, run_id) + raise AirflowException( + "Failed to get run status for project " + f"{project_id} with run: {run_id}" + ) + project_status = run_status["status"] self.log.info( f"Polling Hex Project {project_id}. Status: {project_status}." ) - if project_status not in VALID_STATUSES: - raise AirflowException(f"Unhandled status: {project_status}") if project_status == COMPLETE: - break + return run_status if project_status in TERMINAL_STATUSES: raise AirflowException( f"Project Run failed with status {project_status}. " - f"See Run URL for more info {run_response['runUrl']}" + f"See Run URL for more info {run_status['runUrl']}" ) if ( @@ -217,4 +241,28 @@ def run_and_poll( ) time.sleep(poll_interval) - return run_status + + def run_and_poll( + self, + project_id: str, + inputs: Optional[dict], + update_cache: bool = False, + poll_interval: int = 3, + poll_timeout: int = 600, + kill_on_timeout: bool = True, + notifications: List[NotificationDetails] = [], + max_poll_retries: int = 3, + poll_retry_delay: int = 5, + ): + run_response = self.run_project(project_id, inputs, update_cache, notifications) + run_id = run_response["runId"] + + return self.poll_project_status( + project_id, + run_id, + poll_interval, + poll_timeout, + kill_on_timeout, + max_poll_retries, + poll_retry_delay, + ) diff --git a/airflow_provider_hex/operators/hex.py b/airflow_provider_hex/operators/hex.py index c0e439e..e14cca5 100644 --- a/airflow_provider_hex/operators/hex.py +++ b/airflow_provider_hex/operators/hex.py @@ -2,7 +2,6 @@ from airflow.models import BaseOperator from airflow.models.dag import Context -from airflow.utils.decorators import apply_defaults from airflow_provider_hex.hooks.hex import HexHook from airflow_provider_hex.types import NotificationDetails @@ -41,7 +40,6 @@ class HexRunProjectOperator(BaseOperator): template_fields = ["project_id", "input_parameters"] ui_color = "#F5C0C0" - @apply_defaults def __init__( self, project_id: str, @@ -53,6 +51,8 @@ def __init__( input_parameters: Optional[Dict[str, Any]] = None, update_cache: bool = False, notifications: List[NotificationDetails] = [], + max_poll_retries: int = 3, + poll_retry_delay: int = 5, # Change this to 5 **kwargs, ) -> None: super().__init__(**kwargs) @@ -65,6 +65,8 @@ def __init__( self.input_parameters = input_parameters self.update_cache = update_cache self.notifications = notifications + self.max_poll_retries = max_poll_retries + self.poll_retry_delay = poll_retry_delay def execute(self, context: Context) -> Any: hook = HexHook(self.hex_conn_id) @@ -79,6 +81,8 @@ def execute(self, context: Context) -> Any: poll_timeout=self.timeout, kill_on_timeout=self.kill_on_timeout, notifications=self.notifications, + max_poll_retries=self.max_poll_retries, + poll_retry_delay=self.poll_retry_delay, ) self.log.info("Hex Project completed successfully") diff --git a/setup.cfg b/setup.cfg index ea06af0..9ca07b1 100644 --- a/setup.cfg +++ b/setup.cfg @@ -26,7 +26,7 @@ install_requires = apache-airflow>=2.2.0 requests>=2 importlib-metadata>=4.8.1 - typing-extensions>=3.10.0.2 + typing-extensions>=4 zip_safe = false [options.extras_require] diff --git a/tests/conftest.py b/tests/conftest.py index fcc8b13..507372d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,8 +1,10 @@ import datetime +import os import pendulum import pytest from airflow import DAG +from airflow.utils.db import initdb from airflow_provider_hex.operators.hex import HexRunProjectOperator @@ -20,11 +22,17 @@ def sample_conn(mocker): ) +@pytest.fixture(scope="session", autouse=True) +def init_airflow_db(): + os.environ["AIRFLOW__CORE__UNIT_TEST_MODE"] = "True" + initdb() + + @pytest.fixture() def dag(): with DAG( dag_id=TEST_DAG_ID, - schedule_interval="@daily", + schedule="@daily", start_date=DATA_INTERVAL_START, ) as dag: HexRunProjectOperator( @@ -32,6 +40,8 @@ def dag(): hex_conn_id="hex_conn", project_id="ABC-123", input_parameters={"input_date": "{{ ds }}"}, + max_poll_retries=3, + poll_retry_delay=1, ) return dag @@ -48,5 +58,7 @@ def fake_dag(): hex_conn_id="hex_conn", project_id="ABC-123", input_parameters={"input_date": "{{ ds }}"}, + max_poll_retries=3, + poll_retry_delay=1, ) return dag diff --git a/tests/hooks/test_hex_hook.py b/tests/hooks/test_hex_hook.py index d4af454..fa0c0fe 100644 --- a/tests/hooks/test_hex_hook.py +++ b/tests/hooks/test_hex_hook.py @@ -1,13 +1,12 @@ import logging import pytest -from airflow import AirflowException +from airflow.exceptions import AirflowException from airflow_provider_hex.hooks.hex import HexHook log = logging.getLogger(__name__) - mock_run = { "projectId": "abc-123", "runId": "1", @@ -118,12 +117,15 @@ def test_run_poll_pending_and_error(self, requests_mock): requests_mock.post( "https://www.httpbin.org/api/v1/project/abc-123/run", headers={"Content-Type": "application/json"}, - json=mock_run, + json={"projectId": "abc-123", "runId": "1"}, ) mock_status = {"projectId": "abc-123", "status": "PENDING"} - - mock_status_2 = {"projectId": "abc-123", "status": "UNABLE_TO_ALLOCATE_KERNEL"} + mock_status_2 = { + "projectId": "abc-123", + "status": "UNABLE_TO_ALLOCATE_KERNEL", + "runUrl": "https://example.com/run/1", + } header = {"Content-Type": "application/json"} requests_mock.register_uri( @@ -138,4 +140,96 @@ def test_run_poll_pending_and_error(self, requests_mock): hook = HexHook(hex_conn_id="hex_conn") with pytest.raises(AirflowException, match=r"Project Run failed with status.*"): - hook.run_and_poll("abc-123", inputs=None, poll_interval=1) + hook.run_and_poll( + "abc-123", + inputs=None, + poll_interval=1, + max_poll_retries=3, + poll_retry_delay=1, + ) + + # Check if the status endpoint was called multiple times + assert requests_mock.call_count == 3 # 1 POST + 2 GET requests + + def test_run_status_with_retries(self, requests_mock): + mock_status_error = {"projectId": "abc-123", "error": "Internal Server Error"} + mock_status_success = {"projectId": "abc-123", "status": "RUNNING"} + + header = {"Content-Type": "application/json"} + requests_mock.register_uri( + "GET", + "https://www.httpbin.org/api/v1/project/abc-123/run/1", + [ + {"headers": header, "json": mock_status_error, "status_code": 500}, + {"headers": header, "json": mock_status_error, "status_code": 500}, + {"headers": header, "json": mock_status_success, "status_code": 200}, + ], + ) + + hook = HexHook(hex_conn_id="hex_conn") + response = hook.run_status_with_retries( + "abc-123", "1", max_retries=3, retry_delay=1 + ) + + assert response["status"] == "RUNNING" + assert requests_mock.call_count == 3 + + def test_poll_project_status(self, requests_mock): + mock_status_pending = {"projectId": "abc-123", "status": "PENDING"} + mock_status_running = {"projectId": "abc-123", "status": "RUNNING"} + mock_status_completed = { + "projectId": "abc-123", + "status": "COMPLETED", + "runUrl": "https://example.com/run/1", + } + + header = {"Content-Type": "application/json"} + requests_mock.register_uri( + "GET", + "https://www.httpbin.org/api/v1/project/abc-123/run/1", + [ + {"headers": header, "json": mock_status_pending}, + {"headers": header, "json": mock_status_running}, + {"headers": header, "json": mock_status_completed}, + ], + ) + + hook = HexHook(hex_conn_id="hex_conn") + response = hook.poll_project_status( + "abc-123", + "1", + poll_interval=1, + poll_timeout=10, + max_poll_retries=3, + poll_retry_delay=1, + ) + + assert response["status"] == "COMPLETED" + assert requests_mock.call_count == 3 + + def test_poll_project_status_error(self, requests_mock): + requests_mock.get( + "https://www.httpbin.org/api/v1/project/abc-123/run/1", + [{"status_code": 500}] * 9, # 3 retries * 3 attempts + ) + + requests_mock.delete( + "https://www.httpbin.org/api/v1/project/abc-123/run/1", status_code=200 + ) + + hook = HexHook(hex_conn_id="hex_conn") + + with pytest.raises( + AirflowException, match="Failed to get run status for project" + ): + hook.poll_project_status( + "abc-123", + "1", + poll_interval=1, + poll_timeout=10, + kill_on_timeout=True, + max_poll_retries=3, + poll_retry_delay=1, + ) + + assert requests_mock.call_count == 10 # 9 GET requests + 1 DELETE request diff --git a/tests/operators/test_hex_operator.py b/tests/operators/test_hex_operator.py index 6b4b74d..faf4dfa 100644 --- a/tests/operators/test_hex_operator.py +++ b/tests/operators/test_hex_operator.py @@ -17,12 +17,17 @@ def test_my_custom_operator_execute_no_trigger(dag, requests_mock): json={"projectId": "ABC-123", "runId": "1"}, ) - mock_status = {"projectId": "abc-123", "status": "COMPLETED"} + mock_status = { + "projectId": "ABC-123", + "status": "COMPLETED", + "runUrl": "https://example.com/run/1", + } requests_mock.get( - "https://www.httpbin.org/api/v1/project/abc-123/run/1", + "https://www.httpbin.org/api/v1/project/ABC-123/run/1", headers={"Content-Type": "application/json"}, json=mock_status, ) + dagrun = dag.create_dagrun( state=DagRunState.RUNNING, execution_date=DATA_INTERVAL_START, @@ -33,8 +38,8 @@ def test_my_custom_operator_execute_no_trigger(dag, requests_mock): ti = dagrun.get_task_instance(task_id=TEST_TASK_ID) ti.task = dag.get_task(task_id=TEST_TASK_ID) ti.run(ignore_ti_state=True) + assert ti.state == TaskInstanceState.SUCCESS json = requests_mock.request_history[0].json() assert json["inputParams"]["input_date"][0:4] == str(DATA_INTERVAL_START.year) - print(json) assert json["updateCache"] is False