Skip to content

Commit

Permalink
Enhance HexRunProjectOperator, update dependencies and config
Browse files Browse the repository at this point in the history
- Remove deprecated "apply_default" decorator from HexRunProjectOperator
- Implement retry mechanism for API polling in HexHook
- Decouple polling logic from run_and_poll method
- Add max_poll_retries and poll_retry_delay parameters
- Update flake8 pre-commit hook to use GitHub repository
- Update pre-commit hooks to latest versions
  • Loading branch information
jacobcbeaudin committed Sep 8, 2024
1 parent d1b2c8c commit fc65f1e
Show file tree
Hide file tree
Showing 9 changed files with 215 additions and 42 deletions.
14 changes: 7 additions & 7 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,23 +1,23 @@
repos:
- repo: https://github.com/psf/black
rev: 22.6.0
rev: 24.8.0
hooks:
- id: black
args: ["--target-version=py38", "--line-length=88"]

- repo: https://github.com/pycqa/isort
rev: 5.10.1
rev: 5.13.2
hooks:
- id: isort
args: ["--profile=black"]

- repo: https://gitlab.com/pycqa/flake8
rev: 3.9.2
- repo: https://github.com/pycqa/flake8
rev: 7.1.1
hooks:
- id: flake8

- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.1.0
rev: v4.6.0
hooks:
- id: check-merge-conflict
- id: check-toml
Expand All @@ -28,14 +28,14 @@ repos:
- id: trailing-whitespace

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.931
rev: v1.11.2
hooks:
- id: mypy
exclude: ^tests/
additional_dependencies: [ types-requests ]

- repo: https://github.com/codespell-project/codespell
rev: v2.1.0
rev: v2.3.0
hooks:
- id: codespell
name: Run codespell to check for common misspellings in files
Expand Down
12 changes: 11 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,17 @@
-

### Fixed
-

## [0.1.10] - 2024-09-09

### Added
- Enhanced retry mechanism for polling project status
- New `max_poll_retries` and `poll_retry_delay` parameters for `HexRunProjectOperator`
- New `run_status_with_retries` method in `HexHook`
- New `poll_project_status` method in `HexHook` with improved error handling

### Changed
- Improved error handling for API calls and status checks


## [0.1.9] - 2023-05-16
Expand Down
2 changes: 1 addition & 1 deletion VERSION.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.1.9
0.1.10
88 changes: 68 additions & 20 deletions airflow_provider_hex/hooks/hex.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from airflow.exceptions import AirflowException
from airflow.hooks.base import BaseHook
from importlib_metadata import PackageNotFoundError, version
from requests.exceptions import RequestException
from tenacity import retry, stop_after_attempt, wait_fixed

from airflow_provider_hex.types import NotificationDetails, RunResponse, StatusResponse

Expand Down Expand Up @@ -151,52 +153,74 @@ def run_project(
),
)

def run_status(self, project_id, run_id) -> StatusResponse:
@retry(stop=stop_after_attempt(3), wait=wait_fixed(1))
def run_status(self, project_id: str, run_id: str) -> StatusResponse:
endpoint = f"api/v1/project/{project_id}/run/{run_id}"
method = "GET"
try:
response = self.run(method=method, endpoint=endpoint, data=None)
return cast(StatusResponse, response)
except RequestException as e:
self.log.error(f"API call failed: {str(e)}")
raise

return cast(
StatusResponse, self.run(method=method, endpoint=endpoint, data=None)
)

def cancel_run(self, project_id, run_id) -> str:
def cancel_run(self, project_id: str, run_id: str) -> str:
endpoint = f"api/v1/project/{project_id}/run/{run_id}"
method = "DELETE"

self.run(method=method, endpoint=endpoint)
return run_id

def run_and_poll(
def run_status_with_retries(
self, project_id: str, run_id: str, max_retries: int = 3, retry_delay: int = 1
) -> StatusResponse:
@retry(stop=stop_after_attempt(max_retries), wait=wait_fixed(retry_delay))
def _run_status():
return self.run_status(project_id, run_id)

return _run_status()

def poll_project_status(
self,
project_id: str,
inputs: Optional[dict],
update_cache: bool = False,
run_id: str,
poll_interval: int = 3,
poll_timeout: int = 600,
kill_on_timeout: bool = True,
notifications: List[NotificationDetails] = [],
):
run_response = self.run_project(project_id, inputs, update_cache, notifications)
run_id = run_response["runId"]

max_poll_retries: int = 3,
poll_retry_delay: int = 5,
) -> StatusResponse:
poll_start = datetime.datetime.now()
while True:
run_status = self.run_status(project_id, run_id)
try:
run_status = self.run_status_with_retries(
project_id, run_id, max_poll_retries, poll_retry_delay
)
except Exception as e:
self.log.error(
f"Failed to get run status after {max_poll_retries} "
f"attempts: {str(e)}"
)
if kill_on_timeout:
self.cancel_run(project_id, run_id)
raise AirflowException(
"Failed to get run status for project "
f"{project_id} with run: {run_id}"
)

project_status = run_status["status"]

self.log.info(
f"Polling Hex Project {project_id}. Status: {project_status}."
)
if project_status not in VALID_STATUSES:
raise AirflowException(f"Unhandled status: {project_status}")

if project_status == COMPLETE:
break
return run_status

if project_status in TERMINAL_STATUSES:
raise AirflowException(
f"Project Run failed with status {project_status}. "
f"See Run URL for more info {run_response['runUrl']}"
f"See Run URL for more info {run_status['runUrl']}"
)

if (
Expand All @@ -217,4 +241,28 @@ def run_and_poll(
)

time.sleep(poll_interval)
return run_status

def run_and_poll(
self,
project_id: str,
inputs: Optional[dict],
update_cache: bool = False,
poll_interval: int = 3,
poll_timeout: int = 600,
kill_on_timeout: bool = True,
notifications: List[NotificationDetails] = [],
max_poll_retries: int = 3,
poll_retry_delay: int = 5,
):
run_response = self.run_project(project_id, inputs, update_cache, notifications)
run_id = run_response["runId"]

return self.poll_project_status(
project_id,
run_id,
poll_interval,
poll_timeout,
kill_on_timeout,
max_poll_retries,
poll_retry_delay,
)
8 changes: 6 additions & 2 deletions airflow_provider_hex/operators/hex.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from airflow.models import BaseOperator
from airflow.models.dag import Context
from airflow.utils.decorators import apply_defaults

from airflow_provider_hex.hooks.hex import HexHook
from airflow_provider_hex.types import NotificationDetails
Expand Down Expand Up @@ -41,7 +40,6 @@ class HexRunProjectOperator(BaseOperator):
template_fields = ["project_id", "input_parameters"]
ui_color = "#F5C0C0"

@apply_defaults
def __init__(
self,
project_id: str,
Expand All @@ -53,6 +51,8 @@ def __init__(
input_parameters: Optional[Dict[str, Any]] = None,
update_cache: bool = False,
notifications: List[NotificationDetails] = [],
max_poll_retries: int = 3,
poll_retry_delay: int = 5, # Change this to 5
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -65,6 +65,8 @@ def __init__(
self.input_parameters = input_parameters
self.update_cache = update_cache
self.notifications = notifications
self.max_poll_retries = max_poll_retries
self.poll_retry_delay = poll_retry_delay

def execute(self, context: Context) -> Any:
hook = HexHook(self.hex_conn_id)
Expand All @@ -79,6 +81,8 @@ def execute(self, context: Context) -> Any:
poll_timeout=self.timeout,
kill_on_timeout=self.kill_on_timeout,
notifications=self.notifications,
max_poll_retries=self.max_poll_retries,
poll_retry_delay=self.poll_retry_delay,
)
self.log.info("Hex Project completed successfully")

Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ install_requires =
apache-airflow>=2.2.0
requests>=2
importlib-metadata>=4.8.1
typing-extensions>=3.10.0.2
typing-extensions>=4
zip_safe = false

[options.extras_require]
Expand Down
14 changes: 13 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import datetime
import os

import pendulum
import pytest
from airflow import DAG
from airflow.utils.db import initdb

from airflow_provider_hex.operators.hex import HexRunProjectOperator

Expand All @@ -20,18 +22,26 @@ def sample_conn(mocker):
)


@pytest.fixture(scope="session", autouse=True)
def init_airflow_db():
os.environ["AIRFLOW__CORE__UNIT_TEST_MODE"] = "True"
initdb()


@pytest.fixture()
def dag():
with DAG(
dag_id=TEST_DAG_ID,
schedule_interval="@daily",
schedule="@daily",
start_date=DATA_INTERVAL_START,
) as dag:
HexRunProjectOperator(
task_id=TEST_TASK_ID,
hex_conn_id="hex_conn",
project_id="ABC-123",
input_parameters={"input_date": "{{ ds }}"},
max_poll_retries=3,
poll_retry_delay=1,
)
return dag

Expand All @@ -48,5 +58,7 @@ def fake_dag():
hex_conn_id="hex_conn",
project_id="ABC-123",
input_parameters={"input_date": "{{ ds }}"},
max_poll_retries=3,
poll_retry_delay=1,
)
return dag
Loading

0 comments on commit fc65f1e

Please sign in to comment.