From 71fec4e661ae4858dfe0f3797124c68b87aa13b2 Mon Sep 17 00:00:00 2001 From: "Jiao, Hsu" <60685621+jieyao-MilestoneHub@users.noreply.github.com> Date: Tue, 3 Dec 2024 18:44:17 +0800 Subject: [PATCH] [Providers/HTTP] Add adapter parameter to HttpHook to allow custom requests adapters (#44302) * feat(http-hook): add adapter parameter to HttpHook and enhance get_conn - Added `adapter` parameter to `HttpHook` to allow custom HTTP adapters. - Modified `get_conn` to support mounting custom adapters or using TCPKeepAliveAdapter by default. - Added comprehensive tests to validate the functionality of the `adapter` parameter and its integration with `get_conn`. - Ensured all new tests pass and maintain compatibility with existing functionality. * fix(http_hook): Update docstring and remove redundant TCPKeepAliveAdapter - Added missing `adapter` parameter description to the HttpHook class docstring. - Removed redundant instantiation of `TCPKeepAliveAdapter` in the `run` method since it's already instantiated in `get_conn`. * fix(http_hook): improve get_conn session setup and TCP adapter logic - Ensured proper mounting of TCP Keep-Alive adapter when enabled. - Improved handling of connection extras for cleaner session configuration. * feat(http): update get_conn logic and corresponding tests (#44302) Aligned the `get_conn` method with the adjustments specified in #44302, including refined handling of headers. Optimized and updated test cases to ensure compatibility and maintain robust test coverage. * refactor(http_hook): simplify HttpHook by reverting BaseAdapter to HTTPAdapter - Changed the `adapter` parameter to accept only `HTTPAdapter` instead of `BaseAdapter`. - Strengthened `_set_base_url` validation to ensure base_url is constructed with stricter conditions. - Adjusted `_mount_adapters` to improve maintainability. * refactor(http_hook): simplify HttpHook by reverting BaseAdapter to HTTPAdapter - Changed the `adapter` parameter to accept only `HTTPAdapter` instead of `BaseAdapter`. - Strengthened `_set_base_url` validation to ensure base_url is constructed with stricter conditions. - Adjusted `_mount_adapters` to improve maintainability. * Merge: new main * refactor: improve function naming and add type annotations - Changed the function prefix from `_set` to `_configure_session_from` to enhance readability and better reflect its purpose. - Added static type annotations for input parameters and return values. - Included comments to document the design rationale following coding standards. - Improved error message: replaced generic text with detailed and actionable messages. * fix: simplify the change of session - Added a variable `session` after the change of session member * fix: Adjust response format. * fix: simplify the logic * fix(hook): ensure default HTTPAdapter in HttpHook init The `adapter` parameter in `HttpHook` was previously required to be explicitly set to an instance of `HTTPAdapter`. This commit modifies the `__init__` method to assign a default `HTTPAdapter` when no adapter is provided. Changes: - Removed type checks for `adapter`, as default initialization guarantees correctness. - Improved code readability and reduced potential runtime errors. No functional changes beyond defaulting `adapter` to `HTTPAdapter`. * feat(http_hook): add support for custom adapter in initialization Refactored `HttpHook` to support a custom `HTTPAdapter` through the `adapter` parameter. If no adapter is provided, it defaults to `TCPKeepAliveAdapter` when `tcp_keep_alive=True`. Test: Added `test_custom_adapter` to verify correct adapter mounting. * fix: CI image checks / Static checks - Adjust the length of each line of code. * fix: Adjust indent style - modify `assert instance` by PEP8 * fix: ruff error about `from requests.adapters import HTTPAdapter` --------- Co-authored-by: jiao --- .../src/airflow/providers/http/hooks/http.py | 126 +++++++++++------- providers/tests/http/hooks/test_http.py | 16 ++- 2 files changed, 95 insertions(+), 47 deletions(-) diff --git a/providers/src/airflow/providers/http/hooks/http.py b/providers/src/airflow/providers/http/hooks/http.py index 05b432626b8e3..a179739275e1a 100644 --- a/providers/src/airflow/providers/http/hooks/http.py +++ b/providers/src/airflow/providers/http/hooks/http.py @@ -19,6 +19,7 @@ import asyncio from typing import TYPE_CHECKING, Any, Callable +from urllib.parse import urlparse import aiohttp import requests @@ -34,6 +35,7 @@ if TYPE_CHECKING: from aiohttp.client_reqrep import ClientResponse + from requests.adapters import HTTPAdapter from airflow.models import Connection @@ -54,6 +56,7 @@ class HttpHook(BaseHook): API url i.e https://www.google.com/ and optional authentication credentials. Default headers can also be specified in the Extra field in json format. :param auth_type: The auth type for the service + :param adapter: An optional instance of `requests.adapters.HTTPAdapter` to mount for the session. :param tcp_keep_alive: Enable TCP Keep Alive for the connection. :param tcp_keep_alive_idle: The TCP Keep Alive Idle parameter (corresponds to ``socket.TCP_KEEPIDLE``). :param tcp_keep_alive_count: The TCP Keep Alive count parameter (corresponds to ``socket.TCP_KEEPCNT``) @@ -76,6 +79,7 @@ def __init__( tcp_keep_alive_idle: int = 120, tcp_keep_alive_count: int = 20, tcp_keep_alive_interval: int = 30, + adapter: HTTPAdapter | None = None, ) -> None: super().__init__() self.http_conn_id = http_conn_id @@ -83,10 +87,17 @@ def __init__( self.base_url: str = "" self._retry_obj: Callable[..., Any] self._auth_type: Any = auth_type - self.tcp_keep_alive = tcp_keep_alive - self.keep_alive_idle = tcp_keep_alive_idle - self.keep_alive_count = tcp_keep_alive_count - self.keep_alive_interval = tcp_keep_alive_interval + + # If no adapter is provided, use TCPKeepAliveAdapter (default behavior) + self.adapter = adapter + if tcp_keep_alive and adapter is None: + self.keep_alive_adapter = TCPKeepAliveAdapter( + idle=tcp_keep_alive_idle, + count=tcp_keep_alive_count, + interval=tcp_keep_alive_interval, + ) + else: + self.keep_alive_adapter = None @property def auth_type(self): @@ -102,47 +113,76 @@ def get_conn(self, headers: dict[Any, Any] | None = None) -> requests.Session: """ Create a Requests HTTP session. - :param headers: additional headers to be passed through as a dictionary + :param headers: Additional headers to be passed through as a dictionary. + :return: A configured requests.Session object. """ session = requests.Session() - - if self.http_conn_id: - conn = self.get_connection(self.http_conn_id) - - if conn.host and "://" in conn.host: - self.base_url = conn.host - else: - # schema defaults to HTTP - schema = conn.schema if conn.schema else "http" - host = conn.host if conn.host else "" - self.base_url = f"{schema}://{host}" - - if conn.port: - self.base_url += f":{conn.port}" - if conn.login: - session.auth = self.auth_type(conn.login, conn.password) - elif self._auth_type: - session.auth = self.auth_type() - if conn.extra: - extra = conn.extra_dejson - extra.pop( - "timeout", None - ) # ignore this as timeout is only accepted in request method of Session - extra.pop("allow_redirects", None) # ignore this as only max_redirects is accepted in Session - session.proxies = extra.pop("proxies", extra.pop("proxy", {})) - session.stream = extra.pop("stream", False) - session.verify = extra.pop("verify", extra.pop("verify_ssl", True)) - session.cert = extra.pop("cert", None) - session.max_redirects = extra.pop("max_redirects", DEFAULT_REDIRECT_LIMIT) - session.trust_env = extra.pop("trust_env", True) - - try: - session.headers.update(extra) - except TypeError: - self.log.warning("Connection to %s has invalid extra field.", conn.host) + connection = self.get_connection(self.http_conn_id) + self._set_base_url(connection) + session = self._configure_session_from_auth(session, connection) + if connection.extra: + session = self._configure_session_from_extra(session, connection) + session = self._configure_session_from_mount_adapters(session) if headers: session.headers.update(headers) + return session + + def _set_base_url(self, connection: Connection) -> None: + host = connection.host or "" + schema = connection.schema or "http" + # RFC 3986 (https://www.rfc-editor.org/rfc/rfc3986.html#page-16) + if "://" in host: + self.base_url = host + else: + self.base_url = f"{schema}://{host}" if host else f"{schema}://" + if connection.port: + self.base_url = f"{self.base_url}:{connection.port}" + parsed = urlparse(self.base_url) + if not parsed.scheme: + raise ValueError(f"Invalid base URL: Missing scheme in {self.base_url}") + + def _configure_session_from_auth( + self, session: requests.Session, connection: Connection + ) -> requests.Session: + session.auth = self._extract_auth(connection) + return session + + def _extract_auth(self, connection: Connection) -> Any | None: + if connection.login: + return self.auth_type(connection.login, connection.password) + elif self._auth_type: + return self.auth_type() + return None + + def _configure_session_from_extra( + self, session: requests.Session, connection: Connection + ) -> requests.Session: + extra = connection.extra_dejson + extra.pop("timeout", None) + extra.pop("allow_redirects", None) + session.proxies = extra.pop("proxies", extra.pop("proxy", {})) + session.stream = extra.pop("stream", False) + session.verify = extra.pop("verify", extra.pop("verify_ssl", True)) + session.cert = extra.pop("cert", None) + session.max_redirects = extra.pop("max_redirects", DEFAULT_REDIRECT_LIMIT) + session.trust_env = extra.pop("trust_env", True) + try: + session.headers.update(extra) + except TypeError: + self.log.warning("Connection to %s has invalid extra field.", connection.host) + return session + def _configure_session_from_mount_adapters(self, session: requests.Session) -> requests.Session: + scheme = urlparse(self.base_url).scheme + if not scheme: + raise ValueError( + f"Cannot mount adapters: {self.base_url} does not include a valid scheme (http or https)." + ) + if self.adapter: + session.mount(f"{scheme}://", self.adapter) + elif self.keep_alive_adapter: + session.mount("http://", self.keep_alive_adapter) + session.mount("https://", self.keep_alive_adapter) return session def run( @@ -171,11 +211,6 @@ def run( url = self.url_from_endpoint(endpoint) - if self.tcp_keep_alive: - keep_alive_adapter = TCPKeepAliveAdapter( - idle=self.keep_alive_idle, count=self.keep_alive_count, interval=self.keep_alive_interval - ) - session.mount(url, keep_alive_adapter) if self.method == "GET": # GET uses params req = requests.Request(self.method, url, params=data, headers=headers, **request_kwargs) @@ -467,5 +502,4 @@ def _retryable_error_async(self, exception: ClientResponseError) -> bool: if exception.status == 413: # don't retry for payload Too Large return False - return exception.status >= 500 diff --git a/providers/tests/http/hooks/test_http.py b/providers/tests/http/hooks/test_http.py index e09fd2d034e60..bd381a7155bbd 100644 --- a/providers/tests/http/hooks/test_http.py +++ b/providers/tests/http/hooks/test_http.py @@ -29,7 +29,7 @@ import requests import tenacity from aioresponses import aioresponses -from requests.adapters import Response +from requests.adapters import HTTPAdapter, Response from requests.auth import AuthBase, HTTPBasicAuth from requests.models import DEFAULT_REDIRECT_LIMIT @@ -536,6 +536,20 @@ def test_url_from_endpoint(self, base_url: str, endpoint: str, expected_url: str hook.base_url = base_url assert hook.url_from_endpoint(endpoint) == expected_url + def test_custom_adapter(self): + with mock.patch( + "airflow.hooks.base.BaseHook.get_connection", side_effect=get_airflow_connection_with_port + ): + custom_adapter = HTTPAdapter() + hook = HttpHook(method="GET", adapter=custom_adapter) + session = hook.get_conn() + assert isinstance( + session.adapters["http://"], type(custom_adapter) + ), "Custom HTTP adapter not correctly mounted" + assert isinstance( + session.adapters["https://"], type(custom_adapter) + ), "Custom HTTPS adapter not correctly mounted" + class TestHttpAsyncHook: @pytest.mark.asyncio