Skip to content

Commit

Permalink
Refactor step 3
Browse files Browse the repository at this point in the history
  • Loading branch information
vblagoje committed Jun 5, 2024
1 parent 24e89d1 commit 5167f71
Show file tree
Hide file tree
Showing 9 changed files with 75 additions and 103 deletions.
100 changes: 36 additions & 64 deletions haystack_experimental/util/openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,11 @@
from base64 import b64encode
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Union
from typing import Any, Callable, Dict, List, Literal, Optional, Union
from urllib.parse import urlparse

import requests
import yaml
from requests.adapters import HTTPAdapter
from urllib3 import Retry

from haystack_experimental.util.payload_extraction import create_function_payload_extractor
from haystack_experimental.util.schema_conversion import anthropic_converter, cohere_converter, openai_converter
Expand Down Expand Up @@ -105,63 +103,6 @@ def apply_auth(self, security_scheme: Dict[str, Any], request: Dict[str, Any]):
raise ValueError("HTTPAuthentication strategy received a non-HTTP security scheme.")


@dataclass
class HttpClientConfig:
"""Configuration for the HTTP client."""

timeout: int = 10
max_retries: int = 3
backoff_factor: float = 0.3
retry_on_status: set = field(default_factory=lambda: {500, 502, 503, 504})
default_headers: Dict[str, str] = field(default_factory=dict)


class HttpClient:
"""HTTP client for sending requests."""

def __init__(self, config: Optional[HttpClientConfig] = None):
self.config = config or HttpClientConfig()
self.session = requests.Session()
retries = Retry(
total=self.config.max_retries,
backoff_factor=self.config.backoff_factor,
status_forcelist=self.config.retry_on_status,
)
adapter = HTTPAdapter(max_retries=retries)
self.session.mount("http://", adapter)
self.session.mount("https://", adapter)
self.session.headers.update(self.config.default_headers)

def send_request(self, request: Dict[str, Any]) -> Any:
"""
Send an HTTP request using the provided request dictionary.
:param request: A dictionary containing the request details.
"""
url = request["url"]
headers = {**self.config.default_headers, **request.get("headers", {})}
try:
response = self.session.request(
request["method"],
url,
headers=headers,
params=request.get("params", {}),
json=request.get("json"),
auth=request.get("auth"),
)
response.raise_for_status()
return response.json()
except requests.exceptions.HTTPError as e:
logger.warning("HTTP error occurred: %s while sending request to %s", e, url)
raise HttpClientError(f"HTTP error occurred: {e}") from e
except requests.exceptions.RequestException as e:
logger.warning("Request error occurred: %s while sending request to %s", e, url)
raise HttpClientError(f"HTTP error occurred: {e}") from e
except Exception as e:
logger.warning("An error occurred: %s while sending request to %s", e, url)
raise HttpClientError(f"An error occurred: {e}") from e


class HttpClientError(Exception):
"""Exception raised for errors in the HTTP client."""

Expand Down Expand Up @@ -305,7 +246,7 @@ def __init__( # noqa: PLR0913 pylint: disable=too-many-arguments
self,
openapi_spec: Union[str, Path, Dict[str, Any]],
credentials: Optional[Union[str, Dict[str, Any], AuthenticationStrategy]] = None,
http_client: Optional[HttpClient] = None,
request_sender: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
llm_provider: Optional[str] = None,
): # noqa: PLR0913
if isinstance(openapi_spec, (str, Path)) and os.path.isfile(openapi_spec):
Expand All @@ -321,7 +262,7 @@ def __init__( # noqa: PLR0913 pylint: disable=too-many-arguments
raise ValueError("Invalid OpenAPI specification format. Expected file path or dictionary.")

self.credentials = credentials
self.http_client = http_client or HttpClient(HttpClientConfig())
self.request_sender = request_sender
self.llm_provider = llm_provider or "openai"

def get_auth_config(self) -> AuthenticationStrategy:
Expand Down Expand Up @@ -395,7 +336,7 @@ class OpenAPIServiceClient:

def __init__(self, client_config: ClientConfiguration):
self.client_config = client_config
self.http_client = client_config.http_client
self.request_sender = client_config.request_sender or self._request_sender()

def invoke(self, function_payload: Any) -> Any:
"""
Expand All @@ -416,7 +357,38 @@ def invoke(self, function_payload: Any) -> Any:
operation = self.client_config.openapi_spec.find_operation_by_id(fn_invocation_payload.get("name"))
request = self._build_request(operation, **fn_invocation_payload.get("arguments"))
self._apply_authentication(self.client_config.get_auth_config(), operation, request)
return self.http_client.send_request(request)
return self.request_sender(request)

def _request_sender(self) -> Callable[[Dict[str, Any]], Dict[str, Any]]:
"""
Returns a callable that sends the request using the HTTP client.
"""
def send_request(request: Dict[str, Any]) -> Dict[str, Any]:
url = request["url"]
headers = {**request.get("headers", {})}
try:
response = requests.request(
request["method"],
url,
headers=headers,
params=request.get("params", {}),
json=request.get("json"),
auth=request.get("auth"),
timeout=10,
)
response.raise_for_status()
return response.json()
except requests.exceptions.HTTPError as e:
logger.warning("HTTP error occurred: %s while sending request to %s", e, url)
raise HttpClientError(f"HTTP error occurred: {e}") from e
except requests.exceptions.RequestException as e:
logger.warning("Request error occurred: %s while sending request to %s", e, url)
raise HttpClientError(f"HTTP error occurred: {e}") from e
except Exception as e:
logger.warning("An error occurred: %s while sending request to %s", e, url)
raise HttpClientError(f"An error occurred: {e}") from e

return send_request

def _build_request(self, operation: Operation, **kwargs) -> Any:
request = {
Expand Down
48 changes: 24 additions & 24 deletions haystack_experimental/util/payload_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,30 @@
from typing import Any, Callable, Dict, List, Optional, Union


def create_function_payload_extractor(arguments_field_name: str) -> Callable[[Any], Dict[str, Any]]:
"""
Extracts invocation payload from a given LLM completion containing function invocation.
"""
def _extract_function_invocation(payload: Any) -> Dict[str, Any]:
"""
Extract the function invocation details from the payload.
"""
fields_and_values = _search(payload, arguments_field_name)
if fields_and_values:
arguments = fields_and_values.get(arguments_field_name)
if not isinstance(arguments, (str, dict)):
raise ValueError(
f"Invalid {arguments_field_name} type {type(arguments)} for function call, expected str/dict"
)
return {
"name": fields_and_values.get("name"),
"arguments": json.loads(arguments) if isinstance(arguments, str) else arguments,
}
return {}

return _extract_function_invocation


def _get_dict_converter(obj: Any,
method_names: Optional[List[str]] = None) -> Union[Callable[[], Dict[str, Any]], None]:
method_names = method_names or ["model_dump", "dict"] # search for pydantic v2 then v1
Expand Down Expand Up @@ -41,27 +65,3 @@ def _search(payload: Any, arguments_field_name: str) -> Dict[str, Any]:
if result:
return result
return {}


def create_function_payload_extractor(arguments_field_name: str) -> Callable[[Any], Dict[str, Any]]:
"""
Extracts invocation payload from a given LLM completion containing function invocation.
"""
def _extract_function_invocation(payload: Any) -> Dict[str, Any]:
"""
Extract the function invocation details from the payload.
"""
fields_and_values = _search(payload, arguments_field_name)
if fields_and_values:
arguments = fields_and_values.get(arguments_field_name)
if not isinstance(arguments, (str, dict)):
raise ValueError(
f"Invalid {arguments_field_name} type {type(arguments)} for function call, expected str/dict"
)
return {
"name": fields_and_values.get("name"),
"arguments": json.loads(arguments) if isinstance(arguments, str) else arguments,
}
return {}

return _extract_function_invocation
6 changes: 3 additions & 3 deletions test/util/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@
from fastapi import FastAPI
from fastapi.testclient import TestClient

from haystack_experimental.util.openapi import HttpClient, HttpClientError
from haystack_experimental.util.openapi import HttpClientError


@pytest.fixture()
def test_files_path():
return Path(__file__).parent.parent / "test_files"


class FastAPITestClient(HttpClient):
class FastAPITestClient:

def __init__(self, app: FastAPI):
self.app = app
Expand All @@ -31,7 +31,7 @@ def strip_host(self, url: str) -> str:
new_path += "?" + parsed_url.query
return new_path

def send_request(self, request: dict) -> dict:
def __call__(self, request: dict) -> dict:
# OAS spec will list a server URL, but FastAPI doesn't need it for local testing, in fact it will fail
# if the URL has a host. So we strip it here.
url = self.strip_host(request["url"])
Expand Down
6 changes: 3 additions & 3 deletions test/util/test_openapi_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class TestOpenAPI:

def test_greet_mix_params_body(self, test_files_path):
config = ClientConfiguration(openapi_spec=test_files_path / "yaml" / "openapi_greeting_service.yml",
http_client=FastAPITestClient(create_greet_mix_params_body_app()))
request_sender=FastAPITestClient(create_greet_mix_params_body_app()))
client = OpenAPIServiceClient(config)
payload = {
"id": "call_NJr1NBz2Th7iUWJpRIJZoJIA",
Expand All @@ -88,7 +88,7 @@ def test_greet_mix_params_body(self, test_files_path):

def test_greet_params_only(self, test_files_path):
config = ClientConfiguration(openapi_spec=test_files_path / "yaml" / "openapi_greeting_service.yml",
http_client=FastAPITestClient(create_greet_params_only_app()))
request_sender=FastAPITestClient(create_greet_params_only_app()))
client = OpenAPIServiceClient(config)
payload = {
"id": "call_NJr1NBz2Th7iUWJpRIJZoJIA",
Expand All @@ -103,7 +103,7 @@ def test_greet_params_only(self, test_files_path):

def test_greet_request_body_only(self, test_files_path):
config = ClientConfiguration(openapi_spec=test_files_path / "yaml" / "openapi_greeting_service.yml",
http_client=FastAPITestClient(create_greet_request_body_only_app()))
request_sender=FastAPITestClient(create_greet_request_body_only_app()))
client = OpenAPIServiceClient(config)
payload = {
"id": "call_NJr1NBz2Th7iUWJpRIJZoJIA",
Expand Down
10 changes: 5 additions & 5 deletions test/util/test_openapi_client_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ class TestOpenAPIAuth:

def test_greet_api_key_auth(self, test_files_path):
config = ClientConfiguration(openapi_spec=test_files_path / "yaml" / "openapi_greeting_service.yml",
http_client=FastAPITestClient(create_greet_api_key_auth_app()),
request_sender=FastAPITestClient(create_greet_api_key_auth_app()),
credentials=ApiKeyAuthentication(API_KEY))
client = OpenAPIServiceClient(config)
payload = {
Expand All @@ -155,7 +155,7 @@ def test_greet_api_key_auth(self, test_files_path):

def test_greet_basic_auth(self, test_files_path):
config = ClientConfiguration(openapi_spec=test_files_path / "yaml" / "openapi_greeting_service.yml",
http_client=FastAPITestClient(create_greet_basic_auth_app()),
request_sender=FastAPITestClient(create_greet_basic_auth_app()),
credentials=HTTPAuthentication(BASIC_AUTH_USERNAME, BASIC_AUTH_PASSWORD))
client = OpenAPIServiceClient(config)
payload = {
Expand All @@ -171,7 +171,7 @@ def test_greet_basic_auth(self, test_files_path):

def test_greet_api_key_query_auth(self, test_files_path):
config = ClientConfiguration(openapi_spec=test_files_path / "yaml" / "openapi_greeting_service.yml",
http_client=FastAPITestClient(create_greet_api_key_query_app()),
request_sender=FastAPITestClient(create_greet_api_key_query_app()),
credentials=ApiKeyAuthentication(API_KEY_QUERY))
client = OpenAPIServiceClient(config)
payload = {
Expand All @@ -188,7 +188,7 @@ def test_greet_api_key_query_auth(self, test_files_path):
def test_greet_api_key_cookie_auth(self, test_files_path):

config = ClientConfiguration(openapi_spec=test_files_path / "yaml" / "openapi_greeting_service.yml",
http_client=FastAPITestClient(create_greet_api_key_cookie_app()),
request_sender=FastAPITestClient(create_greet_api_key_cookie_app()),
credentials=ApiKeyAuthentication(API_KEY_COOKIE))

client = OpenAPIServiceClient(config)
Expand All @@ -205,7 +205,7 @@ def test_greet_api_key_cookie_auth(self, test_files_path):

def test_greet_bearer_auth(self, test_files_path):
config = ClientConfiguration(openapi_spec=test_files_path / "yaml" / "openapi_greeting_service.yml",
http_client=FastAPITestClient(create_greet_bearer_auth_app()),
request_sender=FastAPITestClient(create_greet_bearer_auth_app()),
credentials=HTTPAuthentication(token=BEARER_TOKEN))
client = OpenAPIServiceClient(config)
payload = {
Expand Down
2 changes: 1 addition & 1 deletion test/util/test_openapi_client_complex_request_body.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def test_create_order(self, spec_file_path, test_files_path):
path_element = "yaml" if spec_file_path.endswith(".yml") else "json"

config = ClientConfiguration(openapi_spec=test_files_path / path_element / spec_file_path,
http_client=FastAPITestClient(create_order_app()))
request_sender=FastAPITestClient(create_order_app()))

client = OpenAPIServiceClient(config)
order_json = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class TestPaymentProcess:

def test_process_payment(self, test_files_path):
config = ClientConfiguration(openapi_spec=test_files_path / "json" / "complex_types_openapi_service.json",
http_client=FastAPITestClient(create_payment_app()))
request_sender=FastAPITestClient(create_payment_app()))
client = OpenAPIServiceClient(config)

payment_json = {
Expand Down
2 changes: 1 addition & 1 deletion test/util/test_openapi_client_edge_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class TestEdgeCases:

def test_missing_operation_id(self, test_files_path):
config = ClientConfiguration(openapi_spec=test_files_path / "yaml" / "openapi_edge_cases.yml",
http_client=FastAPITestClient(None))
request_sender=FastAPITestClient(None))
client = OpenAPIServiceClient(config)

payload = {
Expand Down
2 changes: 1 addition & 1 deletion test/util/test_openapi_client_error_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class TestErrorHandling:
@pytest.mark.parametrize("status_code", [400, 401, 403, 404, 500])
def test_http_error_handling(self, test_files_path, status_code):
config = ClientConfiguration(openapi_spec=test_files_path / "yaml" / "openapi_error_handling.yml",
http_client=FastAPITestClient(create_error_handling_app()))
request_sender=FastAPITestClient(create_error_handling_app()))
client = OpenAPIServiceClient(config)
json_error = {"status_code": status_code}
payload = {
Expand Down

0 comments on commit 5167f71

Please sign in to comment.