From 5167f7142ad8b4383fa0772384105c7cc8e57e2c Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Wed, 5 Jun 2024 15:06:51 +0200 Subject: [PATCH] Refactor step 3 --- haystack_experimental/util/openapi.py | 100 +++++++----------- .../util/payload_extraction.py | 48 ++++----- test/util/conftest.py | 6 +- test/util/test_openapi_client.py | 6 +- test/util/test_openapi_client_auth.py | 10 +- ...est_openapi_client_complex_request_body.py | 2 +- ...enapi_client_complex_request_body_mixed.py | 2 +- test/util/test_openapi_client_edge_cases.py | 2 +- .../test_openapi_client_error_handling.py | 2 +- 9 files changed, 75 insertions(+), 103 deletions(-) diff --git a/haystack_experimental/util/openapi.py b/haystack_experimental/util/openapi.py index 9d7dc420..fef95de7 100644 --- a/haystack_experimental/util/openapi.py +++ b/haystack_experimental/util/openapi.py @@ -8,13 +8,11 @@ from base64 import b64encode from dataclasses import dataclass, field from pathlib import Path -from typing import Any, Dict, List, Literal, Optional, Union +from typing import Any, Callable, Dict, List, Literal, Optional, Union from urllib.parse import urlparse import requests import yaml -from requests.adapters import HTTPAdapter -from urllib3 import Retry from haystack_experimental.util.payload_extraction import create_function_payload_extractor from haystack_experimental.util.schema_conversion import anthropic_converter, cohere_converter, openai_converter @@ -105,63 +103,6 @@ def apply_auth(self, security_scheme: Dict[str, Any], request: Dict[str, Any]): raise ValueError("HTTPAuthentication strategy received a non-HTTP security scheme.") -@dataclass -class HttpClientConfig: - """Configuration for the HTTP client.""" - - timeout: int = 10 - max_retries: int = 3 - backoff_factor: float = 0.3 - retry_on_status: set = field(default_factory=lambda: {500, 502, 503, 504}) - default_headers: Dict[str, str] = field(default_factory=dict) - - -class HttpClient: - """HTTP client for sending requests.""" - - def __init__(self, config: Optional[HttpClientConfig] = None): - self.config = config or HttpClientConfig() - self.session = requests.Session() - retries = Retry( - total=self.config.max_retries, - backoff_factor=self.config.backoff_factor, - status_forcelist=self.config.retry_on_status, - ) - adapter = HTTPAdapter(max_retries=retries) - self.session.mount("http://", adapter) - self.session.mount("https://", adapter) - self.session.headers.update(self.config.default_headers) - - def send_request(self, request: Dict[str, Any]) -> Any: - """ - Send an HTTP request using the provided request dictionary. - - :param request: A dictionary containing the request details. - """ - url = request["url"] - headers = {**self.config.default_headers, **request.get("headers", {})} - try: - response = self.session.request( - request["method"], - url, - headers=headers, - params=request.get("params", {}), - json=request.get("json"), - auth=request.get("auth"), - ) - response.raise_for_status() - return response.json() - except requests.exceptions.HTTPError as e: - logger.warning("HTTP error occurred: %s while sending request to %s", e, url) - raise HttpClientError(f"HTTP error occurred: {e}") from e - except requests.exceptions.RequestException as e: - logger.warning("Request error occurred: %s while sending request to %s", e, url) - raise HttpClientError(f"HTTP error occurred: {e}") from e - except Exception as e: - logger.warning("An error occurred: %s while sending request to %s", e, url) - raise HttpClientError(f"An error occurred: {e}") from e - - class HttpClientError(Exception): """Exception raised for errors in the HTTP client.""" @@ -305,7 +246,7 @@ def __init__( # noqa: PLR0913 pylint: disable=too-many-arguments self, openapi_spec: Union[str, Path, Dict[str, Any]], credentials: Optional[Union[str, Dict[str, Any], AuthenticationStrategy]] = None, - http_client: Optional[HttpClient] = None, + request_sender: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None, llm_provider: Optional[str] = None, ): # noqa: PLR0913 if isinstance(openapi_spec, (str, Path)) and os.path.isfile(openapi_spec): @@ -321,7 +262,7 @@ def __init__( # noqa: PLR0913 pylint: disable=too-many-arguments raise ValueError("Invalid OpenAPI specification format. Expected file path or dictionary.") self.credentials = credentials - self.http_client = http_client or HttpClient(HttpClientConfig()) + self.request_sender = request_sender self.llm_provider = llm_provider or "openai" def get_auth_config(self) -> AuthenticationStrategy: @@ -395,7 +336,7 @@ class OpenAPIServiceClient: def __init__(self, client_config: ClientConfiguration): self.client_config = client_config - self.http_client = client_config.http_client + self.request_sender = client_config.request_sender or self._request_sender() def invoke(self, function_payload: Any) -> Any: """ @@ -416,7 +357,38 @@ def invoke(self, function_payload: Any) -> Any: operation = self.client_config.openapi_spec.find_operation_by_id(fn_invocation_payload.get("name")) request = self._build_request(operation, **fn_invocation_payload.get("arguments")) self._apply_authentication(self.client_config.get_auth_config(), operation, request) - return self.http_client.send_request(request) + return self.request_sender(request) + + def _request_sender(self) -> Callable[[Dict[str, Any]], Dict[str, Any]]: + """ + Returns a callable that sends the request using the HTTP client. + """ + def send_request(request: Dict[str, Any]) -> Dict[str, Any]: + url = request["url"] + headers = {**request.get("headers", {})} + try: + response = requests.request( + request["method"], + url, + headers=headers, + params=request.get("params", {}), + json=request.get("json"), + auth=request.get("auth"), + timeout=10, + ) + response.raise_for_status() + return response.json() + except requests.exceptions.HTTPError as e: + logger.warning("HTTP error occurred: %s while sending request to %s", e, url) + raise HttpClientError(f"HTTP error occurred: {e}") from e + except requests.exceptions.RequestException as e: + logger.warning("Request error occurred: %s while sending request to %s", e, url) + raise HttpClientError(f"HTTP error occurred: {e}") from e + except Exception as e: + logger.warning("An error occurred: %s while sending request to %s", e, url) + raise HttpClientError(f"An error occurred: {e}") from e + + return send_request def _build_request(self, operation: Operation, **kwargs) -> Any: request = { diff --git a/haystack_experimental/util/payload_extraction.py b/haystack_experimental/util/payload_extraction.py index 202d11c1..b98969c1 100644 --- a/haystack_experimental/util/payload_extraction.py +++ b/haystack_experimental/util/payload_extraction.py @@ -3,6 +3,30 @@ from typing import Any, Callable, Dict, List, Optional, Union +def create_function_payload_extractor(arguments_field_name: str) -> Callable[[Any], Dict[str, Any]]: + """ + Extracts invocation payload from a given LLM completion containing function invocation. + """ + def _extract_function_invocation(payload: Any) -> Dict[str, Any]: + """ + Extract the function invocation details from the payload. + """ + fields_and_values = _search(payload, arguments_field_name) + if fields_and_values: + arguments = fields_and_values.get(arguments_field_name) + if not isinstance(arguments, (str, dict)): + raise ValueError( + f"Invalid {arguments_field_name} type {type(arguments)} for function call, expected str/dict" + ) + return { + "name": fields_and_values.get("name"), + "arguments": json.loads(arguments) if isinstance(arguments, str) else arguments, + } + return {} + + return _extract_function_invocation + + def _get_dict_converter(obj: Any, method_names: Optional[List[str]] = None) -> Union[Callable[[], Dict[str, Any]], None]: method_names = method_names or ["model_dump", "dict"] # search for pydantic v2 then v1 @@ -41,27 +65,3 @@ def _search(payload: Any, arguments_field_name: str) -> Dict[str, Any]: if result: return result return {} - - -def create_function_payload_extractor(arguments_field_name: str) -> Callable[[Any], Dict[str, Any]]: - """ - Extracts invocation payload from a given LLM completion containing function invocation. - """ - def _extract_function_invocation(payload: Any) -> Dict[str, Any]: - """ - Extract the function invocation details from the payload. - """ - fields_and_values = _search(payload, arguments_field_name) - if fields_and_values: - arguments = fields_and_values.get(arguments_field_name) - if not isinstance(arguments, (str, dict)): - raise ValueError( - f"Invalid {arguments_field_name} type {type(arguments)} for function call, expected str/dict" - ) - return { - "name": fields_and_values.get("name"), - "arguments": json.loads(arguments) if isinstance(arguments, str) else arguments, - } - return {} - - return _extract_function_invocation diff --git a/test/util/conftest.py b/test/util/conftest.py index 39caa850..05bd940a 100644 --- a/test/util/conftest.py +++ b/test/util/conftest.py @@ -10,7 +10,7 @@ from fastapi import FastAPI from fastapi.testclient import TestClient -from haystack_experimental.util.openapi import HttpClient, HttpClientError +from haystack_experimental.util.openapi import HttpClientError @pytest.fixture() @@ -18,7 +18,7 @@ def test_files_path(): return Path(__file__).parent.parent / "test_files" -class FastAPITestClient(HttpClient): +class FastAPITestClient: def __init__(self, app: FastAPI): self.app = app @@ -31,7 +31,7 @@ def strip_host(self, url: str) -> str: new_path += "?" + parsed_url.query return new_path - def send_request(self, request: dict) -> dict: + def __call__(self, request: dict) -> dict: # OAS spec will list a server URL, but FastAPI doesn't need it for local testing, in fact it will fail # if the URL has a host. So we strip it here. url = self.strip_host(request["url"]) diff --git a/test/util/test_openapi_client.py b/test/util/test_openapi_client.py index 3769bd0c..66a74056 100644 --- a/test/util/test_openapi_client.py +++ b/test/util/test_openapi_client.py @@ -73,7 +73,7 @@ class TestOpenAPI: def test_greet_mix_params_body(self, test_files_path): config = ClientConfiguration(openapi_spec=test_files_path / "yaml" / "openapi_greeting_service.yml", - http_client=FastAPITestClient(create_greet_mix_params_body_app())) + request_sender=FastAPITestClient(create_greet_mix_params_body_app())) client = OpenAPIServiceClient(config) payload = { "id": "call_NJr1NBz2Th7iUWJpRIJZoJIA", @@ -88,7 +88,7 @@ def test_greet_mix_params_body(self, test_files_path): def test_greet_params_only(self, test_files_path): config = ClientConfiguration(openapi_spec=test_files_path / "yaml" / "openapi_greeting_service.yml", - http_client=FastAPITestClient(create_greet_params_only_app())) + request_sender=FastAPITestClient(create_greet_params_only_app())) client = OpenAPIServiceClient(config) payload = { "id": "call_NJr1NBz2Th7iUWJpRIJZoJIA", @@ -103,7 +103,7 @@ def test_greet_params_only(self, test_files_path): def test_greet_request_body_only(self, test_files_path): config = ClientConfiguration(openapi_spec=test_files_path / "yaml" / "openapi_greeting_service.yml", - http_client=FastAPITestClient(create_greet_request_body_only_app())) + request_sender=FastAPITestClient(create_greet_request_body_only_app())) client = OpenAPIServiceClient(config) payload = { "id": "call_NJr1NBz2Th7iUWJpRIJZoJIA", diff --git a/test/util/test_openapi_client_auth.py b/test/util/test_openapi_client_auth.py index 6797c2aa..014edde5 100644 --- a/test/util/test_openapi_client_auth.py +++ b/test/util/test_openapi_client_auth.py @@ -139,7 +139,7 @@ class TestOpenAPIAuth: def test_greet_api_key_auth(self, test_files_path): config = ClientConfiguration(openapi_spec=test_files_path / "yaml" / "openapi_greeting_service.yml", - http_client=FastAPITestClient(create_greet_api_key_auth_app()), + request_sender=FastAPITestClient(create_greet_api_key_auth_app()), credentials=ApiKeyAuthentication(API_KEY)) client = OpenAPIServiceClient(config) payload = { @@ -155,7 +155,7 @@ def test_greet_api_key_auth(self, test_files_path): def test_greet_basic_auth(self, test_files_path): config = ClientConfiguration(openapi_spec=test_files_path / "yaml" / "openapi_greeting_service.yml", - http_client=FastAPITestClient(create_greet_basic_auth_app()), + request_sender=FastAPITestClient(create_greet_basic_auth_app()), credentials=HTTPAuthentication(BASIC_AUTH_USERNAME, BASIC_AUTH_PASSWORD)) client = OpenAPIServiceClient(config) payload = { @@ -171,7 +171,7 @@ def test_greet_basic_auth(self, test_files_path): def test_greet_api_key_query_auth(self, test_files_path): config = ClientConfiguration(openapi_spec=test_files_path / "yaml" / "openapi_greeting_service.yml", - http_client=FastAPITestClient(create_greet_api_key_query_app()), + request_sender=FastAPITestClient(create_greet_api_key_query_app()), credentials=ApiKeyAuthentication(API_KEY_QUERY)) client = OpenAPIServiceClient(config) payload = { @@ -188,7 +188,7 @@ def test_greet_api_key_query_auth(self, test_files_path): def test_greet_api_key_cookie_auth(self, test_files_path): config = ClientConfiguration(openapi_spec=test_files_path / "yaml" / "openapi_greeting_service.yml", - http_client=FastAPITestClient(create_greet_api_key_cookie_app()), + request_sender=FastAPITestClient(create_greet_api_key_cookie_app()), credentials=ApiKeyAuthentication(API_KEY_COOKIE)) client = OpenAPIServiceClient(config) @@ -205,7 +205,7 @@ def test_greet_api_key_cookie_auth(self, test_files_path): def test_greet_bearer_auth(self, test_files_path): config = ClientConfiguration(openapi_spec=test_files_path / "yaml" / "openapi_greeting_service.yml", - http_client=FastAPITestClient(create_greet_bearer_auth_app()), + request_sender=FastAPITestClient(create_greet_bearer_auth_app()), credentials=HTTPAuthentication(token=BEARER_TOKEN)) client = OpenAPIServiceClient(config) payload = { diff --git a/test/util/test_openapi_client_complex_request_body.py b/test/util/test_openapi_client_complex_request_body.py index 79549177..ccfb63df 100644 --- a/test/util/test_openapi_client_complex_request_body.py +++ b/test/util/test_openapi_client_complex_request_body.py @@ -59,7 +59,7 @@ def test_create_order(self, spec_file_path, test_files_path): path_element = "yaml" if spec_file_path.endswith(".yml") else "json" config = ClientConfiguration(openapi_spec=test_files_path / path_element / spec_file_path, - http_client=FastAPITestClient(create_order_app())) + request_sender=FastAPITestClient(create_order_app())) client = OpenAPIServiceClient(config) order_json = { diff --git a/test/util/test_openapi_client_complex_request_body_mixed.py b/test/util/test_openapi_client_complex_request_body_mixed.py index b769bbbf..4dc532fc 100644 --- a/test/util/test_openapi_client_complex_request_body_mixed.py +++ b/test/util/test_openapi_client_complex_request_body_mixed.py @@ -57,7 +57,7 @@ class TestPaymentProcess: def test_process_payment(self, test_files_path): config = ClientConfiguration(openapi_spec=test_files_path / "json" / "complex_types_openapi_service.json", - http_client=FastAPITestClient(create_payment_app())) + request_sender=FastAPITestClient(create_payment_app())) client = OpenAPIServiceClient(config) payment_json = { diff --git a/test/util/test_openapi_client_edge_cases.py b/test/util/test_openapi_client_edge_cases.py index e1c36920..888580d9 100644 --- a/test/util/test_openapi_client_edge_cases.py +++ b/test/util/test_openapi_client_edge_cases.py @@ -13,7 +13,7 @@ class TestEdgeCases: def test_missing_operation_id(self, test_files_path): config = ClientConfiguration(openapi_spec=test_files_path / "yaml" / "openapi_edge_cases.yml", - http_client=FastAPITestClient(None)) + request_sender=FastAPITestClient(None)) client = OpenAPIServiceClient(config) payload = { diff --git a/test/util/test_openapi_client_error_handling.py b/test/util/test_openapi_client_error_handling.py index 53e9478a..d826e33b 100644 --- a/test/util/test_openapi_client_error_handling.py +++ b/test/util/test_openapi_client_error_handling.py @@ -27,7 +27,7 @@ class TestErrorHandling: @pytest.mark.parametrize("status_code", [400, 401, 403, 404, 500]) def test_http_error_handling(self, test_files_path, status_code): config = ClientConfiguration(openapi_spec=test_files_path / "yaml" / "openapi_error_handling.yml", - http_client=FastAPITestClient(create_error_handling_app())) + request_sender=FastAPITestClient(create_error_handling_app())) client = OpenAPIServiceClient(config) json_error = {"status_code": status_code} payload = {