From 472b2de5313bccdc9cba0ead422ec80f115d5777 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Tue, 28 May 2024 09:36:28 +0200 Subject: [PATCH 01/40] Initial openapi impl --- .../components/connectors/__init__.py | 7 + .../components/connectors/openapi.py | 125 +++ .../components/converters/__init__.py | 7 + .../components/converters/openapi.py | 83 ++ haystack_experimental/util/__init__.py | 3 + haystack_experimental/util/openapi.py | 968 ++++++++++++++++++ pyproject.toml | 5 +- test/components/connectors/__init__.py | 3 + test/components/connectors/test_openapi.py | 114 +++ test/components/converters/__init__.py | 3 + test/components/converters/test_openapi.py | 246 +++++ .../json/complex_types_openai_spec.json | 64 ++ .../json/complex_types_openapi_service.json | 103 ++ .../json/openapi_order_service.json | 109 ++ .../json/serperdev_openapi_spec.json | 62 ++ test/test_files/yaml/github_compare.yml | 438 ++++++++ test/test_files/yaml/openapi_edge_cases.yml | 13 + .../yaml/openapi_error_handling.yml | 24 + .../yaml/openapi_greeting_service.yml | 272 +++++ .../test_files/yaml/openapi_order_service.yml | 75 ++ test/test_files/yaml/serper.yml | 39 + test/util/__init__.py | 3 + test/util/conftest.py | 59 ++ test/util/test_openapi_client.py | 129 +++ test/util/test_openapi_client_auth.py | 238 +++++ ...est_openapi_client_complex_request_body.py | 87 ++ ...enapi_client_complex_request_body_mixed.py | 90 ++ test/util/test_openapi_client_edge_cases.py | 38 + .../test_openapi_client_error_handling.py | 46 + test/util/test_openapi_client_live.py | 73 ++ .../test_openapi_client_live_anthropic.py | 69 ++ test/util/test_openapi_client_live_cohere.py | 73 ++ test/util/test_openapi_client_live_openai.py | 99 ++ test/util/test_openapi_cohere_conversion.py | 79 ++ test/util/test_openapi_openai_conversion.py | 104 ++ test/util/test_openapi_spec.py | 127 +++ 36 files changed, 4076 insertions(+), 1 deletion(-) create mode 100644 haystack_experimental/components/connectors/__init__.py create mode 100644 haystack_experimental/components/connectors/openapi.py create mode 100644 haystack_experimental/components/converters/__init__.py create mode 100644 haystack_experimental/components/converters/openapi.py create mode 100644 haystack_experimental/util/__init__.py create mode 100644 haystack_experimental/util/openapi.py create mode 100644 test/components/connectors/__init__.py create mode 100644 test/components/connectors/test_openapi.py create mode 100644 test/components/converters/__init__.py create mode 100644 test/components/converters/test_openapi.py create mode 100644 test/test_files/json/complex_types_openai_spec.json create mode 100644 test/test_files/json/complex_types_openapi_service.json create mode 100644 test/test_files/json/openapi_order_service.json create mode 100644 test/test_files/json/serperdev_openapi_spec.json create mode 100644 test/test_files/yaml/github_compare.yml create mode 100644 test/test_files/yaml/openapi_edge_cases.yml create mode 100644 test/test_files/yaml/openapi_error_handling.yml create mode 100644 test/test_files/yaml/openapi_greeting_service.yml create mode 100644 test/test_files/yaml/openapi_order_service.yml create mode 100644 test/test_files/yaml/serper.yml create mode 100644 test/util/__init__.py create mode 100644 test/util/conftest.py create mode 100644 test/util/test_openapi_client.py create mode 100644 test/util/test_openapi_client_auth.py create mode 100644 test/util/test_openapi_client_complex_request_body.py create mode 100644 test/util/test_openapi_client_complex_request_body_mixed.py create mode 100644 test/util/test_openapi_client_edge_cases.py create mode 100644 test/util/test_openapi_client_error_handling.py create mode 100644 test/util/test_openapi_client_live.py create mode 100644 test/util/test_openapi_client_live_anthropic.py create mode 100644 test/util/test_openapi_client_live_cohere.py create mode 100644 test/util/test_openapi_client_live_openai.py create mode 100644 test/util/test_openapi_cohere_conversion.py create mode 100644 test/util/test_openapi_openai_conversion.py create mode 100644 test/util/test_openapi_spec.py diff --git a/haystack_experimental/components/connectors/__init__.py b/haystack_experimental/components/connectors/__init__.py new file mode 100644 index 00000000..9c3d11ec --- /dev/null +++ b/haystack_experimental/components/connectors/__init__.py @@ -0,0 +1,7 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from haystack_experimental.components.connectors.openapi import OpenAPIServiceConnector + +__all__ = ["OpenAPIServiceConnector"] diff --git a/haystack_experimental/components/connectors/openapi.py b/haystack_experimental/components/connectors/openapi.py new file mode 100644 index 00000000..a9afce5c --- /dev/null +++ b/haystack_experimental/components/connectors/openapi.py @@ -0,0 +1,125 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import json +from typing import Any, Dict, List, Optional, Union + +from haystack import component, logging +from haystack.dataclasses import ChatMessage, ChatRole + +from haystack_experimental.util.openapi import ClientConfigurationBuilder, OpenAPIServiceClient, validate_provider + +logger = logging.getLogger(__name__) + + +@component +class OpenAPIServiceConnector: + """ + The `OpenAPIServiceConnector` component connects the Haystack framework to OpenAPI services. + + It integrates with `ChatMessage` dataclass, where the payload in messages is used to determine the method to be + called and the parameters to be passed. The response from the service is returned as a `ChatMessage`. + + Function calling payloads from OpenAI, Anthropic, and Cohere LLMs are supported. + + Before using this component, users usually resolve function calling function definitions with a help of + `OpenAPIServiceToFunctions` component. + + The example below demonstrates how to use the `OpenAPIServiceConnector` to invoke a method on a + https://serper.dev/ service specified via OpenAPI specification. + + Note, however, that `OpenAPIServiceConnector` is usually not meant to be used directly, but rather as part of a + pipeline that includes the `OpenAPIServiceToFunctions` component and an `OpenAIChatGenerator` component using LLM + with the function calling capabilities. In the example below we use the function calling payload directly, but in a + real-world scenario, the function calling payload would usually be generated by the `OpenAIChatGenerator` + component. + + Usage example: + + ```python + import json + import requests + + from haystack_experimental.components.connectors import OpenAPIServiceConnector + from haystack.dataclasses import ChatMessage + + + fc_payload = [{'function': {'arguments': '{"q": "Why was Sam Altman ousted from OpenAI?"}', 'name': 'search'}, + 'id': 'call_PmEBYvZ7mGrQP5PUASA5m9wO', 'type': 'function'}] + + serper_token = + serperdev_openapi_spec = json.loads(requests.get("https://bit.ly/serper_dev_spec").text) + service_connector = OpenAPIServiceConnector() + result = service_connector.run(messages=[ChatMessage.from_assistant(json.dumps(fc_payload))], + service_openapi_spec=serperdev_openapi_spec, service_credentials=serper_token) + print(result) + + >> {'service_response': [ChatMessage(content='{"searchParameters": {"q": "Why was Sam Altman ousted from OpenAI?", + >> "type": "search", "engine": "google"}, "answerBox": {"snippet": "Concerns over AI safety and OpenAI\'s role + >> in protecting were at the center of Altman\'s brief ouster from the company."... + ``` + + """ + + def __init__(self, provider: Optional[str] = None): + """ + Initializes the OpenAPIServiceConnector instance. + """ + self.llm_provider = validate_provider(provider or "openai") + + @component.output_types(service_response=Dict[str, Any]) + def run( + self, + messages: List[ChatMessage], + service_openapi_spec: Dict[str, Any], + service_credentials: Optional[Union[dict, str]] = None, + ) -> Dict[str, List[ChatMessage]]: + """ + Processes a list of chat messages to invoke a method on an OpenAPI service. + + It parses the last message in the list, expecting it to contain an OpenAI function calling descriptor + (name & parameters) in JSON format. + + :param messages: A list of `ChatMessage` objects containing the messages to be processed. The last message + should contain the function invocation payload in OpenAI function calling format. See the example in the class + docstring for the expected format. + :param service_openapi_spec: The OpenAPI JSON specification object of the service to be invoked. + :param service_credentials: The credentials to be used for authentication with the service. + Currently, only the http and apiKey OpenAPI security schemes are supported. + + :return: A dictionary with the following keys: + - `service_response`: a list of `ChatMessage` objects, each containing the response from the service. The + response is in JSON format, and the `content` attribute of the `ChatMessage` + contains the JSON string. + + :raises ValueError: If the last message is not from the assistant or if it does not contain the correct + payload to invoke a method on the service. + """ + + last_message = messages[-1] + if not last_message.is_from(ChatRole.ASSISTANT): + raise ValueError(f"{last_message} is not from the assistant.") + if not last_message.content: + raise ValueError("Function calling message content is empty.") + + builder = ClientConfigurationBuilder() + config_openapi = ( + builder.with_openapi_spec(service_openapi_spec) + .with_credentials(service_credentials or {}) + .with_provider(self.llm_provider) + .build() + ) + logger.debug(f"Invoking service {config_openapi.get_openapi_spec().get_name()} with {last_message.content}") + openapi_service = OpenAPIServiceClient(config_openapi) + try: + payload = ( + json.loads(last_message.content) if isinstance(last_message.content, str) else last_message.content + ) + service_response = openapi_service.invoke(payload) + except Exception as e: # pylint: disable=broad-exception-caught + logger.error(f"Error invoking OpenAPI endpoint. Error: {e}") + service_response = {"error": str(e)} + response_messages = [ChatMessage.from_user(json.dumps(service_response))] + + return {"service_response": response_messages} diff --git a/haystack_experimental/components/converters/__init__.py b/haystack_experimental/components/converters/__init__.py new file mode 100644 index 00000000..6fd23f7d --- /dev/null +++ b/haystack_experimental/components/converters/__init__.py @@ -0,0 +1,7 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from haystack_experimental.components.converters.openapi import OpenAPIServiceToFunctions + +__all__ = ["OpenAPIServiceToFunctions"] diff --git a/haystack_experimental/components/converters/openapi.py b/haystack_experimental/components/converters/openapi.py new file mode 100644 index 00000000..17d0eb77 --- /dev/null +++ b/haystack_experimental/components/converters/openapi.py @@ -0,0 +1,83 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from pathlib import Path +from typing import Any, Dict, List, Optional, Union + +from haystack import component, logging +from haystack.dataclasses.byte_stream import ByteStream + +from haystack_experimental.util.openapi import ClientConfigurationBuilder, validate_provider + +logger = logging.getLogger(__name__) + + +@component +class OpenAPIServiceToFunctions: + """ + Converts OpenAPI service schemas to a format suitable for OpenAI, Anthropic, or Cohere function calling. + + The definition must respect OpenAPI specification 3.0.0 or higher. + It can be specified in JSON or YAML format. + Each function must have: + - unique operationId + - description + - requestBody and/or parameters + - schema for the requestBody and/or parameters + For more details on OpenAPI specification see the + [official documentation](https://github.com/OAI/OpenAPI-Specification). + + Usage example: + ```python + from haystack_experimental.components.converters import OpenAPIServiceToFunctions + + converter = OpenAPIServiceToFunctions() + result = converter.run(sources=["path/to/openapi_definition.yaml"]) + assert result["functions"] + ``` + """ + + MIN_REQUIRED_OPENAPI_SPEC_VERSION = 3 + + def __init__(self, provider: Optional[str] = None): + """ + Create an OpenAPIServiceToFunctions component. + + :param provider: The LLM provider to use, defaults to "openai". + """ + self.llm_provider = validate_provider(provider or "openai") + + @component.output_types(functions=List[Dict[str, Any]], openapi_specs=List[Dict[str, Any]]) + def run(self, sources: List[Union[str, Path, ByteStream]]) -> Dict[str, Any]: + """ + Converts OpenAPI definitions into LLM specific function calling format. + + :param sources: + File paths or ByteStream objects of OpenAPI definitions (in JSON or YAML format). + + :returns: + A dictionary with the following keys: + - functions: Function definitions in JSON object format + - openapi_specs: OpenAPI specs in JSON/YAML object format with resolved references + + :raises RuntimeError: + If the OpenAPI definitions cannot be downloaded or processed. + :raises ValueError: + If the source type is not recognized or no functions are found in the OpenAPI definitions. + """ + all_extracted_fc_definitions: List[Dict[str, Any]] = [] + all_openapi_specs = [] + + builder = ClientConfigurationBuilder() + for source in sources: + source = source.to_string() if isinstance(source, ByteStream) else source + # to get tools definitions all we need is the openapi spec + config_openapi = builder.with_openapi_spec(source).with_provider(self.llm_provider).build() + + all_extracted_fc_definitions.extend(config_openapi.get_tools_definitions()) + all_openapi_specs.append(config_openapi.get_openapi_spec().to_dict(resolve_references=True)) + if not all_extracted_fc_definitions: + logger.warning("No OpenAI function definitions extracted from the provided OpenAPI specification sources.") + + return {"functions": all_extracted_fc_definitions, "openapi_specs": all_openapi_specs} diff --git a/haystack_experimental/util/__init__.py b/haystack_experimental/util/__init__.py new file mode 100644 index 00000000..c1764a6e --- /dev/null +++ b/haystack_experimental/util/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 diff --git a/haystack_experimental/util/openapi.py b/haystack_experimental/util/openapi.py new file mode 100644 index 00000000..552892ca --- /dev/null +++ b/haystack_experimental/util/openapi.py @@ -0,0 +1,968 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import dataclasses +import json +import logging +import os +from base64 import b64encode +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Callable, Dict, List, Literal, Optional, Protocol, Union, runtime_checkable +from urllib.parse import urlparse + +import jsonref +import requests +import yaml +from requests.adapters import HTTPAdapter +from urllib3 import Retry + +VALID_HTTP_METHODS = ["get", "put", "post", "delete", "options", "head", "patch", "trace"] + +MIN_REQUIRED_OPENAPI_SPEC_VERSION = 3 + +logger = logging.getLogger(__name__) + + +def validate_provider(provider: str) -> str: + """ + Check if the selected provider is supported. + + :param provider: The selected provider to validate. + :return: The validated provider. + :raises ValueError: If the selected provider is not supported. + """ + available_providers = ["openai", "anthropic", "cohere"] + if provider not in available_providers: + raise ValueError(f"LLM provider {provider} is not supported. Available providers: {available_providers}") + return provider + + +@runtime_checkable +class AuthenticationStrategy(Protocol): + """ + Represents an authentication strategy that can be applied to an HTTP request. + """ + def apply_auth(self, security_scheme: Dict[str, Any], request: Dict[str, Any]): + """ + Apply the authentication strategy to the given request. + + :param security_scheme: the security scheme from the OpenAPI spec. + :param request: the request to apply the authentication to. + """ + + +class PassThroughAuthentication(AuthenticationStrategy): + """No-op authentication strategy that does nothing.""" + def apply_auth(self, security_scheme: Dict[str, Any], request: Dict[str, Any]): + """ + No-op authentication strategy that does nothing. + """ + + +@dataclass +class ApiKeyAuthentication(AuthenticationStrategy): + """ API key authentication strategy.""" + api_key: Optional[str] = None + + def apply_auth(self, security_scheme: Dict[str, Any], request: Dict[str, Any]): + """ + Apply the API key authentication strategy to the given request. + + :param security_scheme: the security scheme from the OpenAPI spec. + :param request: the request to apply the authentication to. + """ + if security_scheme["in"] == "header": + request.setdefault("headers", {})[security_scheme["name"]] = self.api_key + elif security_scheme["in"] == "query": + request.setdefault("params", {})[security_scheme["name"]] = self.api_key + elif security_scheme["in"] == "cookie": + request.setdefault("cookies", {})[security_scheme["name"]] = self.api_key + else: + raise ValueError( + f"Unsupported apiKey authentication location: {security_scheme['in']}, " + f"must be one of 'header', 'query', or 'cookie'" + ) + + +@dataclass +class HTTPAuthentication(AuthenticationStrategy): + """ HTTP authentication strategy.""" + username: Optional[str] = None + password: Optional[str] = None + token: Optional[str] = None + + def __post_init__(self): + if not self.token and (not self.username or not self.password): + raise ValueError( + "For HTTP Basic Auth, both username and password must be provided. " + "For Bearer Auth, a token must be provided." + ) + + def apply_auth(self, security_scheme: Dict[str, Any], request: Dict[str, Any]): + """ + Apply the HTTP authentication strategy to the given request. + + :param security_scheme: the security scheme from the OpenAPI spec. + :param request: the request to apply the authentication to. + """ + if security_scheme["type"] == "http": + if security_scheme["scheme"].lower() == "basic": + if not self.username or not self.password: + raise ValueError("Username and password must be provided for Basic Auth.") + credentials = f"{self.username}:{self.password}" + encoded_credentials = b64encode(credentials.encode("utf-8")).decode("utf-8") + request.setdefault("headers", {})["Authorization"] = f"Basic {encoded_credentials}" + elif security_scheme["scheme"].lower() == "bearer": + if not self.token: + raise ValueError("Token must be provided for Bearer Auth.") + request.setdefault("headers", {})["Authorization"] = f"Bearer {self.token}" + else: + raise ValueError(f"Unsupported HTTP authentication scheme: {security_scheme['scheme']}") + else: + raise ValueError("HTTPAuthentication strategy received a non-HTTP security scheme.") + + +@dataclass +class HttpClientConfig: + """ Configuration for the HTTP client. """ + timeout: int = 10 + max_retries: int = 3 + backoff_factor: float = 0.3 + retry_on_status: set = field(default_factory=lambda: {500, 502, 503, 504}) + default_headers: Dict[str, str] = field(default_factory=dict) + + +class HttpClient: + """ HTTP client for sending requests. """ + def __init__(self, config: Optional[HttpClientConfig] = None): + self.config = config or HttpClientConfig() + self.session = requests.Session() + self._initialize_session() + + def _initialize_session(self) -> None: + retries = Retry( + total=self.config.max_retries, + backoff_factor=self.config.backoff_factor, + status_forcelist=self.config.retry_on_status, + ) + adapter = HTTPAdapter(max_retries=retries) + self.session.mount("http://", adapter) + self.session.mount("https://", adapter) + self.session.headers.update(self.config.default_headers) + + def send_request(self, request: Dict[str, Any]) -> Any: + """ + Send an HTTP request using the provided request dictionary. + + :param request: A dictionary containing the request details. + """ + url = request["url"] + method = request["method"] + headers = {**self.config.default_headers, **request.get("headers", {})} + params = request.get("params", {}) + json_data = request.get("json") + auth = request.get("auth") + try: + response = self.session.request(method, url, headers=headers, params=params, json=json_data, auth=auth) + response.raise_for_status() + return response.json() + except requests.exceptions.HTTPError as e: + logger.warning("HTTP error occurred: %s while sending request to %s", e, url) + raise HttpClientError(f"HTTP error occurred: {e}") from e + except requests.exceptions.RequestException as e: + logger.warning("Request error occurred: %s while sending request to %s", e, url) + raise HttpClientError(f"HTTP error occurred: {e}") from e + except Exception as e: + logger.warning("An error occurred: %s while sending request to %s", e, url) + raise HttpClientError(f"An error occurred: {e}") from e + + +class HttpClientError(Exception): + """Exception raised for errors in the HTTP client.""" + + +class Operation: + """ Represents an operation in an OpenAPI specification.""" + def __init__(self, path: str, method: str, operation_dict: Dict[str, Any], spec_dict: Dict[str, Any]): + if method.lower() not in VALID_HTTP_METHODS: + raise ValueError(f"Invalid HTTP method: {method}") + self.path = path + self.method = method.lower() + self.operation_dict = operation_dict + self.spec_dict = spec_dict + + def get_parameters(self, location: Optional[Literal["header", "query", "path"]] = None) -> List[Dict[str, Any]]: + """ + Get the parameters for the operation. + + :param location: The location of the parameters to retrieve. If None, all parameters are returned. + """ + parameters = self.operation_dict.get("parameters", []) + path_item = self.spec_dict.get("paths", {}).get(self.path, {}) + parameters.extend(path_item.get("parameters", [])) + if location: + return [param for param in parameters if param["in"] == location] + return parameters + + def get_request_body(self) -> Dict[str, Any]: + """ + Get the request body for the operation. + """ + return self.operation_dict.get("requestBody", {}) + + def get_responses(self) -> Dict[str, Any]: + """ + Get the responses for the operation. + """ + return self.operation_dict.get("responses", {}) + + def get_security_requirements(self) -> List[Dict[str, List[str]]]: + """ + Get the security requirements for the operation. + """ + security_requirements = self.operation_dict.get("security", []) + if not security_requirements: + security_requirements = self.spec_dict.get("security", []) + return security_requirements + + def get_server_url(self) -> str: + """ + Get the server URL for the operation. + """ + servers = self.operation_dict.get("servers", []) + if not servers: + servers = self.spec_dict.get("servers", []) + if servers: + return servers[0].get("url", "") + return "" + + def get_field(self, key: str, default: Any = None) -> Any: + """ + Get a field from the operation dictionary. + """ + return self.operation_dict.get(key, default) + + +class OpenAPISpecification: + """ Represents an OpenAPI specification.""" + def __init__(self, spec_dict: Dict[str, Any]): + if not isinstance(spec_dict, Dict): + raise ValueError(f"Invalid OpenAPI specification, expected a dictionary: {spec_dict}") + # just a crude sanity check, by no means a full validation + if "openapi" not in spec_dict or "paths" not in spec_dict or "servers" not in spec_dict: + raise ValueError( + "Invalid OpenAPI specification format. See https://swagger.io/specification/ for details.", spec_dict + ) + self.spec_dict = spec_dict + + @classmethod + def from_dict(cls, spec_dict: Dict[str, Any]) -> "OpenAPISpecification": + """ + Create an OpenAPISpecification instance from a dictionary. + """ + parser = cls(spec_dict) + return parser + + @classmethod + def from_str(cls, content: str) -> "OpenAPISpecification": + """ + Create an OpenAPISpecification instance from a string. + """ + try: + loaded_spec = json.loads(content) + except json.JSONDecodeError: + try: + loaded_spec = yaml.safe_load(content) + except yaml.YAMLError as e: + raise ValueError("Content cannot be decoded as JSON or YAML: " + str(e)) from e + return cls(loaded_spec) + + @classmethod + def from_file(cls, spec_file: Union[str, Path]) -> "OpenAPISpecification": + """ + Create an OpenAPISpecification instance from a file. + """ + with open(spec_file, encoding="utf-8") as file: + content = file.read() + return cls.from_str(content) + + @classmethod + def from_url(cls, url: str) -> "OpenAPISpecification": + """ + Create an OpenAPISpecification instance from a URL. + """ + try: + response = requests.get(url, timeout=10) + response.raise_for_status() + content = response.text + except requests.RequestException as e: + raise ConnectionError(f"Failed to fetch the specification from URL: {url}. {e!s}") from e + return cls.from_str(content) + + def get_name(self) -> str: + """ + Get the title of the OpenAPI specification. + """ + return self.spec_dict.get("info", {}).get("title", "") + + def get_paths(self) -> Dict[str, Dict[str, Any]]: + """ + Get the paths from the OpenAPI specification. + """ + return self.spec_dict.get("paths", {}) + + def get_operation(self, path: str, method: Optional[str] = None) -> Operation: + """ + Retrieve an operation from the OpenAPI specification. + """ + path_item = self.get_paths().get(path, {}) + return self.get_operation_item(path, path_item, method) + + def find_operation_by_path_substring(self, path_partial: str, method: Optional[str] = None) -> Operation: + """ + Find an operation by a substring of the path. + """ + for path, path_item in self.get_paths().items(): + if path_partial in path: + return self.get_operation_item(path, path_item, method) + raise ValueError(f"No operation found with path containing {path_partial}") + + def find_operation_by_id(self, op_id: str, method: Optional[str] = None) -> Operation: + """ + Find an operation by operationId. + """ + for path, path_item in self.get_paths().items(): + op: Operation = self.get_operation_item(path, path_item, method) + if op_id in op.get_field("operationId", ""): + return self.get_operation_item(path, path_item, method) + raise ValueError(f"No operation found with operationId {op_id}") + + def get_operation_item(self, path: str, path_item: Dict[str, Any], method: Optional[str] = None) -> Operation: + """ + Get an operation item from the OpenAPI specification. + + :param path: The path of the operation. + :param path_item: The path item from the OpenAPI specification. + :param method: The HTTP method of the operation. + """ + if method: + operation_dict = path_item.get(method.lower(), {}) + if not operation_dict: + raise ValueError(f"No operation found for method {method} at path {path}") + return Operation(path, method.lower(), operation_dict, self.spec_dict) + if len(path_item) == 1: + method, operation_dict = next(iter(path_item.items())) + return Operation(path, method, operation_dict, self.spec_dict) + if len(path_item) > 1: + raise ValueError(f"Multiple operations found at path {path}, method parameter is required.") + raise ValueError(f"No operations found at path {path}.") + + def get_operations(self) -> List[Operation]: + """ + Get all operations from the OpenAPI specification. + """ + operations = [] + for path, path_item in self.get_paths().items(): + for method, operation_dict in path_item.items(): + if method.lower() in VALID_HTTP_METHODS: + operations.append(Operation(path, method, operation_dict, self.spec_dict)) + return operations + + def get_security_schemes(self) -> Dict[str, Dict[str, Any]]: + """ + Get the security schemes from the OpenAPI specification. + """ + components = self.spec_dict.get("components", {}) + return components.get("securitySchemes", {}) + + def to_dict(self, *, resolve_references: Optional[bool] = False) -> Dict[str, Any]: + """ + Converts the OpenAPI specification to a dictionary format. + + Optionally resolves all $ref references within the spec, returning a fully resolved specification + dictionary if `resolve_references` is set to True. + + :param resolve_references: If True, resolve references in the specification. + :return: A dictionary representation of the OpenAPI specification, optionally fully resolved. + """ + return jsonref.replace_refs(self.spec_dict, proxies=False) if resolve_references else self.spec_dict + + +class ClientConfiguration: + """ Configuration for the OpenAPI client. """ + + def __init__( # noqa: PLR0913 pylint: disable=too-many-arguments + self, + openapi_spec: Union[str, Path, Dict[str, Any]], + credentials: Optional[Union[str, Dict[str, Any], AuthenticationStrategy]] = None, + http_client: Optional[HttpClient] = None, + http_client_config: Optional[HttpClientConfig] = None, + llm_provider: Optional[str] = None, + ): # noqa: PLR0913 + if isinstance(openapi_spec, (str, Path)) and os.path.isfile(openapi_spec): + self.openapi_spec = OpenAPISpecification.from_file(openapi_spec) + elif isinstance(openapi_spec, dict): + self.openapi_spec = OpenAPISpecification.from_dict(openapi_spec) + elif isinstance(openapi_spec, str): + if self.is_valid_http_url(openapi_spec): + self.openapi_spec = OpenAPISpecification.from_url(openapi_spec) + else: + self.openapi_spec = OpenAPISpecification.from_str(openapi_spec) + else: + raise ValueError("Invalid OpenAPI specification format. Expected file path or dictionary.") + + self.credentials = credentials + self.http_client = http_client or HttpClient(http_client_config) + self.http_client_config = http_client_config or HttpClientConfig() + self.llm_provider = llm_provider or "openai" + + def get_openapi_spec(self) -> OpenAPISpecification: + """ + Get the OpenAPI specification. + """ + return self.openapi_spec + + def get_http_client(self) -> HttpClient: + """ + Get the HTTP client. + """ + return self.http_client + + def get_http_client_config(self) -> HttpClientConfig: + """ + Get the HTTP client configuration. + """ + return self.http_client_config + + def get_auth_config(self) -> AuthenticationStrategy: + """ + Get the authentication configuration. + """ + if not self.credentials: + return PassThroughAuthentication() + if isinstance(self.credentials, AuthenticationStrategy): + return self.credentials + security_schemes = self.openapi_spec.get_security_schemes() + if isinstance(self.credentials, str): + return self._create_authentication_from_string(self.credentials, security_schemes) + if isinstance(self.credentials, dict): + return self._create_authentication_from_dict(self.credentials) + raise ValueError(f"Unsupported credentials type: {type(self.credentials)}") + + def get_tools_definitions(self) -> List[Dict[str, Any]]: + """ + Get the tools definitions used as tools LLM parameter. + """ + provider_to_converter = {"anthropic": anthropic_converter, "cohere": cohere_converter} + converter = provider_to_converter.get(self.llm_provider, openai_converter) + return converter(self.openapi_spec) + + def get_payload_extractor(self): + """ + Get the payload extractor for the LLM provider. + """ + provider_to_arguments_field_name = {"anthropic": "input", "cohere": "parameters"} # add more providers here + # default to OpenAI "arguments" + arguments_field_name = provider_to_arguments_field_name.get(self.llm_provider, "arguments") + return LLMFunctionPayloadExtractor(arguments_field_name=arguments_field_name) + + def _create_authentication_from_string( + self, credentials: str, security_schemes: Dict[str, Any] + ) -> AuthenticationStrategy: + for scheme in security_schemes.values(): + if scheme["type"] == "apiKey": + return ApiKeyAuthentication(api_key=credentials) + if scheme["type"] == "http": + return HTTPAuthentication(token=credentials) + if scheme["type"] == "oauth2": + raise NotImplementedError("OAuth2 authentication is not yet supported.") + raise ValueError(f"Unable to create authentication from provided credentials: {credentials}") + + def _create_authentication_from_dict(self, credentials: Dict[str, Any]) -> AuthenticationStrategy: + if "username" in credentials and "password" in credentials: + return HTTPAuthentication(username=credentials["username"], password=credentials["password"]) + if "api_key" in credentials: + return ApiKeyAuthentication(api_key=credentials["api_key"]) + if "token" in credentials: + return HTTPAuthentication(token=credentials["token"]) + if "access_token" in credentials: + raise NotImplementedError("OAuth2 authentication is not yet supported.") + raise ValueError("Unable to create authentication from provided credentials: {credentials}") + + def is_valid_http_url(self, url: str) -> bool: + """Check if a URL is a valid HTTP/HTTPS URL.""" + r = urlparse(url) + return all([r.scheme in ["http", "https"], r.netloc]) + + +class LLMFunctionPayloadExtractor: + """ + Implements a recursive search for extracting LLM generated function payloads. + """ + def __init__(self, arguments_field_name: str): + self.arguments_field_name = arguments_field_name + + def extract_function_invocation(self, payload: Any) -> Dict[str, Any]: + """ + Extract the function invocation details from the payload. + """ + fields_and_values = self._search(payload) + if fields_and_values: + arguments = fields_and_values.get(self.arguments_field_name) + if not isinstance(arguments, (str, dict)): + raise ValueError( + f"Invalid {self.arguments_field_name} type {type(arguments)} for function call, expected str/dict" + ) + return { + "name": fields_and_values.get("name"), + "arguments": json.loads(arguments) if isinstance(arguments, str) else arguments, + } + return {} + + def _required_fields(self) -> List[str]: + return ["name", self.arguments_field_name] + + def _search(self, payload: Any) -> Dict[str, Any]: + if self._is_primitive(payload): + return {} + if dict_converter := self._get_dict_converter(payload): + payload = dict_converter() + elif dataclasses.is_dataclass(payload): + payload = dataclasses.asdict(payload) + if isinstance(payload, dict): + if all(field in payload for field in self._required_fields()): + # this is the payload we are looking for + return payload + for value in payload.values(): + result = self._search(value) + if result: + return result + elif isinstance(payload, list): + for item in payload: + result = self._search(item) + if result: + return result + return {} + + def _get_dict_converter( + self, obj: Any, method_names: Optional[List[str]] = None + ) -> Union[Callable[[], Dict[str, Any]], None]: + method_names = method_names or ["model_dump", "dict"] # search for pydantic v2 then v1 + for attr in method_names: + if hasattr(obj, attr) and callable(getattr(obj, attr)): + return getattr(obj, attr) + return None + + def _is_primitive(self, obj) -> bool: + return isinstance(obj, (int, float, str, bool, type(None))) + + +class ClientConfigurationBuilder: + """ + ClientConfigurationBuilder provides a fluent interface for constructing a `ClientConfiguration`. + + This builder allows for the step-by-step configuration of all necessary components to interact with an + API defined by an OpenAPI specification. + """ + + def __init__(self): + self._openapi_spec: Union[str, Path, Dict[str, Any], None] = None + self._credentials: Optional[Union[str, Dict[str, Any], AuthenticationStrategy]] = None + self._http_client: Optional[HttpClient] = None + self._http_client_config: Optional[HttpClientConfig] = None + self._llm_provider: Optional[str] = None + + def with_openapi_spec(self, openapi_spec: Union[str, Path, Dict[str, Any]]) -> "ClientConfigurationBuilder": + """ + Sets the OpenAPI specification for the configuration. + + :param openapi_spec: The OpenAPI specification as a URL, file path, or dictionary. + :return: The instance of this builder to allow for method chaining. + """ + self._openapi_spec = openapi_spec + return self + + def with_credentials( + self, credentials: Union[str, Dict[str, Any], AuthenticationStrategy] + ) -> "ClientConfigurationBuilder": + """ + Specifies the credentials used for authenticating requests made by the client. + + :param credentials: Credentials as a string, dictionary, or an AuthenticationStrategy instance. + :return: The instance of this builder to allow for method chaining. + """ + self._credentials = credentials + return self + + def with_http_client(self, http_client: HttpClient) -> "ClientConfigurationBuilder": + """ + Specifies the HTTP client to be used for making API calls. + + :param http_client: The HTTP client implementation. + :return: The instance of this builder to allow for method chaining. + """ + self._http_client = http_client + return self + + def with_http_client_config(self, http_client_config: HttpClientConfig) -> "ClientConfigurationBuilder": + """ + Specifies the HTTP client configuration. + + If not set, the default configuration is used. + + :param http_client_config: Configuration settings for the HTTP client. + :return: The instance of this builder to allow for method chaining. + """ + self._http_client_config = http_client_config + return self + + def with_provider(self, llm_provider: str) -> "ClientConfigurationBuilder": + """ + Specifies the Large Language Model (LLM) provider to be used for generating function calls. + + :param llm_provider: The LLM provider name. + :return: The instance of this builder to allow for method chaining. + """ + self._llm_provider = llm_provider + return self + + def build(self) -> ClientConfiguration: + """ + Constructs a `ClientConfiguration` instance using the settings provided. + + It validates that an OpenAPI specification has been set before proceeding with the build. + + :return: A configured instance of ClientConfiguration. + :raises ValueError: If the OpenAPI specification is not set. + """ + if self._openapi_spec is None: + raise ValueError("OpenAPI specification must be provided to build a configuration.") + + return ClientConfiguration( + openapi_spec=self._openapi_spec, + credentials=self._credentials, + http_client=self._http_client, + http_client_config=self._http_client_config, + llm_provider=self._llm_provider or "openai", + ) + + +class RequestBuilder: + """ Builds an HTTP request based on an OpenAPI operation""" + def __init__(self, client_config: ClientConfiguration): + self.openapi_parser = client_config.get_openapi_spec() + self.http_client = client_config.get_http_client() + self.auth_config = client_config.get_auth_config() or PassThroughAuthentication() + + def build_request(self, operation: Operation, **kwargs) -> Any: + """ + Build an HTTP request based on the operation and arguments provided. + """ + url = self._build_url(operation, **kwargs) + method = operation.method.lower() + headers = self._build_headers(operation) + query_params = self._build_query_params(operation, **kwargs) + body = self._build_request_body(operation, **kwargs) + request = { + "url": url, + "method": method, + "headers": headers, + "params": query_params, + "json": body, + } + self._apply_authentication(operation, request) + return request + + def _build_headers(self, operation: Operation, **kwargs) -> Dict[str, str]: + headers = {} + for parameter in operation.get_parameters("header"): + param_value = kwargs.get(parameter["name"], None) + if param_value: + headers[parameter["name"]] = str(param_value) + elif parameter.get("required", False): + raise ValueError(f"Missing required header parameter: {parameter['name']}") + return headers + + def _build_url(self, operation: Operation, **kwargs) -> str: + server_url = operation.get_server_url() + path = operation.path + for parameter in operation.get_parameters("path"): + param_value = kwargs.get(parameter["name"], None) + if param_value: + path = path.replace(f"{{{parameter['name']}}}", str(param_value)) + elif parameter.get("required", False): + raise ValueError(f"Missing required path parameter: {parameter['name']}") + return server_url + path + + def _build_query_params(self, operation: Operation, **kwargs) -> Dict[str, Any]: + query_params = {} + + # Simplify query parameter assembly using _get_parameter_value + for parameter in operation.get_parameters("query"): + param_value = kwargs.get(parameter["name"], None) + if param_value: + query_params[parameter["name"]] = param_value + elif parameter.get("required", False): + raise ValueError(f"Missing required query parameter: {parameter['name']}") + return query_params + + def _build_request_body(self, operation: Operation, **kwargs) -> Any: + request_body = operation.get_request_body() + if request_body: + content = request_body.get("content", {}) + if "application/json" in content: + return {**kwargs} + raise NotImplementedError("Request body content type not supported") + return None + + def _apply_authentication(self, operation: Operation, request: Dict[str, Any]): + # security requirements specify which authentication scheme to apply (the "what/which") + security_requirements = operation.get_security_requirements() + # security schemes define how to authenticate (the "how") + security_schemes = operation.spec_dict.get("components", {}).get("securitySchemes", {}) + if security_requirements: + for requirement in security_requirements: + for scheme_name in requirement: + if scheme_name in security_schemes: + security_scheme = security_schemes[scheme_name] + self.auth_config.apply_auth(security_scheme, request) + break + + +class OpenAPIServiceClient: + """ + A client for invoking operations on REST services defined by OpenAPI specifications. + + Together with the `ClientConfiguration`, its `ClientConfigurationBuilder`, the `OpenAPIServiceClient` + simplifies the process of (LLMs) with services defined by OpenAPI specifications. + """ + + def __init__(self, client_config: ClientConfiguration): + self.openapi_spec = client_config.get_openapi_spec() + self.http_client = client_config.get_http_client() + self.request_builder = RequestBuilder(client_config) + self.payload_extractor = client_config.get_payload_extractor() + + def invoke(self, function_payload: Any) -> Any: + """ + Invokes a function specified in the function payload. + + :param function_payload: The function payload containing the details of the function to be invoked. + :returns: The response from the service after invoking the function. + :raises OpenAPIClientError: If the function invocation payload cannot be extracted from the function payload. + :raises HttpClientError: If an error occurs while sending the request and receiving the response. + """ + fn_invocation_payload = self.payload_extractor.extract_function_invocation(function_payload) + if not fn_invocation_payload: + raise OpenAPIClientError( + f"Failed to extract function invocation payload from {function_payload} using " + f"{self.payload_extractor.__class__.__name__}. Ensure the payload format matches the expected " + "structure for the designated LLM extractor." + ) + # fn_invocation_payload, if not empty, guaranteed to have "name" and "arguments" keys from here on + operation = self.openapi_spec.find_operation_by_id(fn_invocation_payload.get("name")) + request = self.request_builder.build_request(operation, **fn_invocation_payload.get("arguments")) + return self.http_client.send_request(request) + + +class OpenAPIClientError(Exception): + """Exception raised for errors in the OpenAPI client.""" + + +def openai_converter(schema: OpenAPISpecification) -> List[Dict[str, Any]]: + """ + Converts OpenAPI specification to a list of function suitable for OpenAI LLM function calling. + + :param schema: The OpenAPI specification to convert. + :return: A list of dictionaries, each representing a function definition. + """ + resolved_schema = jsonref.replace_refs(schema.spec_dict) + fn_definitions = _openapi_to_functions(resolved_schema, "parameters", _parse_endpoint_spec_openai) + return [{"type": "function", "function": fn} for fn in fn_definitions] + + +def anthropic_converter(schema: OpenAPISpecification) -> List[Dict[str, Any]]: + """ + Converts an OpenAPI specification to a list of function definitions for Anthropic LLM function calling. + + :param schema: The OpenAPI specification to convert. + :return: A list of dictionaries, each representing a function definition. + """ + resolved_schema = jsonref.replace_refs(schema.spec_dict) + return _openapi_to_functions(resolved_schema, "input_schema", _parse_endpoint_spec_openai) + + +def cohere_converter(schema: OpenAPISpecification) -> List[Dict[str, Any]]: + """ + Converts an OpenAPI specification to a list of function definitions for Cohere LLM function calling. + + :param schema: The OpenAPI specification to convert. + :return: A list of dictionaries, each representing a function definition. + """ + resolved_schema = jsonref.replace_refs(schema.spec_dict) + return _openapi_to_functions(resolved_schema, "not important for cohere", _parse_endpoint_spec_cohere) + + +def _openapi_to_functions(service_openapi_spec: Dict[str, Any], parameters_name: str, + parse_endpoint_fn: Callable[[Dict[str, Any], str], Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + Extracts functions from the OpenAPI specification, converts them into a function schema. + """ + + # Doesn't enforce rigid spec validation because that would require a lot of dependencies + # We check the version and require minimal fields to be present, so we can extract functions + spec_version = service_openapi_spec.get("openapi") + if not spec_version: + raise ValueError(f"Invalid OpenAPI spec provided. Could not extract version from {service_openapi_spec}") + service_openapi_spec_version = int(spec_version.split(".")[0]) + # Compare the versions + if service_openapi_spec_version < MIN_REQUIRED_OPENAPI_SPEC_VERSION: + raise ValueError( + f"Invalid OpenAPI spec version {service_openapi_spec_version}. Must be " + f"at least {MIN_REQUIRED_OPENAPI_SPEC_VERSION}." + ) + functions: List[Dict[str, Any]] = [] + for paths in service_openapi_spec["paths"].values(): + for path_spec in paths.values(): + function_dict = parse_endpoint_fn(path_spec, parameters_name) + if function_dict: + functions.append(function_dict) + return functions + + +def _parse_endpoint_spec_openai(resolved_spec: Dict[str, Any], parameters_name: str) -> Dict[str, Any]: + """ + Parses an OpenAPI endpoint specification for OpenAI. + """ + if not isinstance(resolved_spec, dict): + logger.warning("Invalid OpenAPI spec format provided. Could not extract function.") + return {} + function_name = resolved_spec.get("operationId") + description = resolved_spec.get("description") or resolved_spec.get("summary", "") + schema: Dict[str, Any] = {"type": "object", "properties": {}} + # requestBody section + req_body_schema = ( + resolved_spec.get("requestBody", {}).get("content", {}).get("application/json", {}).get("schema", {}) + ) + if "properties" in req_body_schema: + for prop_name, prop_schema in req_body_schema["properties"].items(): + schema["properties"][prop_name] = _parse_property_attributes(prop_schema) + if "required" in req_body_schema: + schema.setdefault("required", []).extend(req_body_schema["required"]) + + # parameters section + for param in resolved_spec.get("parameters", []): + if "schema" in param: + schema_dict = _parse_property_attributes(param["schema"]) + # these attributes are not in param[schema] level but on param level + useful_attributes = ["description", "pattern", "enum"] + schema_dict.update({key: param[key] for key in useful_attributes if param.get(key)}) + schema["properties"][param["name"]] = schema_dict + if param.get("required", False): + schema.setdefault("required", []).append(param["name"]) + + if function_name and description and schema["properties"]: + return {"name": function_name, "description": description, parameters_name: schema} + logger.warning("Invalid OpenAPI spec format provided. Could not extract function from %s", resolved_spec) + return {} + + +def _parse_property_attributes(property_schema: Dict[str, Any], include_attributes: Optional[List[str]] = None + ) -> Dict[str, Any]: + """ + Recursively parses the attributes of a property schema. + """ + include_attributes = include_attributes or ["description", "pattern", "enum"] + schema_type = property_schema.get("type") + parsed_schema = {"type": schema_type} if schema_type else {} + for attr in include_attributes: + if attr in property_schema: + parsed_schema[attr] = property_schema[attr] + if schema_type == "object": + properties = property_schema.get("properties", {}) + parsed_properties = { + prop_name: _parse_property_attributes(prop, include_attributes) + for prop_name, prop in properties.items() + } + parsed_schema["properties"] = parsed_properties + if "required" in property_schema: + parsed_schema["required"] = property_schema["required"] + elif schema_type == "array": + items = property_schema.get("items", {}) + parsed_schema["items"] = _parse_property_attributes(items, include_attributes) + return parsed_schema + + +def _parse_endpoint_spec_cohere(operation: Dict[str, Any], ignored_param: str) -> Dict[str, Any]: + """ + Parses an endpoint specification for Cohere. + """ + function_name = operation.get("operationId") + description = operation.get("description") or operation.get("summary", "") + parameter_definitions = _parse_parameters(operation) + if function_name: + return { + "name": function_name, + "description": description, + "parameter_definitions": parameter_definitions, + } + logger.warning("Operation missing operationId, cannot create function definition.") + return {} + + +def _parse_parameters(operation: Dict[str, Any]) -> Dict[str, Any]: + """ + Parses the parameters from an operation specification. + """ + parameters = {} + for param in operation.get("parameters", []): + if "schema" in param: + parameters[param["name"]] = _parse_schema( + param["schema"], param.get("required", False), param.get("description", "") + ) + if "requestBody" in operation: + content = operation["requestBody"].get("content", {}).get("application/json", {}) + if "schema" in content: + schema_properties = content["schema"].get("properties", {}) + required_properties = content["schema"].get("required", []) + for name, schema in schema_properties.items(): + parameters[name] = _parse_schema( + schema, name in required_properties, schema.get("description", "") + ) + return parameters + + +def _parse_schema(schema: Dict[str, Any], required: bool, description: str) -> Dict[str, Any]: # noqa: FBT001 + """ + Parses a schema part of an operation specification. + """ + schema_type = _get_type(schema) + if schema_type == "object": + # Recursive call for complex types + properties = schema.get("properties", {}) + nested_parameters = { + name: _parse_schema( + schema=prop_schema, + required=bool(name in schema.get("required", False)), + description=prop_schema.get("description", ""), + ) + for name, prop_schema in properties.items() + } + return { + "type": schema_type, + "description": description, + "properties": nested_parameters, + "required": required, + } + return {"type": schema_type, "description": description, "required": required} + + +def _get_type(schema: Dict[str, Any]) -> str: + type_mapping = {"integer": "int", "string": "str", "boolean": "bool", "number": "float", "object": "object", + "array": "list"} + schema_type = schema.get("type", "object") + if schema_type not in type_mapping: + raise ValueError(f"Unsupported schema type {schema_type}") + return type_mapping[schema_type] diff --git a/pyproject.toml b/pyproject.toml index 78245150..94e4f9b1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ classifiers = [ "Programming Language :: Python :: 3.12", "Topic :: Scientific/Engineering :: Artificial Intelligence", ] -dependencies = ["haystack-ai"] +dependencies = ["jsonref", "haystack-ai"] [project.urls] "CI: GitHub" = "https://github.com/deepset-ai/haystack-experimental/actions" "GitHub: issues" = "https://github.com/deepset-ai/haystack-experimental/issues" @@ -39,6 +39,9 @@ dependencies = [ # Test "pytest", "pytest-cov", + "fastapi", + "cohere", + "anthropic", # Linting "pylint", "ruff", diff --git a/test/components/connectors/__init__.py b/test/components/connectors/__init__.py new file mode 100644 index 00000000..c1764a6e --- /dev/null +++ b/test/components/connectors/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 diff --git a/test/components/connectors/test_openapi.py b/test/components/connectors/test_openapi.py new file mode 100644 index 00000000..f9b5e413 --- /dev/null +++ b/test/components/connectors/test_openapi.py @@ -0,0 +1,114 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +import json +import os +from unittest.mock import patch + +import pytest + +from haystack_experimental.components.connectors import OpenAPIServiceConnector +from haystack.dataclasses import ChatMessage + + +class TestOpenAPIServiceConnector: + @pytest.fixture + def setup_mock(self): + with patch("haystack_experimental.components.connectors.openapi.OpenAPIServiceClient") as mock_client: + mock_client_instance = mock_client.return_value + mock_client_instance.invoke.return_value = {"service_response": "Yes, he was fired and rehired"} + yield mock_client_instance + + def test_init(self): + service_connector = OpenAPIServiceConnector() + assert service_connector is not None + assert service_connector.llm_provider == "openai" + + def test_init_with_anthropic_provider(self): + service_connector = OpenAPIServiceConnector(provider="anthropic") + assert service_connector is not None + assert service_connector.llm_provider == "anthropic" + + def test_run_with_mock(self, setup_mock, test_files_path): + fc_payload = [ + { + "function": {"arguments": '{"q": "Why was Sam Altman ousted from OpenAI?"}', "name": "search"}, + "id": "call_PmEBYvZ7mGrQP5PUASA5m9wO", + "type": "function", + } + ] + with open(os.path.join(test_files_path, "json/serperdev_openapi_spec.json"), "r") as file: + serperdev_openapi_spec = json.load(file) + + service_connector = OpenAPIServiceConnector() + result = service_connector.run( + messages=[ChatMessage.from_assistant(json.dumps(fc_payload))], + service_openapi_spec=serperdev_openapi_spec, + service_credentials="fake_api_key", + ) + + assert "service_response" in result + assert len(result["service_response"]) == 1 + assert isinstance(result["service_response"][0], ChatMessage) + response_content = json.loads(result["service_response"][0].content) + assert response_content == {"service_response": "Yes, he was fired and rehired"} + + # verify invocation payload + setup_mock.invoke.assert_called_once() + invocation_payload = [ + { + "function": {"arguments": '{"q": "Why was Sam Altman ousted from OpenAI?"}', "name": "search"}, + "id": "call_PmEBYvZ7mGrQP5PUASA5m9wO", + "type": "function", + } + ] + setup_mock.invoke.assert_called_with(invocation_payload) + + @pytest.mark.integration + @pytest.mark.skipif("SERPERDEV_API_KEY" not in os.environ, reason="SerperDev API key is not available") + def test_run(self, test_files_path): + fc_payload = [ + { + "function": {"arguments": '{"q": "Why was Sam Altman ousted from OpenAI?"}', "name": "search"}, + "id": "call_PmEBYvZ7mGrQP5PUASA5m9wO", + "type": "function", + } + ] + + with open(os.path.join(test_files_path, "json/serperdev_openapi_spec.json"), "r") as file: + serperdev_openapi_spec = json.load(file) + + service_connector = OpenAPIServiceConnector() + result = service_connector.run( + messages=[ChatMessage.from_assistant(json.dumps(fc_payload))], + service_openapi_spec=serperdev_openapi_spec, + service_credentials=os.environ["SERPERDEV_API_KEY"], + ) + assert "service_response" in result + assert len(result["service_response"]) == 1 + assert isinstance(result["service_response"][0], ChatMessage) + response_text = result["service_response"][0].content + assert "Sam" in response_text or "Altman" in response_text + + @pytest.mark.integration + def test_run_no_credentials(self, test_files_path): + fc_payload = [ + { + "function": {"arguments": '{"q": "Why was Sam Altman ousted from OpenAI?"}', "name": "search"}, + "id": "call_PmEBYvZ7mGrQP5PUASA5m9wO", + "type": "function", + } + ] + + with open(os.path.join(test_files_path, "json/serperdev_openapi_spec.json"), "r") as file: + serperdev_openapi_spec = json.load(file) + + service_connector = OpenAPIServiceConnector() + result = service_connector.run( + messages=[ChatMessage.from_assistant(json.dumps(fc_payload))], service_openapi_spec=serperdev_openapi_spec + ) + assert "service_response" in result + assert len(result["service_response"]) == 1 + assert isinstance(result["service_response"][0], ChatMessage) + response_text = result["service_response"][0].content + assert "403" in response_text diff --git a/test/components/converters/__init__.py b/test/components/converters/__init__.py new file mode 100644 index 00000000..c1764a6e --- /dev/null +++ b/test/components/converters/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 diff --git a/test/components/converters/test_openapi.py b/test/components/converters/test_openapi.py new file mode 100644 index 00000000..fabb4624 --- /dev/null +++ b/test/components/converters/test_openapi.py @@ -0,0 +1,246 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +import json +import sys +import tempfile + +import pytest + +from haystack_experimental.components.converters import OpenAPIServiceToFunctions +from haystack.dataclasses import ByteStream + + +@pytest.fixture +def json_serperdev_openapi_spec(): + serper_spec = """ + { + "openapi": "3.0.0", + "info": { + "title": "SerperDev", + "version": "1.0.0", + "description": "API for performing search queries" + }, + "servers": [ + { + "url": "https://google.serper.dev" + } + ], + "paths": { + "/search": { + "post": { + "operationId": "search", + "description": "Search the web with Google", + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "q": { + "type": "string" + } + } + } + } + } + }, + "responses": { + "200": { + "description": "Successful response", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "searchParameters": { + "type": "undefined" + }, + "knowledgeGraph": { + "type": "undefined" + }, + "answerBox": { + "type": "undefined" + }, + "organic": { + "type": "undefined" + }, + "topStories": { + "type": "undefined" + }, + "peopleAlsoAsk": { + "type": "undefined" + }, + "relatedSearches": { + "type": "undefined" + } + } + } + } + } + } + }, + "security": [ + { + "apikey": [] + } + ] + } + } + }, + "components": { + "securitySchemes": { + "apikey": { + "type": "apiKey", + "name": "x-api-key", + "in": "header" + } + } + } + } + """ + return serper_spec + + +@pytest.fixture +def yaml_serperdev_openapi_spec(): + serper_spec = """ + openapi: 3.0.0 + info: + title: SerperDev + version: 1.0.0 + description: API for performing search queries + servers: + - url: 'https://google.serper.dev' + paths: + /search: + post: + operationId: search + description: Search the web with Google + requestBody: + required: true + content: + application/json: + schema: + type: object + properties: + q: + type: string + responses: + '200': + description: Successful response + content: + application/json: + schema: + type: object + properties: + searchParameters: + type: undefined + knowledgeGraph: + type: undefined + answerBox: + type: undefined + organic: + type: undefined + topStories: + type: undefined + peopleAlsoAsk: + type: undefined + relatedSearches: + type: undefined + security: + - apikey: [] + components: + securitySchemes: + apikey: + type: apiKey + name: x-api-key + in: header + """ + return serper_spec + + +@pytest.fixture +def fn_definition_transform(): + return lambda function_def: {"type": "function", "function": function_def} + + +class TestOpenAPIServiceToFunctions: + # test we can extract functions from openapi spec given + def test_run_with_bytestream_source(self, json_serperdev_openapi_spec, fn_definition_transform): + service = OpenAPIServiceToFunctions() + spec_stream = ByteStream.from_string(json_serperdev_openapi_spec) + result = service.run(sources=[spec_stream]) + assert len(result["functions"]) == 1 + fc = result["functions"][0] + + # check that fc definition is as expected + assert fc == fn_definition_transform( + { + "name": "search", + "description": "Search the web with Google", + "parameters": {"type": "object", "properties": {"q": {"type": "string"}}}, + } + ) + + @pytest.mark.skipif( + sys.platform in ["win32", "cygwin"], + reason="Can't run on Windows Github CI, need access temp file but windows does not allow it", + ) + def test_run_with_file_source(self, json_serperdev_openapi_spec, fn_definition_transform): + # test we can extract functions from openapi spec given in file + service = OpenAPIServiceToFunctions() + # write the spec to NamedTemporaryFile and check that it is parsed correctly + with tempfile.NamedTemporaryFile() as tmp: + tmp.write(json_serperdev_openapi_spec.encode("utf-8")) + tmp.seek(0) + result = service.run(sources=[tmp.name]) + assert len(result["functions"]) == 1 + fc = result["functions"][0] + + # check that fc definition is as expected + assert fc == fn_definition_transform( + { + "name": "search", + "description": "Search the web with Google", + "parameters": {"type": "object", "properties": {"q": {"type": "string"}}}, + } + ) + + def test_run_with_invalid_bytestream_source(self, caplog): + # test invalid source + service = OpenAPIServiceToFunctions() + with pytest.raises(ValueError, match="Invalid OpenAPI specification"): + service.run(sources=[ByteStream.from_string("")]) + + def test_complex_types_conversion(self, test_files_path, fn_definition_transform): + # ensure that complex types from OpenAPI spec are converted to the expected format in OpenAI function calling + service = OpenAPIServiceToFunctions() + result = service.run(sources=[test_files_path / "json" / "complex_types_openapi_service.json"]) + assert len(result["functions"]) == 1 + + with open(test_files_path / "json" / "complex_types_openai_spec.json") as openai_spec_file: + desired_output = json.load(openai_spec_file) + assert result["functions"][0] == fn_definition_transform(desired_output) + + def test_simple_and_complex_at_once(self, test_files_path, json_serperdev_openapi_spec, fn_definition_transform): + # ensure multiple functions are extracted from multiple paths in OpenAPI spec + service = OpenAPIServiceToFunctions() + sources = [ + ByteStream.from_string(json_serperdev_openapi_spec), + test_files_path / "json" / "complex_types_openapi_service.json", + ] + result = service.run(sources=sources) + assert len(result["functions"]) == 2 + + with open(test_files_path / "json" / "complex_types_openai_spec.json") as openai_spec_file: + desired_output = json.load(openai_spec_file) + assert result["functions"][0] == fn_definition_transform( + { + "name": "search", + "description": "Search the web with Google", + "parameters": {"type": "object", "properties": {"q": {"type": "string"}}}, + } + ) + assert result["functions"][1] == fn_definition_transform(desired_output) diff --git a/test/test_files/json/complex_types_openai_spec.json b/test/test_files/json/complex_types_openai_spec.json new file mode 100644 index 00000000..ebaf9556 --- /dev/null +++ b/test/test_files/json/complex_types_openai_spec.json @@ -0,0 +1,64 @@ +{ + "name": "processPayment", + "description": "Process a new payment using the specified payment method", + "parameters": { + "type": "object", + "properties": { + "transaction_amount": { + "type": "number", + "description": "The amount to be paid" + }, + "description": { + "type": "string", + "description": "A brief description of the payment" + }, + "payment_method_id": { + "type": "string", + "description": "The payment method to be used" + }, + "payer": { + "type": "object", + "description": "Information about the payer, including their name, email, and identification number", + "properties": { + "name": { + "type": "string", + "description": "The payer's name" + }, + "email": { + "type": "string", + "description": "The payer's email address" + }, + "identification": { + "type": "object", + "description": "The payer's identification number", + "properties": { + "type": { + "type": "string", + "description": "The type of identification document (e.g., CPF, CNPJ)" + }, + "number": { + "type": "string", + "description": "The identification number" + } + }, + "required": [ + "type", + "number" + ] + } + }, + "required": [ + "name", + "email", + "identification" + ] + } + }, + "required": [ + "transaction_amount", + "description", + "payment_method_id", + "payer" + ] + } +} diff --git a/test/test_files/json/complex_types_openapi_service.json b/test/test_files/json/complex_types_openapi_service.json new file mode 100644 index 00000000..3ea04f8a --- /dev/null +++ b/test/test_files/json/complex_types_openapi_service.json @@ -0,0 +1,103 @@ +{ + "openapi": "3.0.0", + "info": { + "title": "Payment API", + "version": "1.0.0" + }, + "servers": [ + { + "url": "http://localhost:8080" + } + ], + "paths": { + "/new_payment": { + "post": { + "summary": "Process a new payment", + "description": "Process a new payment using the specified payment method", + "operationId": "processPayment", + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "transaction_amount": { + "type": "number", + "description": "The amount to be paid" + }, + "description": { + "type": "string", + "description": "A brief description of the payment" + }, + "payment_method_id": { + "type": "string", + "description": "The payment method to be used" + }, + "payer": { + "$ref": "#/components/schemas/Payer" + } + }, + "required": [ + "transaction_amount", + "description", + "payment_method_id", + "payer" + ] + } + } + } + }, + "responses": { + "200": { + "description": "Payment processed successfully" + }, + "400": { + "description": "Invalid request" + } + } + } + } + }, + "components": { + "schemas": { + "Payer": { + "type": "object", + "description": "Information about the payer, including their name, email, and identification number", + "properties": { + "name": { + "type": "string", + "description": "The payer's name" + }, + "email": { + "type": "string", + "description": "The payer's email address" + }, + "identification": { + "type": "object", + "description": "The payer's identification number", + "properties": { + "type": { + "type": "string", + "description": "The type of identification document (e.g., CPF, CNPJ)" + }, + "number": { + "type": "string", + "description": "The identification number" + } + }, + "required": [ + "type", + "number" + ] + } + }, + "required": [ + "name", + "email", + "identification" + ] + } + } + } +} diff --git a/test/test_files/json/openapi_order_service.json b/test/test_files/json/openapi_order_service.json new file mode 100644 index 00000000..3e24661a --- /dev/null +++ b/test/test_files/json/openapi_order_service.json @@ -0,0 +1,109 @@ +{ + "openapi": "3.0.0", + "info": { + "title": "Order Service", + "version": "1.0.0" + }, + "servers": [{"url": "http://localhost"}], + "paths": { + "/orders": { + "post": { + "summary": "Create a new order", + "operationId": "createOrder", + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/Order" + } + } + } + }, + "responses": { + "201": { + "description": "Created", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/OrderResponse" + } + } + } + } + } + } + } + }, + "components": { + "schemas": { + "Order": { + "type": "object", + "properties": { + "customer": { + "$ref": "#/components/schemas/Customer" + }, + "items": { + "type": "array", + "items": { + "$ref": "#/components/schemas/OrderItem" + } + } + }, + "required": [ + "customer", + "items" + ] + }, + "Customer": { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "email": { + "type": "string" + } + }, + "required": [ + "name", + "email" + ] + }, + "OrderItem": { + "type": "object", + "properties": { + "product": { + "type": "string" + }, + "quantity": { + "type": "integer" + } + }, + "required": [ + "product", + "quantity" + ] + }, + "OrderResponse": { + "type": "object", + "properties": { + "orderId": { + "type": "string" + }, + "status": { + "type": "string" + }, + "totalAmount": { + "type": "number" + } + }, + "required": [ + "orderId", + "status", + "totalAmount" + ] + } + } + } +} \ No newline at end of file diff --git a/test/test_files/json/serperdev_openapi_spec.json b/test/test_files/json/serperdev_openapi_spec.json new file mode 100644 index 00000000..123c993a --- /dev/null +++ b/test/test_files/json/serperdev_openapi_spec.json @@ -0,0 +1,62 @@ +{ + "openapi": "3.0.0", + "info": { + "title": "SerperDev", + "version": "1.0.0", + "description": "API for performing search queries" + }, + "servers": [ + { + "url": "https://google.serper.dev" + } + ], + "paths": { + "/search": { + "post": { + "operationId": "search", + "description": "Search the web with Google", + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "q": { + "type": "string" + } + } + } + } + } + }, + "responses": { + "200": { + "description": "Successful response", + "content": { + "application/json": { + "schema": { + "type": "object" + } + } + } + } + }, + "security": [ + { + "apikey": [] + } + ] + } + } + }, + "components": { + "securitySchemes": { + "apikey": { + "type": "apiKey", + "name": "x-api-key", + "in": "header" + } + } + } +} diff --git a/test/test_files/yaml/github_compare.yml b/test/test_files/yaml/github_compare.yml new file mode 100644 index 00000000..e14575b7 --- /dev/null +++ b/test/test_files/yaml/github_compare.yml @@ -0,0 +1,438 @@ +openapi: 3.1.0 +info: + title: Github API + description: Enables interaction with OpenAPI + version: v1.0.0 +servers: + - url: https://api.github.com +paths: + /repos/{owner}/{repo}/compare/{basehead}: + get: + summary: Compare two branches + description: Compares two branches against one another. + tags: + - repos + operationId: compare_branches + externalDocs: + description: API method documentation + url: >- + https://docs.github.com/enterprise-server@3.9/rest/commits/commits#compare-two-commits + parameters: + - name: basehead + description: >- + The base branch and head branch to compare. This parameter expects + the format `BASE...HEAD` + in: path + required: true + x-multi-segment: true + schema: + type: string + - name: owner + description: The repository owner, usually a company or orgnization + in: path + required: true + x-multi-segment: true + schema: + type: string + - name: repo + description: The repository itself, the project + in: path + required: true + x-multi-segment: true + schema: + type: string + responses: + '200': + description: Response + content: + application/json: + schema: + $ref: '#/components/schemas/commit-comparison' + x-github: + githubCloudOnly: false + enabledForGitHubApps: true + category: commits + subcategory: commits +components: + schemas: + commit-comparison: + title: Commit Comparison + description: Commit Comparison + type: object + properties: + url: + type: string + format: uri + example: >- + https://api.github.com/repos/octocat/Hello-World/compare/master...topic + html_url: + type: string + format: uri + example: https://github.com/octocat/Hello-World/compare/master...topic + permalink_url: + type: string + format: uri + example: >- + https://github.com/octocat/Hello-World/compare/octocat:bbcd538c8e72b8c175046e27cc8f907076331401...octocat:0328041d1152db8ae77652d1618a02e57f745f17 + diff_url: + type: string + format: uri + example: https://github.com/octocat/Hello-World/compare/master...topic.diff + patch_url: + type: string + format: uri + example: https://github.com/octocat/Hello-World/compare/master...topic.patch + base_commit: + $ref: '#/components/schemas/commit' + merge_base_commit: + $ref: '#/components/schemas/commit' + status: + type: string + enum: + - diverged + - ahead + - behind + - identical + example: ahead + ahead_by: + type: integer + example: 4 + behind_by: + type: integer + example: 5 + total_commits: + type: integer + example: 6 + commits: + type: array + items: + $ref: '#/components/schemas/commit' + files: + type: array + items: + $ref: '#/components/schemas/diff-entry' + required: + - url + - html_url + - permalink_url + - diff_url + - patch_url + - base_commit + - merge_base_commit + - status + - ahead_by + - behind_by + - total_commits + - commits + nullable-git-user: + title: Git User + description: Metaproperties for Git author/committer information. + type: object + properties: + name: + type: string + example: '"Chris Wanstrath"' + email: + type: string + example: '"chris@ozmm.org"' + date: + type: string + example: '"2007-10-29T02:42:39.000-07:00"' + nullable: true + nullable-simple-user: + title: Simple User + description: A GitHub user. + type: object + properties: + name: + nullable: true + type: string + email: + nullable: true + type: string + login: + type: string + example: octocat + id: + type: integer + example: 1 + node_id: + type: string + example: MDQ6VXNlcjE= + avatar_url: + type: string + format: uri + example: https://github.com/images/error/octocat_happy.gif + gravatar_id: + type: string + example: 41d064eb2195891e12d0413f63227ea7 + nullable: true + url: + type: string + format: uri + example: https://api.github.com/users/octocat + html_url: + type: string + format: uri + example: https://github.com/octocat + followers_url: + type: string + format: uri + example: https://api.github.com/users/octocat/followers + following_url: + type: string + example: https://api.github.com/users/octocat/following{/other_user} + gists_url: + type: string + example: https://api.github.com/users/octocat/gists{/gist_id} + starred_url: + type: string + example: https://api.github.com/users/octocat/starred{/owner}{/repo} + subscriptions_url: + type: string + format: uri + example: https://api.github.com/users/octocat/subscriptions + organizations_url: + type: string + format: uri + example: https://api.github.com/users/octocat/orgs + repos_url: + type: string + format: uri + example: https://api.github.com/users/octocat/repos + events_url: + type: string + example: https://api.github.com/users/octocat/events{/privacy} + received_events_url: + type: string + format: uri + example: https://api.github.com/users/octocat/received_events + type: + type: string + example: User + site_admin: + type: boolean + starred_at: + type: string + example: '"2020-07-09T00:17:55Z"' + required: + - avatar_url + - events_url + - followers_url + - following_url + - gists_url + - gravatar_id + - html_url + - id + - node_id + - login + - organizations_url + - received_events_url + - repos_url + - site_admin + - starred_url + - subscriptions_url + - type + - url + nullable: true + verification: + title: Verification + type: object + properties: + verified: + type: boolean + reason: + type: string + payload: + type: string + nullable: true + signature: + type: string + nullable: true + required: + - verified + - reason + - payload + - signature + diff-entry: + title: Diff Entry + description: Diff Entry + type: object + properties: + sha: + type: string + example: bbcd538c8e72b8c175046e27cc8f907076331401 + filename: + type: string + example: file1.txt + status: + type: string + enum: + - added + - removed + - modified + - renamed + - copied + - changed + - unchanged + example: added + additions: + type: integer + example: 103 + deletions: + type: integer + example: 21 + changes: + type: integer + example: 124 + blob_url: + type: string + format: uri + example: >- + https://github.com/octocat/Hello-World/blob/6dcb09b5b57875f334f61aebed695e2e4193db5e/file1.txt + raw_url: + type: string + format: uri + example: >- + https://github.com/octocat/Hello-World/raw/6dcb09b5b57875f334f61aebed695e2e4193db5e/file1.txt + contents_url: + type: string + format: uri + example: >- + https://api.github.com/repos/octocat/Hello-World/contents/file1.txt?ref=6dcb09b5b57875f334f61aebed695e2e4193db5e + patch: + type: string + example: '@@ -132,7 +132,7 @@ module Test @@ -1000,7 +1000,7 @@ module Test' + previous_filename: + type: string + example: file.txt + required: + - additions + - blob_url + - changes + - contents_url + - deletions + - filename + - raw_url + - sha + - status + commit: + title: Commit + description: Commit + type: object + properties: + url: + type: string + format: uri + example: >- + https://api.github.com/repos/octocat/Hello-World/commits/6dcb09b5b57875f334f61aebed695e2e4193db5e + sha: + type: string + example: 6dcb09b5b57875f334f61aebed695e2e4193db5e + node_id: + type: string + example: MDY6Q29tbWl0NmRjYjA5YjViNTc4NzVmMzM0ZjYxYWViZWQ2OTVlMmU0MTkzZGI1ZQ== + html_url: + type: string + format: uri + example: >- + https://github.com/octocat/Hello-World/commit/6dcb09b5b57875f334f61aebed695e2e4193db5e + comments_url: + type: string + format: uri + example: >- + https://api.github.com/repos/octocat/Hello-World/commits/6dcb09b5b57875f334f61aebed695e2e4193db5e/comments + commit: + type: object + properties: + url: + type: string + format: uri + example: >- + https://api.github.com/repos/octocat/Hello-World/commits/6dcb09b5b57875f334f61aebed695e2e4193db5e + author: + $ref: '#/components/schemas/nullable-git-user' + committer: + $ref: '#/components/schemas/nullable-git-user' + message: + type: string + example: Fix all the bugs + comment_count: + type: integer + example: 0 + tree: + type: object + properties: + sha: + type: string + example: 827efc6d56897b048c772eb4087f854f46256132 + url: + type: string + format: uri + example: >- + https://api.github.com/repos/octocat/Hello-World/tree/827efc6d56897b048c772eb4087f854f46256132 + required: + - sha + - url + verification: + $ref: '#/components/schemas/verification' + required: + - author + - committer + - comment_count + - message + - tree + - url + author: + $ref: '#/components/schemas/nullable-simple-user' + committer: + $ref: '#/components/schemas/nullable-simple-user' + parents: + type: array + items: + type: object + properties: + sha: + type: string + example: 7638417db6d59f3c431d3e1f261cc637155684cd + url: + type: string + format: uri + example: >- + https://api.github.com/repos/octocat/Hello-World/commits/7638417db6d59f3c431d3e1f261cc637155684cd + html_url: + type: string + format: uri + example: >- + https://github.com/octocat/Hello-World/commit/7638417db6d59f3c431d3e1f261cc637155684cd + required: + - sha + - url + stats: + type: object + properties: + additions: + type: integer + deletions: + type: integer + total: + type: integer + files: + type: array + items: + $ref: '#/components/schemas/diff-entry' + required: + - url + - sha + - node_id + - html_url + - comments_url + - commit + - author + - committer + - parents + securitySchemes: + apikey: + type: apiKey + name: x-api-key + in: header diff --git a/test/test_files/yaml/openapi_edge_cases.yml b/test/test_files/yaml/openapi_edge_cases.yml new file mode 100644 index 00000000..cef304c5 --- /dev/null +++ b/test/test_files/yaml/openapi_edge_cases.yml @@ -0,0 +1,13 @@ +openapi: 3.0.0 +info: + title: Edge Cases API + version: 1.0.0 +servers: + - url: http://localhost # not used anyway +paths: + /missing-operation-id: + get: + summary: Missing operationId + responses: + '200': + description: OK diff --git a/test/test_files/yaml/openapi_error_handling.yml b/test/test_files/yaml/openapi_error_handling.yml new file mode 100644 index 00000000..5cf23fe5 --- /dev/null +++ b/test/test_files/yaml/openapi_error_handling.yml @@ -0,0 +1,24 @@ +openapi: 3.0.0 +info: + title: Error Handling API + version: 1.0.0 +servers: + - url: http://localhost # not used anyway +paths: + /error/{status_code}: + get: + summary: Raise HTTP error + operationId: raiseHttpError + parameters: + - name: status_code + in: path + required: true + schema: + type: integer + responses: + '400': + description: Bad Request + '401': + description: Unauthorized + '404': + description: Not Found \ No newline at end of file diff --git a/test/test_files/yaml/openapi_greeting_service.yml b/test/test_files/yaml/openapi_greeting_service.yml new file mode 100644 index 00000000..701dee33 --- /dev/null +++ b/test/test_files/yaml/openapi_greeting_service.yml @@ -0,0 +1,272 @@ +openapi: 3.0.0 +info: + title: Greeting Service + version: 1.0.0 +servers: + - url: http://localhost # not used anyway +paths: + /greet/{name}: + post: + operationId: greet + parameters: + - name: name + in: path + required: true + schema: + type: string + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/MessageBody' + responses: + '200': + description: Successful response + content: + application/json: + schema: + $ref: '#/components/schemas/GreetingResponse' + + /greet-params/{name}: + get: + operationId: greetParams + parameters: + - name: name + in: path + required: true + schema: + type: string + responses: + '200': + description: Successful response + content: + application/json: + schema: + $ref: '#/components/schemas/GreetingResponse' + + /greet-body: + post: + operationId: greetBody + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/GreetBody' + responses: + '200': + description: Successful response + content: + application/json: + schema: + $ref: '#/components/schemas/GreetingResponse' + + /greet-api-key/{name}: + get: + operationId: greetApiKey + security: + - ApiKeyAuth: [] + parameters: + - name: name + in: path + required: true + schema: + type: string + responses: + '200': + description: Successful response + content: + application/json: + schema: + $ref: '#/components/schemas/GreetingResponse' + '401': + description: Unauthorized + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + + /greet-basic-auth/{name}: + get: + operationId: greetBasicAuth + security: + - BasicAuth: [] + parameters: + - name: name + in: path + required: true + schema: + type: string + responses: + '200': + description: Successful response + content: + application/json: + schema: + $ref: '#/components/schemas/GreetingResponse' + '401': + description: Unauthorized + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + /greet-api-key-query/{name}: + get: + operationId: greetApiKeyQuery + security: + - ApiKeyAuthQuery: [ ] + parameters: + - name: name + in: path + required: true + schema: + type: string + responses: + '200': + description: Successful response + content: + application/json: + schema: + $ref: '#/components/schemas/GreetingResponse' + '401': + description: Unauthorized + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + + /greet-api-key-cookie/{name}: + get: + operationId: greetApiKeyCookie + security: + - ApiKeyAuthCookie: [ ] + parameters: + - name: name + in: path + required: true + schema: + type: string + responses: + '200': + description: Successful response + content: + application/json: + schema: + $ref: '#/components/schemas/GreetingResponse' + '401': + description: Unauthorized + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + + /greet-bearer-auth/{name}: + get: + operationId: greetBearerAuth + security: + - BearerAuth: [ ] + parameters: + - name: name + in: path + required: true + schema: + type: string + responses: + '200': + description: Successful response + content: + application/json: + schema: + $ref: '#/components/schemas/GreetingResponse' + '401': + description: Unauthorized + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + + /greet-oauth/{name}: + get: + operationId: greetOAuth + security: + - OAuth2: [ ] + parameters: + - name: name + in: path + required: true + schema: + type: string + responses: + '200': + description: Successful response + content: + application/json: + schema: + $ref: '#/components/schemas/GreetingResponse' + '401': + description: Unauthorized + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' +components: + securitySchemes: + ApiKeyAuth: + type: apiKey + in: header + name: X-API-Key + BasicAuth: + type: http + scheme: basic + ApiKeyAuthQuery: + type: apiKey + in: query + name: api_key + ApiKeyAuthCookie: + type: apiKey + in: cookie + name: api_key + BearerAuth: + type: http + scheme: bearer + OAuth2: + type: oauth2 + flows: + authorizationCode: + authorizationUrl: https://example.com/oauth/authorize + tokenUrl: https://example.com/oauth/token + scopes: + read:greet: Read access to greeting service + + schemas: + GreetBody: + type: object + properties: + message: + type: string + name: + type: string + required: + - message + - name + + MessageBody: + type: object + properties: + message: + type: string + required: + - message + + GreetingResponse: + type: object + properties: + greeting: + type: string + + ErrorResponse: + type: object + properties: + detail: + type: string \ No newline at end of file diff --git a/test/test_files/yaml/openapi_order_service.yml b/test/test_files/yaml/openapi_order_service.yml new file mode 100644 index 00000000..07360ea5 --- /dev/null +++ b/test/test_files/yaml/openapi_order_service.yml @@ -0,0 +1,75 @@ +openapi: 3.0.0 +info: + title: Order Service + version: 1.0.0 +servers: + - url: http://localhost # not used anyway +paths: + /orders: + post: + summary: Create a new order + operationId: createOrder + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/Order' + responses: + '201': + description: Created + content: + application/json: + schema: + $ref: '#/components/schemas/OrderResponse' + +components: + schemas: + Order: + type: object + properties: + customer: + $ref: '#/components/schemas/Customer' + items: + type: array + items: + $ref: '#/components/schemas/OrderItem' + required: + - customer + - items + + Customer: + type: object + properties: + name: + type: string + email: + type: string + required: + - name + - email + + OrderItem: + type: object + properties: + product: + type: string + quantity: + type: integer + required: + - product + - quantity + + OrderResponse: + type: object + properties: + orderId: + type: string + status: + type: string + totalAmount: + type: number + required: + - orderId + - status + - totalAmount \ No newline at end of file diff --git a/test/test_files/yaml/serper.yml b/test/test_files/yaml/serper.yml new file mode 100644 index 00000000..9d5b1a85 --- /dev/null +++ b/test/test_files/yaml/serper.yml @@ -0,0 +1,39 @@ +openapi: 3.0.0 +info: + title: SerperDev + version: 1.0.0 + description: API for performing search queries +servers: + - url: https://google.serper.dev +paths: + /search: + post: + operationId: serperdev_search + description: Search the web with Google + requestBody: + required: true + content: + application/json: + schema: + type: object + properties: + q: + type: string + required: + - q + responses: + '200': + description: Successful response + content: + application/json: + schema: + type: object + additionalProperties: true + security: + - apikey: [] +components: + securitySchemes: + apikey: + type: apiKey + name: x-api-key + in: header diff --git a/test/util/__init__.py b/test/util/__init__.py new file mode 100644 index 00000000..c1764a6e --- /dev/null +++ b/test/util/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 diff --git a/test/util/conftest.py b/test/util/conftest.py new file mode 100644 index 00000000..bb575703 --- /dev/null +++ b/test/util/conftest.py @@ -0,0 +1,59 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + + +from pathlib import Path +from urllib.parse import urlparse + +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient + +from haystack_experimental.util.openapi import HttpClient, HttpClientError + + +@pytest.fixture() +def test_files_path(): + return Path(__file__).parent.parent / "test_files" + + +class FastAPITestClient(HttpClient): + + def __init__(self, app: FastAPI): + self.app = app + self.client = TestClient(app) + + def strip_host(self, url: str) -> str: + parsed_url = urlparse(url) + new_path = parsed_url.path + if parsed_url.query: + new_path += "?" + parsed_url.query + return new_path + + def send_request(self, request: dict) -> dict: + method = request["method"] + # OAS spec will list a server URL, but FastAPI doesn't need it for local testing, in fact it will fail + # if the URL has a host. So we strip it here. + url = self.strip_host(request["url"]) + headers = request.get("headers", {}) + params = request.get("params", {}) + json_data = request.get("json", None) + auth = request.get("auth", None) + cookies = request.get("cookies", {}) + + try: + response = self.client.request( + method, + url, + headers=headers, + params=params, + json=json_data, + auth=auth, + cookies=cookies, + ) + response.raise_for_status() + return response.json() + except Exception as e: + # Handle HTTP errors + raise HttpClientError(f"HTTP error occurred: {e}") from e diff --git a/test/util/test_openapi_client.py b/test/util/test_openapi_client.py new file mode 100644 index 00000000..581dfdf8 --- /dev/null +++ b/test/util/test_openapi_client.py @@ -0,0 +1,129 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + + +from fastapi import FastAPI +from fastapi.responses import JSONResponse +from pydantic import BaseModel + +from haystack_experimental.util.openapi import ClientConfigurationBuilder, OpenAPIServiceClient +from test.util.conftest import FastAPITestClient + +""" +Tests OpenAPIServiceClient with three FastAPI apps for different parameter types: + +- **greet_mix_params_body**: A POST endpoint `/greet/` accepting a JSON payload with a message, returning a +greeting with the name from the URL and the message from the payload. + +- **greet_params_only**: A GET endpoint `/greet-params/` taking a URL parameter, returning a greeting with +the name from the URL. + +- **greet_request_body_only**: A POST endpoint `/greet-body` accepting a JSON payload with a name and message, +returning a greeting with both. + +OpenAPI specs for these endpoints are in `openapi_greeting_service.yml` in `test/test_files` directory. +""" + + +class GreetBody(BaseModel): + message: str + name: str + + +class MessageBody(BaseModel): + message: str + + +# FastAPI app definitions +def create_greet_mix_params_body_app() -> FastAPI: + app = FastAPI() + + @app.post("/greet/{name}") + def greet(name: str, body: MessageBody): + greeting = f"{body.message}, {name} from mix_params_body!" + return JSONResponse(content={"greeting": greeting}) + + return app + + +def create_greet_params_only_app() -> FastAPI: + app = FastAPI() + + @app.get("/greet-params/{name}") + def greet_params(name: str): + greeting = f"Hello, {name} from params_only!" + return JSONResponse(content={"greeting": greeting}) + + return app + + +def create_greet_request_body_only_app() -> FastAPI: + app = FastAPI() + + @app.post("/greet-body") + def greet_request_body(body: GreetBody): + greeting = f"{body.message}, {body.name} from request_body_only!" + return JSONResponse(content={"greeting": greeting}) + + return app + + +class TestOpenAPI: + + def test_greet_mix_params_body(self, test_files_path): + builder = ClientConfigurationBuilder() + config = ( + builder.with_openapi_spec(test_files_path / "yaml" / "openapi_greeting_service.yml") + .with_http_client(FastAPITestClient(create_greet_mix_params_body_app())) + .build() + ) + client = OpenAPIServiceClient(config) + payload = { + "id": "call_NJr1NBz2Th7iUWJpRIJZoJIA", + "function": { + "arguments": '{"name": "John", "message": "Bonjour"}', + "name": "greet", + }, + "type": "function", + } + response = client.invoke(payload) + assert response == {"greeting": "Bonjour, John from mix_params_body!"} + + def test_greet_params_only(self, test_files_path): + builder = ClientConfigurationBuilder() + config = ( + builder.with_openapi_spec(test_files_path / "yaml" / "openapi_greeting_service.yml") + .with_http_client(FastAPITestClient(create_greet_params_only_app())) + .build() + ) + client = OpenAPIServiceClient(config) + payload = { + "id": "call_NJr1NBz2Th7iUWJpRIJZoJIA", + "function": { + "arguments": '{"name": "John"}', + "name": "greetParams", + }, + "type": "function", + } + response = client.invoke(payload) + assert response == {"greeting": "Hello, John from params_only!"} + + def test_greet_request_body_only(self, test_files_path): + builder = ClientConfigurationBuilder() + config = ( + builder.with_openapi_spec(test_files_path / "yaml" / "openapi_greeting_service.yml") + .with_http_client(FastAPITestClient(create_greet_request_body_only_app())) + .build() + ) + client = OpenAPIServiceClient(config) + payload = { + "id": "call_NJr1NBz2Th7iUWJpRIJZoJIA", + "function": { + "arguments": '{"name": "John", "message": "Hola"}', + "name": "greetBody", + }, + "type": "function", + } + response = client.invoke(payload) + assert response == {"greeting": "Hola, John from request_body_only!"} diff --git a/test/util/test_openapi_client_auth.py b/test/util/test_openapi_client_auth.py new file mode 100644 index 00000000..d326d164 --- /dev/null +++ b/test/util/test_openapi_client_auth.py @@ -0,0 +1,238 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + + +from fastapi import Depends, FastAPI, HTTPException, status +from fastapi.responses import JSONResponse +from fastapi.security import ( + APIKeyCookie, + APIKeyHeader, + APIKeyQuery, + HTTPAuthorizationCredentials, + HTTPBasic, + HTTPBasicCredentials, + HTTPBearer, +) + +from haystack_experimental.util.openapi import ClientConfigurationBuilder, OpenAPIServiceClient, ApiKeyAuthentication, \ + HTTPAuthentication +from test.util.conftest import FastAPITestClient + +API_KEY = "secret_api_key" +BASIC_AUTH_USERNAME = "admin" +BASIC_AUTH_PASSWORD = "secret_password" + +API_KEY_QUERY = "secret_api_key_query" +API_KEY_COOKIE = "secret_api_key_cookie" +BEARER_TOKEN = "secret_bearer_token" + +OAUTH_TOKEN = "secret-oauth-token" + +api_key_query = APIKeyQuery(name="api_key") +api_key_cookie = APIKeyCookie(name="api_key") +bearer_auth = HTTPBearer() + +api_key_header = APIKeyHeader(name="X-API-Key") +basic_auth_http = HTTPBasic() + + +def create_greet_api_key_query_app() -> FastAPI: + app = FastAPI() + + def api_key_query_auth(api_key: str = Depends(api_key_query)): + if api_key != API_KEY_QUERY: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key") + return api_key + + @app.get("/greet-api-key-query/{name}") + def greet_api_key_query(name: str, api_key: str = Depends(api_key_query_auth)): + greeting = f"Hello, {name} from api_key_query_auth, using {api_key}" + return JSONResponse(content={"greeting": greeting}) + + return app + + +def create_greet_api_key_cookie_app() -> FastAPI: + app = FastAPI() + + def api_key_cookie_auth(api_key: str = Depends(api_key_cookie)): + if api_key != API_KEY_COOKIE: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key") + return api_key + + @app.get("/greet-api-key-cookie/{name}") + def greet_api_key_cookie(name: str, api_key: str = Depends(api_key_cookie_auth)): + greeting = f"Hello, {name} from api_key_cookie_auth, using {api_key}" + return JSONResponse(content={"greeting": greeting}) + + return app + + +def create_greet_bearer_auth_app() -> FastAPI: + app = FastAPI() + + def bearer_auth_scheme( + credentials: HTTPAuthorizationCredentials = Depends(bearer_auth), # noqa: B008 + ): + if credentials.scheme != "Bearer" or credentials.credentials != BEARER_TOKEN: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token") + return credentials.credentials + + @app.get("/greet-bearer-auth/{name}") + def greet_bearer_auth(name: str, token: str = Depends(bearer_auth_scheme)): + greeting = f"Hello, {name} from bearer_auth, using {token}" + return JSONResponse(content={"greeting": greeting}) + + return app + + +def create_greet_api_key_auth_app() -> FastAPI: + app = FastAPI() + + def api_key_auth(api_key: str = Depends(api_key_header)): + if api_key != API_KEY: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key") + return api_key + + @app.get("/greet-api-key/{name}") + def greet_api_key(name: str, api_key: str = Depends(api_key_auth)): + greeting = f"Hello, {name} from api_key_auth, using {api_key}" + return JSONResponse(content={"greeting": greeting}) + + return app + + +def create_greet_basic_auth_app() -> FastAPI: + app = FastAPI() + + def basic_auth(credentials: HTTPBasicCredentials = Depends(basic_auth_http)): # noqa: B008 + if credentials.username != BASIC_AUTH_USERNAME or credentials.password != BASIC_AUTH_PASSWORD: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid credentials") + return credentials.username + + @app.get("/greet-basic-auth/{name}") + def greet_basic_auth(name: str, username: str = Depends(basic_auth)): + greeting = f"Hello, {name} from basic_auth, using {username}" + return JSONResponse(content={"greeting": greeting}) + + return app + + +def create_greet_oauth_auth_app() -> FastAPI: + app = FastAPI() + + def oauth_auth(token: HTTPAuthorizationCredentials = Depends(HTTPBearer())): # noqa: B008 + if token.credentials != OAUTH_TOKEN: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token") + return token + + @app.get("/greet-oauth/{name}") + def greet_oauth(name: str, token: HTTPAuthorizationCredentials = Depends(oauth_auth)): # noqa: B008 + greeting = f"Hello, {name} from oauth_auth, using {token}" + return JSONResponse(content={"greeting": greeting}) + + return app + + +class TestOpenAPIAuth: + + def test_greet_api_key_auth(self, test_files_path): + builder = ClientConfigurationBuilder() + config = ( + builder.with_openapi_spec(test_files_path / "yaml" / "openapi_greeting_service.yml") + .with_http_client(FastAPITestClient(create_greet_api_key_auth_app())) + .with_credentials(ApiKeyAuthentication(API_KEY)) + .build() + ) + client = OpenAPIServiceClient(config) + payload = { + "id": "call_NJr1NBz2Th7iUWJpRIJZoJIA", + "function": { + "arguments": '{"name": "John"}', + "name": "greetApiKey", + }, + "type": "function", + } + response = client.invoke(payload) + assert response == {"greeting": "Hello, John from api_key_auth, using secret_api_key"} + + def test_greet_basic_auth(self, test_files_path): + builder = ClientConfigurationBuilder() + config = ( + builder.with_openapi_spec(test_files_path / "yaml" / "openapi_greeting_service.yml") + .with_http_client(FastAPITestClient(create_greet_basic_auth_app())) + .with_credentials(HTTPAuthentication(BASIC_AUTH_USERNAME, BASIC_AUTH_PASSWORD)) + .build() + ) + client = OpenAPIServiceClient(config) + payload = { + "id": "call_NJr1NBz2Th7iUWJpRIJZoJIA", + "function": { + "arguments": '{"name": "John"}', + "name": "greetBasicAuth", + }, + "type": "function", + } + response = client.invoke(payload) + assert response == {"greeting": "Hello, John from basic_auth, using admin"} + + def test_greet_api_key_query_auth(self, test_files_path): + builder = ClientConfigurationBuilder() + config = ( + builder.with_openapi_spec(test_files_path / "yaml" / "openapi_greeting_service.yml") + .with_http_client(FastAPITestClient(create_greet_api_key_query_app())) + .with_credentials(ApiKeyAuthentication(API_KEY_QUERY)) + .build() + ) + client = OpenAPIServiceClient(config) + payload = { + "id": "call_NJr1NBz2Th7iUWJpRIJZoJIA", + "function": { + "arguments": '{"name": "John"}', + "name": "greetApiKeyQuery", + }, + "type": "function", + } + response = client.invoke(payload) + assert response == {"greeting": "Hello, John from api_key_query_auth, using secret_api_key_query"} + + def test_greet_api_key_cookie_auth(self, test_files_path): + builder = ClientConfigurationBuilder() + config = ( + builder.with_openapi_spec(test_files_path / "yaml" / "openapi_greeting_service.yml") + .with_http_client(FastAPITestClient(create_greet_api_key_cookie_app())) + .with_credentials(ApiKeyAuthentication(API_KEY_COOKIE)) + .build() + ) + client = OpenAPIServiceClient(config) + payload = { + "id": "call_NJr1NBz2Th7iUWJpRIJZoJIA", + "function": { + "arguments": '{"name": "John"}', + "name": "greetApiKeyCookie", + }, + "type": "function", + } + response = client.invoke(payload) + assert response == {"greeting": "Hello, John from api_key_cookie_auth, using secret_api_key_cookie"} + + def test_greet_bearer_auth(self, test_files_path): + builder = ClientConfigurationBuilder() + config = ( + builder.with_openapi_spec(test_files_path / "yaml" / "openapi_greeting_service.yml") + .with_http_client(FastAPITestClient(create_greet_bearer_auth_app())) + .with_credentials(HTTPAuthentication(token=BEARER_TOKEN)) + .build() + ) + client = OpenAPIServiceClient(config) + payload = { + "id": "call_NJr1NBz2Th7iUWJpRIJZoJIA", + "function": { + "arguments": '{"name": "John"}', + "name": "greetBearerAuth", + }, + "type": "function", + } + response = client.invoke(payload) + assert response == {"greeting": "Hello, John from bearer_auth, using secret_bearer_token"} diff --git a/test/util/test_openapi_client_complex_request_body.py b/test/util/test_openapi_client_complex_request_body.py new file mode 100644 index 00000000..17702cb3 --- /dev/null +++ b/test/util/test_openapi_client_complex_request_body.py @@ -0,0 +1,87 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + + +import json +from typing import List + +import pytest +from fastapi import FastAPI +from fastapi.responses import JSONResponse +from pydantic import BaseModel + +from haystack_experimental.util.openapi import ClientConfigurationBuilder, OpenAPIServiceClient +from test.util.conftest import FastAPITestClient + + +class Customer(BaseModel): + name: str + email: str + + +class OrderItem(BaseModel): + product: str + quantity: int + + +class Order(BaseModel): + customer: Customer + items: List[OrderItem] + + +class OrderResponse(BaseModel): + orderId: str # noqa: N815 + status: str + totalAmount: float # noqa: N815 + + +def create_order_app() -> FastAPI: + app = FastAPI() + + @app.post("/orders") + def create_order(order: Order): + total_amount = sum(item.quantity * 10 for item in order.items) + response = OrderResponse( + orderId="ORDER-001", + status="CREATED", + totalAmount=total_amount, + ) + return JSONResponse(content=response.model_dump(), status_code=201) + + return app + + +class TestComplexRequestBody: + + @pytest.mark.parametrize("spec_file_path", ["openapi_order_service.yml", "openapi_order_service.json"]) + def test_create_order(self, spec_file_path, test_files_path): + builder = ClientConfigurationBuilder() + path_element = "yaml" if spec_file_path.endswith(".yml") else "json" + config = ( + builder.with_openapi_spec(test_files_path / path_element / spec_file_path) + .with_http_client(FastAPITestClient(create_order_app())) + .build() + ) + client = OpenAPIServiceClient(config) + order_json = { + "customer": {"name": "John Doe", "email": "john@example.com"}, + "items": [ + {"product": "Product A", "quantity": 2}, + {"product": "Product B", "quantity": 1}, + ], + } + payload = { + "id": "call_NJr1NBz2Th7iUWJpRIJZoJIA", + "function": { + "arguments": json.dumps(order_json), + "name": "createOrder", + }, + "type": "function", + } + response = client.invoke(payload) + assert response == { + "orderId": "ORDER-001", + "status": "CREATED", + "totalAmount": 30, + } diff --git a/test/util/test_openapi_client_complex_request_body_mixed.py b/test/util/test_openapi_client_complex_request_body_mixed.py new file mode 100644 index 00000000..907520eb --- /dev/null +++ b/test/util/test_openapi_client_complex_request_body_mixed.py @@ -0,0 +1,90 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + + +import json + +from fastapi import FastAPI +from fastapi.responses import JSONResponse +from pydantic import BaseModel + +from haystack_experimental.util.openapi import ClientConfigurationBuilder, OpenAPIServiceClient +from test.util.conftest import FastAPITestClient + + +class Identification(BaseModel): + type: str + number: str + + +class Payer(BaseModel): + name: str + email: str + identification: Identification + + +class PaymentRequest(BaseModel): + transaction_amount: float + description: str + payment_method_id: str + payer: Payer + + +class PaymentResponse(BaseModel): + transaction_id: str + status: str + message: str + + +def create_payment_app() -> FastAPI: + app = FastAPI() + + @app.post("/new_payment") + def process_payment(payment: PaymentRequest): + # sanity + assert payment.transaction_amount == 100.0 + response = PaymentResponse( + transaction_id="TRANS-12345", status="SUCCESS", message="Payment processed successfully." + ) + return JSONResponse(content=response.model_dump(), status_code=200) + + return app + + +# Write the unit test +class TestPaymentProcess: + + def test_process_payment(self, test_files_path): + config = ( + ClientConfigurationBuilder() + .with_openapi_spec(test_files_path / "json" / "complex_types_openapi_service.json") + .with_http_client(FastAPITestClient(create_payment_app())) + .build() + ) + client = OpenAPIServiceClient(config) + + payment_json = { + "transaction_amount": 100.0, + "description": "Test Payment", + "payment_method_id": "CARD-123", + "payer": { + "name": "Alice Smith", + "email": "alice@example.com", + "identification": {"type": "CPF", "number": "123.456.789-00"}, + }, + } + payload = { + "id": "call_uniqueID123", + "function": { + "arguments": json.dumps(payment_json), + "name": "processPayment", + }, + "type": "function", + } + response = client.invoke(payload) + assert response == { + "transaction_id": "TRANS-12345", + "status": "SUCCESS", + "message": "Payment processed successfully.", + } diff --git a/test/util/test_openapi_client_edge_cases.py b/test/util/test_openapi_client_edge_cases.py new file mode 100644 index 00000000..f421e73e --- /dev/null +++ b/test/util/test_openapi_client_edge_cases.py @@ -0,0 +1,38 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + + +import pytest + +from haystack_experimental.util.openapi import ClientConfigurationBuilder, OpenAPIServiceClient +from test.util.conftest import FastAPITestClient + + +class TestEdgeCases: + def test_invalid_openapi_spec(self): + builder = ClientConfigurationBuilder() + with pytest.raises(ValueError, match="Invalid OpenAPI specification"): + config = builder.with_openapi_spec("invalid_spec.yml").build() + OpenAPIServiceClient(config) + + def test_missing_operation_id(self, test_files_path): + builder = ClientConfigurationBuilder() + config = ( + builder.with_openapi_spec(test_files_path / "yaml" / "openapi_edge_cases.yml") + .with_http_client(FastAPITestClient(None)) + .build() + ) + client = OpenAPIServiceClient(config) + + payload = { + "type": "function", + "function": { + "arguments": '{"name": "John", "message": "Hola"}', + "name": "missingOperationId", + }, + } + with pytest.raises(ValueError, match="No operation found with operationId"): + client.invoke(payload) + + # TODO: Add more tests for edge cases diff --git a/test/util/test_openapi_client_error_handling.py b/test/util/test_openapi_client_error_handling.py new file mode 100644 index 00000000..3201ad65 --- /dev/null +++ b/test/util/test_openapi_client_error_handling.py @@ -0,0 +1,46 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + + +import json + +import pytest +from fastapi import FastAPI, HTTPException + +from haystack_experimental.util.openapi import ClientConfigurationBuilder, OpenAPIServiceClient, HttpClientError +from test.util.conftest import FastAPITestClient + + +def create_error_handling_app() -> FastAPI: + app = FastAPI() + + @app.get("/error/{status_code}") + def raise_http_error(status_code: int): + raise HTTPException(status_code=status_code, detail=f"HTTP {status_code} error") + + return app + + +class TestErrorHandling: + @pytest.mark.parametrize("status_code", [400, 401, 403, 404, 500]) + def test_http_error_handling(self, test_files_path, status_code): + builder = ClientConfigurationBuilder() + config = ( + builder.with_openapi_spec(test_files_path / "yaml" / "openapi_error_handling.yml") + .with_http_client(FastAPITestClient(create_error_handling_app())) + .build() + ) + client = OpenAPIServiceClient(config) + json_error = {"status_code": status_code} + payload = { + "type": "function", + "function": { + "arguments": json.dumps(json_error), + "name": "raiseHttpError", + }, + } + with pytest.raises(HttpClientError) as exc_info: + client.invoke(payload) + + assert str(status_code) in str(exc_info.value) diff --git a/test/util/test_openapi_client_live.py b/test/util/test_openapi_client_live.py new file mode 100644 index 00000000..869c56a8 --- /dev/null +++ b/test/util/test_openapi_client_live.py @@ -0,0 +1,73 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import json +import os + +import pytest +import yaml +from haystack_experimental.util.openapi import ClientConfigurationBuilder, OpenAPIServiceClient + + +class TestClientLive: + + @pytest.mark.skipif("SERPERDEV_API_KEY" not in os.environ, reason="SERPERDEV_API_KEY not set") + @pytest.mark.integration + def test_serperdev(self, test_files_path): + builder = ClientConfigurationBuilder() + config = ( + builder.with_openapi_spec(test_files_path / "yaml" / "serper.yml") + .with_credentials(os.getenv("SERPERDEV_API_KEY")) + .build() + ) + serper_api = OpenAPIServiceClient(config) + payload = { + "id": "call_NJr1NBz2Th7iUWJpRIJZoJIA", + "function": { + "arguments": '{"q": "Who was Nikola Tesla?"}', + "name": "serperdev_search", + }, + "type": "function", + } + response = serper_api.invoke(payload) + assert "invention" in str(response) + + @pytest.mark.skipif("SERPERDEV_API_KEY" not in os.environ, reason="SERPERDEV_API_KEY not set") + @pytest.mark.integration + def test_serperdev_load_spec_first(self, test_files_path): + with open(test_files_path / "yaml" / "serper.yml") as file: + loaded_spec = yaml.safe_load(file) + + # use builder with dict spec + builder = ClientConfigurationBuilder() + config = builder.with_openapi_spec(loaded_spec).with_credentials(os.getenv("SERPERDEV_API_KEY")).build() + serper_api = OpenAPIServiceClient(config) + payload = { + "id": "call_NJr1NBz2Th7iUWJpRIJZoJIA", + "function": { + "arguments": '{"q": "Who was Nikola Tesla?"}', + "name": "serperdev_search", + }, + "type": "function", + } + response = serper_api.invoke(payload) + assert "invention" in str(response) + + @pytest.mark.integration + def test_github(self, test_files_path): + builder = ClientConfigurationBuilder() + config = builder.with_openapi_spec(test_files_path / "yaml" / "github_compare.yml").build() + api = OpenAPIServiceClient(config) + + params = {"owner": "deepset-ai", "repo": "haystack", "basehead": "main...add_default_adapter_filters"} + payload = { + "id": "call_NJr1NBz2Th7iUWJpRIJZoJIA", + "function": { + "arguments": json.dumps(params), + "name": "compare", + }, + "type": "function", + } + response = api.invoke(payload) + assert "deepset" in str(response) diff --git a/test/util/test_openapi_client_live_anthropic.py b/test/util/test_openapi_client_live_anthropic.py new file mode 100644 index 00000000..0ac5c09c --- /dev/null +++ b/test/util/test_openapi_client_live_anthropic.py @@ -0,0 +1,69 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import os + +import anthropic +import pytest + +from haystack_experimental.util.openapi import ClientConfigurationBuilder, OpenAPIServiceClient + + +class TestClientLiveAnthropic: + + @pytest.mark.skipif("SERPERDEV_API_KEY" not in os.environ, reason="SERPERDEV_API_KEY not set") + @pytest.mark.skipif("ANTHROPIC_API_KEY" not in os.environ, reason="ANTHROPIC_API_KEY not set") + @pytest.mark.integration + def test_serperdev(self): + builder = ClientConfigurationBuilder() + config = ( + builder.with_openapi_spec("https://bit.ly/serper_dev_spec_yaml") + .with_credentials(os.getenv("SERPERDEV_API_KEY")) + .with_provider("anthropic") + .build() + ) + client = anthropic.Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY")) + response = client.beta.tools.messages.create( + model="claude-3-opus-20240229", + max_tokens=1024, + tools=config.get_tools_definitions(), + messages=[{"role": "user", "content": "Do a google search: Who was Nikola Tesla?"}], + ) + service_api = OpenAPIServiceClient(config) + service_response = service_api.invoke(response) + assert "inventions" in str(service_response) + + # make a few more requests to test the same tool + service_response = service_api.invoke(response) + assert "Serbian" in str(service_response) + + service_response = service_api.invoke(response) + assert "American" in str(service_response) + + @pytest.mark.skipif("ANTHROPIC_API_KEY" not in os.environ, reason="ANTHROPIC_API_KEY not set") + @pytest.mark.integration + def test_github(self, test_files_path): + builder = ClientConfigurationBuilder() + config = ( + builder.with_openapi_spec(test_files_path / "yaml" / "github_compare.yml") + .with_provider("anthropic") + .build() + ) + + client = anthropic.Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY")) + response = client.beta.tools.messages.create( + model="claude-3-opus-20240229", + max_tokens=1024, + tools=config.get_tools_definitions(), + messages=[ + { + "role": "user", + "content": "Compare branches main and add_default_adapter_filters in repo" + " haystack and owner deepset-ai", + } + ], + ) + service_api = OpenAPIServiceClient(config) + service_response = service_api.invoke(response) + assert "deepset" in str(service_response) diff --git a/test/util/test_openapi_client_live_cohere.py b/test/util/test_openapi_client_live_cohere.py new file mode 100644 index 00000000..901b474d --- /dev/null +++ b/test/util/test_openapi_client_live_cohere.py @@ -0,0 +1,73 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +import os + +import cohere +import pytest + +from haystack_experimental.util.openapi import ClientConfigurationBuilder, OpenAPIServiceClient + +# Copied from Cohere's documentation +preamble = """ +## Task & Context +You help people answer their questions and other requests interactively. You will be asked a very wide array of + requests on all kinds of topics. You will be equipped with a wide range of search engines or similar tools to + help you, which you use to research your answer. You should focus on serving the user's needs as best you can, + which will be wide-ranging. + +## Style Guide +Unless the user asks for a different style of answer, you should answer in full sentences, using proper grammar and + spelling. +""" + + +class TestClientLiveCohere: + + @pytest.mark.skipif("SERPERDEV_API_KEY" not in os.environ, reason="SERPERDEV_API_KEY not set") + @pytest.mark.skipif("COHERE_API_KEY" not in os.environ, reason="COHERE_API_KEY not set") + @pytest.mark.integration + def test_serperdev(self, test_files_path): + builder = ClientConfigurationBuilder() + config = ( + builder.with_openapi_spec(test_files_path / "yaml" / "serper.yml") + .with_credentials(os.getenv("SERPERDEV_API_KEY")) + .with_provider("cohere") + .build() + ) + client = cohere.Client(api_key=os.getenv("COHERE_API_KEY")) + response = client.chat( + model="command-r", + preamble=preamble, + tools=config.get_tools_definitions(), + message="Do a google search: Who was Nikola Tesla?", + ) + service_api = OpenAPIServiceClient(config) + service_response = service_api.invoke(response) + assert "inventions" in str(service_response) + + # make a few more requests to test the same tool + service_response = service_api.invoke(response) + assert "Serbian" in str(service_response) + + service_response = service_api.invoke(response) + assert "American" in str(service_response) + + @pytest.mark.skipif("COHERE_API_KEY" not in os.environ, reason="COHERE_API_KEY not set") + @pytest.mark.integration + def test_github(self, test_files_path): + builder = ClientConfigurationBuilder() + config = ( + builder.with_openapi_spec(test_files_path / "yaml" / "github_compare.yml").with_provider("cohere").build() + ) + + client = cohere.Client(api_key=os.getenv("COHERE_API_KEY")) + response = client.chat( + model="command-r", + preamble=preamble, + tools=config.get_tools_definitions(), + message="Compare branches main and add_default_adapter_filters in repo haystack and owner deepset-ai", + ) + service_api = OpenAPIServiceClient(config) + service_response = service_api.invoke(response) + assert "deepset" in str(service_response) diff --git a/test/util/test_openapi_client_live_openai.py b/test/util/test_openapi_client_live_openai.py new file mode 100644 index 00000000..bfc00e06 --- /dev/null +++ b/test/util/test_openapi_client_live_openai.py @@ -0,0 +1,99 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import os + +import pytest +from openai import OpenAI + +from haystack_experimental.util.openapi import ClientConfigurationBuilder, OpenAPIServiceClient + + +class TestClientLiveOpenAPI: + + @pytest.mark.skipif("SERPERDEV_API_KEY" not in os.environ, reason="SERPERDEV_API_KEY not set") + @pytest.mark.skipif("OPENAI_API_KEY" not in os.environ, reason="OPENAI_API_KEY not set") + @pytest.mark.integration + def test_serperdev(self): + builder = ClientConfigurationBuilder() + config = ( + builder.with_openapi_spec("https://bit.ly/serper_dev_spec_yaml") + .with_credentials(os.getenv("SERPERDEV_API_KEY")) + .build() + ) + client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) + response = client.chat.completions.create( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "Do a serperdev google search: Who was Nikola Tesla?"}], + tools=config.get_tools_definitions(), + ) + service_api = OpenAPIServiceClient(config) + service_response = service_api.invoke(response) + assert "inventions" in str(service_response) + + # make a few more requests to test the same tool + service_response = service_api.invoke(response) + assert "Serbian" in str(service_response) + + service_response = service_api.invoke(response) + assert "American" in str(service_response) + + @pytest.mark.skipif("OPENAI_API_KEY" not in os.environ, reason="OPENAI_API_KEY not set") + @pytest.mark.integration + def test_github(self, test_files_path): + builder = ClientConfigurationBuilder() + config = builder.with_openapi_spec(test_files_path / "yaml" / "github_compare.yml").build() + + client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) + response = client.chat.completions.create( + model="gpt-3.5-turbo", + messages=[ + { + "role": "user", + "content": "Compare branches main and add_default_adapter_filters in repo" + " haystack and owner deepset-ai", + } + ], + tools=config.get_tools_definitions(), + ) + service_api = OpenAPIServiceClient(config) + service_response = service_api.invoke(response) + assert "deepset" in str(service_response) + + @pytest.mark.skipif("FIRECRAWL_API_KEY" not in os.environ, reason="FIRECRAWL_API_KEY not set") + @pytest.mark.skipif("OPENAI_API_KEY" not in os.environ, reason="OPENAI_API_KEY not set") + @pytest.mark.integration + def test_firecrawl(self): + openapi_spec_url = "https://raw.githubusercontent.com/mendableai/firecrawl/main/apps/api/openapi.json" + builder = ClientConfigurationBuilder() + config = builder.with_openapi_spec(openapi_spec_url).with_credentials(os.getenv("FIRECRAWL_API_KEY")).build() + + client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) + response = client.chat.completions.create( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "Scrape URL: https://news.ycombinator.com/"}], + tools=config.get_tools_definitions(), + ) + service_api = OpenAPIServiceClient(config) + service_response = service_api.invoke(response) + assert isinstance(service_response, dict) + assert service_response.get("success", False), "Firecrawl scrape API call failed" + + # now test the same openapi service but different endpoint/tool + top_k = 2 + response = client.chat.completions.create( + model="gpt-3.5-turbo", + messages=[ + { + "role": "user", + "content": f"Search Google for `Why was Sam Altman ousted from OpenAI?`, limit to {top_k} results", + } + ], + tools=config.get_tools_definitions(), + ) + service_response = service_api.invoke(response) + assert isinstance(service_response, dict) + assert service_response.get("success", False), "Firecrawl search API call failed" + assert len(service_response.get("data", [])) == top_k + assert "Sam" in str(service_response) diff --git a/test/util/test_openapi_cohere_conversion.py b/test/util/test_openapi_cohere_conversion.py new file mode 100644 index 00000000..5afb1673 --- /dev/null +++ b/test/util/test_openapi_cohere_conversion.py @@ -0,0 +1,79 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from haystack_experimental.util.openapi import OpenAPISpecification, cohere_converter + + +class TestOpenAPISchemaConversion: + + def test_serperdev(self, test_files_path): + spec = OpenAPISpecification.from_file(test_files_path / "yaml" / "serper.yml") + functions = cohere_converter(schema=spec) + + assert functions + assert len(functions) == 1 + function = functions[0] + assert function["name"] == "serperdev_search" + assert function["description"] == "Search the web with Google" + assert function["parameter_definitions"] == {"q": {"description": "", "type": "str", "required": True}} + + def test_github(self, test_files_path): + spec = OpenAPISpecification.from_file(test_files_path / "yaml" / "github_compare.yml") + functions = cohere_converter(schema=spec) + assert functions + assert len(functions) == 1 + function = functions[0] + assert function["name"] == "compare_branches" + assert function["description"] == "Compares two branches against one another." + assert function["parameter_definitions"] == { + "basehead": { + "description": "The base branch and head branch to compare." + " This parameter expects the format `BASE...HEAD`", + "type": "str", + "required": True, + }, + "owner": { + "description": "The repository owner, usually a company or orgnization", + "type": "str", + "required": True, + }, + "repo": {"description": "The repository itself, the project", "type": "str", "required": True}, + } + + def test_complex_types(self, test_files_path): + spec = OpenAPISpecification.from_file(test_files_path / "json" / "complex_types_openapi_service.json") + functions = cohere_converter(schema=spec) + + assert functions + assert len(functions) == 1 + function = functions[0] + assert function["name"] == "processPayment" + assert function["description"] == "Process a new payment using the specified payment method" + assert function["parameter_definitions"] == { + "transaction_amount": {"type": "float", "description": "The amount to be paid", "required": True}, + "description": {"type": "str", "description": "A brief description of the payment", "required": True}, + "payment_method_id": {"type": "str", "description": "The payment method to be used", "required": True}, + "payer": { + "type": "object", + "description": "Information about the payer, including their name, email, and identification number", + "properties": { + "name": {"type": "str", "description": "The payer's name", "required": True}, + "email": {"type": "str", "description": "The payer's email address", "required": True}, + "identification": { + "type": "object", + "description": "The payer's identification number", + "properties": { + "type": { + "type": "str", + "description": "The type of identification document (e.g., CPF, CNPJ)", + "required": True, + }, + "number": {"type": "str", "description": "The identification number", "required": True}, + }, + "required": True, + }, + }, + "required": True, + }, + } diff --git a/test/util/test_openapi_openai_conversion.py b/test/util/test_openapi_openai_conversion.py new file mode 100644 index 00000000..fb3afa0e --- /dev/null +++ b/test/util/test_openapi_openai_conversion.py @@ -0,0 +1,104 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +from haystack_experimental.util.openapi import openai_converter, anthropic_converter, OpenAPISpecification + + +class TestOpenAPISchemaConversion: + + @pytest.mark.parametrize("provider", ["openai", "anthropic"]) + def test_serperdev(self, test_files_path, provider): + spec = OpenAPISpecification.from_file(test_files_path / "yaml" / "serper.yml") + functions = openai_converter(schema=spec) if provider == "openai" else anthropic_converter(schema=spec) + assert functions + assert len(functions) == 1 + function = functions[0]["function"] if provider == "openai" else functions[0] + assert function["name"] == "serperdev_search" + assert function["description"] == "Search the web with Google" + assert ( + function["parameters"] + if provider == "openai" + else function["input_schema"] + == {"type": "object", "properties": {"q": {"type": "string"}}, "required": ["q"]} + ) + + @pytest.mark.parametrize("provider", ["openai", "anthropic"]) + def test_github(self, test_files_path, provider: str): + spec = OpenAPISpecification.from_file(test_files_path / "yaml" / "github_compare.yml") + functions = openai_converter(schema=spec) if provider == "openai" else anthropic_converter(schema=spec) + assert functions + assert len(functions) == 1 + function = functions[0]["function"] if provider == "openai" else functions[0] + assert function["name"] == "compare_branches" + assert function["description"] == "Compares two branches against one another." + assert ( + function["parameters"] + if provider == "openai" + else function["input_schema"] + == { + "type": "object", + "properties": { + "basehead": { + "type": "string", + "description": "The base branch and head branch to compare. " + "This parameter expects the format `BASE...HEAD`", + }, + "owner": { + "type": "string", + "description": "The repository owner, usually a company or orgnization", + }, + "repo": {"type": "string", "description": "The repository itself, the project"}, + }, + "required": ["basehead", "owner", "repo"], + } + ) + + @pytest.mark.parametrize("provider", ["openai", "anthropic"]) + def test_complex_types(self, test_files_path, provider: str): + spec = OpenAPISpecification.from_file(test_files_path / "json" / "complex_types_openapi_service.json") + functions = openai_converter(schema=spec) if provider == "openai" else anthropic_converter(schema=spec) + + assert functions + assert len(functions) == 1 + function = functions[0]["function"] if provider == "openai" else functions[0] + assert function["name"] == "processPayment" + assert function["description"] == "Process a new payment using the specified payment method" + assert ( + function["parameters"] + if provider == "openai" + else function["input_schema"] + == { + "type": "object", + "properties": { + "transaction_amount": {"type": "number", "description": "The amount to be paid"}, + "description": {"type": "string", "description": "A brief description of the payment"}, + "payment_method_id": {"type": "string", "description": "The payment method to be used"}, + "payer": { + "type": "object", + "description": "Information about the payer, including their name, email, " + "and identification number", + "properties": { + "name": {"type": "string", "description": "The payer's name"}, + "email": {"type": "string", "description": "The payer's email address"}, + "identification": { + "type": "object", + "description": "The payer's identification number", + "properties": { + "type": { + "type": "string", + "description": "The type of identification document (e.g., CPF, CNPJ)", + }, + "number": {"type": "string", "description": "The identification number"}, + }, + "required": ["type", "number"], + }, + }, + "required": ["name", "email", "identification"], + }, + }, + "required": ["transaction_amount", "description", "payment_method_id", "payer"], + } + ) diff --git a/test/util/test_openapi_spec.py b/test/util/test_openapi_spec.py new file mode 100644 index 00000000..722946f1 --- /dev/null +++ b/test/util/test_openapi_spec.py @@ -0,0 +1,127 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import json + +import pytest + +from haystack_experimental.util.openapi import OpenAPISpecification + + +class TestOpenAPISpecification: + + # can be initialized from a dictionary + def test_initialized_from_dictionary(self): + spec_dict = { + "openapi": "3.0.0", + "info": {"title": "Test API", "version": "1.0.0"}, + "servers": [{"url": "https://api.example.com"}], + "paths": { + "/users": { + "get": {"summary": "Get all users", "responses": {"200": {"description": "Successful response"}}} + } + }, + } + openapi_spec = OpenAPISpecification.from_dict(spec_dict) + assert openapi_spec.spec_dict == spec_dict + + # can be initialized from a string + def test_initialized_from_string(self): + content = """ + openapi: 3.0.0 + info: + title: Test API + version: 1.0.0 + servers: + - url: https://api.example.com + paths: + /users: + get: + summary: Get all users + responses: + '200': + description: Successful response + """ + openapi_spec = OpenAPISpecification.from_str(content) + assert openapi_spec.spec_dict == { + "openapi": "3.0.0", + "info": {"title": "Test API", "version": "1.0.0"}, + "servers": [{"url": "https://api.example.com"}], + "paths": { + "/users": { + "get": {"summary": "Get all users", "responses": {"200": {"description": "Successful response"}}} + } + }, + } + + # can be initialized from a file + def test_initialized_from_file(self, tmp_path): + content = """ + openapi: 3.0.0 + info: + title: Test API + version: 1.0.0 + servers: + - url: https://api.example.com + paths: + /users: + get: + summary: Get all users + responses: + '200': + description: Successful response + """ + file_path = tmp_path / "spec.yaml" + file_path.write_text(content) + openapi_spec = OpenAPISpecification.from_file(file_path) + assert openapi_spec.spec_dict == { + "openapi": "3.0.0", + "info": {"title": "Test API", "version": "1.0.0"}, + "servers": [{"url": "https://api.example.com"}], + "paths": { + "/users": { + "get": {"summary": "Get all users", "responses": {"200": {"description": "Successful response"}}} + } + }, + } + + # can get all paths + def test_get_all_paths(self): + spec_dict = { + "openapi": "3.0.0", + "info": {"title": "Test API", "version": "1.0.0"}, + "servers": [{"url": "https://api.example.com"}], + "paths": {"/users": {}, "/products": {}, "/orders": {}}, + } + openapi_spec = OpenAPISpecification(spec_dict) + paths = openapi_spec.get_paths() + assert paths == {"/users": {}, "/products": {}, "/orders": {}} + + # raises ValueError if initialized from an invalid schema + def test_raises_value_error_invalid_schema(self): + spec_dict = {"info": {"title": "Test API", "version": "1.0.0"}, "paths": {"/users": {}}} + with pytest.raises(ValueError): + OpenAPISpecification(spec_dict) + + # Should return the raw OpenAPI specification dictionary with resolved references. + def test_return_raw_spec_with_resolved_references(self, test_files_path): + spec = OpenAPISpecification.from_file(test_files_path / "json" / "complex_types_openapi_service.json") + raw_spec = spec.to_dict(resolve_references=True) + + assert "$ref" not in str(raw_spec) + assert "#/" not in str(raw_spec) + + # verify that we can serialize the raw spec to a string + schema_ser = json.dumps(raw_spec, indent=2) + + # and that the serialized string does not contain any $ref or #/ references + assert "$ref" not in schema_ser + assert "#/" not in schema_ser + + # and that we can deserialize the serialized string back to a dictionary + schema = json.loads(schema_ser) + assert "$ref" not in schema + assert "#/" not in schema + + assert schema == raw_spec From 3fd99acd4d360cce823334d896c58ce9665770c5 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Wed, 5 Jun 2024 11:48:24 +0200 Subject: [PATCH 02/40] Refactoring step 1 --- .../components/connectors/__init__.py | 7 - .../components/connectors/openapi.py | 125 ----- .../components/converters/__init__.py | 7 - .../components/converters/openapi.py | 83 ---- .../components/tools/openapi}/__init__.py | 0 .../components/tools/openapi/openapi.py | 0 haystack_experimental/util/openapi.py | 431 +++++------------- test/components/connectors/test_openapi.py | 114 ----- test/components/converters/test_openapi.py | 246 ---------- test/util/conftest.py | 19 +- test/util/test_openapi_client.py | 26 +- test/util/test_openapi_client_auth.py | 58 +-- ...est_openapi_client_complex_request_body.py | 12 +- ...enapi_client_complex_request_body_mixed.py | 10 +- test/util/test_openapi_client_edge_cases.py | 15 +- .../test_openapi_client_error_handling.py | 11 +- test/util/test_openapi_client_live.py | 15 +- .../test_openapi_client_live_anthropic.py | 22 +- test/util/test_openapi_client_live_cohere.py | 18 +- test/util/test_openapi_client_live_openai.py | 21 +- test/util/test_openapi_spec.py | 12 - 21 files changed, 188 insertions(+), 1064 deletions(-) delete mode 100644 haystack_experimental/components/connectors/__init__.py delete mode 100644 haystack_experimental/components/connectors/openapi.py delete mode 100644 haystack_experimental/components/converters/__init__.py delete mode 100644 haystack_experimental/components/converters/openapi.py rename {test/components/connectors => haystack_experimental/components/tools/openapi}/__init__.py (100%) rename test/components/converters/__init__.py => haystack_experimental/components/tools/openapi/openapi.py (100%) delete mode 100644 test/components/connectors/test_openapi.py delete mode 100644 test/components/converters/test_openapi.py diff --git a/haystack_experimental/components/connectors/__init__.py b/haystack_experimental/components/connectors/__init__.py deleted file mode 100644 index 9c3d11ec..00000000 --- a/haystack_experimental/components/connectors/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -# SPDX-FileCopyrightText: 2022-present deepset GmbH -# -# SPDX-License-Identifier: Apache-2.0 - -from haystack_experimental.components.connectors.openapi import OpenAPIServiceConnector - -__all__ = ["OpenAPIServiceConnector"] diff --git a/haystack_experimental/components/connectors/openapi.py b/haystack_experimental/components/connectors/openapi.py deleted file mode 100644 index a9afce5c..00000000 --- a/haystack_experimental/components/connectors/openapi.py +++ /dev/null @@ -1,125 +0,0 @@ -# SPDX-FileCopyrightText: 2022-present deepset GmbH -# -# SPDX-License-Identifier: Apache-2.0 - -import json -from typing import Any, Dict, List, Optional, Union - -from haystack import component, logging -from haystack.dataclasses import ChatMessage, ChatRole - -from haystack_experimental.util.openapi import ClientConfigurationBuilder, OpenAPIServiceClient, validate_provider - -logger = logging.getLogger(__name__) - - -@component -class OpenAPIServiceConnector: - """ - The `OpenAPIServiceConnector` component connects the Haystack framework to OpenAPI services. - - It integrates with `ChatMessage` dataclass, where the payload in messages is used to determine the method to be - called and the parameters to be passed. The response from the service is returned as a `ChatMessage`. - - Function calling payloads from OpenAI, Anthropic, and Cohere LLMs are supported. - - Before using this component, users usually resolve function calling function definitions with a help of - `OpenAPIServiceToFunctions` component. - - The example below demonstrates how to use the `OpenAPIServiceConnector` to invoke a method on a - https://serper.dev/ service specified via OpenAPI specification. - - Note, however, that `OpenAPIServiceConnector` is usually not meant to be used directly, but rather as part of a - pipeline that includes the `OpenAPIServiceToFunctions` component and an `OpenAIChatGenerator` component using LLM - with the function calling capabilities. In the example below we use the function calling payload directly, but in a - real-world scenario, the function calling payload would usually be generated by the `OpenAIChatGenerator` - component. - - Usage example: - - ```python - import json - import requests - - from haystack_experimental.components.connectors import OpenAPIServiceConnector - from haystack.dataclasses import ChatMessage - - - fc_payload = [{'function': {'arguments': '{"q": "Why was Sam Altman ousted from OpenAI?"}', 'name': 'search'}, - 'id': 'call_PmEBYvZ7mGrQP5PUASA5m9wO', 'type': 'function'}] - - serper_token = - serperdev_openapi_spec = json.loads(requests.get("https://bit.ly/serper_dev_spec").text) - service_connector = OpenAPIServiceConnector() - result = service_connector.run(messages=[ChatMessage.from_assistant(json.dumps(fc_payload))], - service_openapi_spec=serperdev_openapi_spec, service_credentials=serper_token) - print(result) - - >> {'service_response': [ChatMessage(content='{"searchParameters": {"q": "Why was Sam Altman ousted from OpenAI?", - >> "type": "search", "engine": "google"}, "answerBox": {"snippet": "Concerns over AI safety and OpenAI\'s role - >> in protecting were at the center of Altman\'s brief ouster from the company."... - ``` - - """ - - def __init__(self, provider: Optional[str] = None): - """ - Initializes the OpenAPIServiceConnector instance. - """ - self.llm_provider = validate_provider(provider or "openai") - - @component.output_types(service_response=Dict[str, Any]) - def run( - self, - messages: List[ChatMessage], - service_openapi_spec: Dict[str, Any], - service_credentials: Optional[Union[dict, str]] = None, - ) -> Dict[str, List[ChatMessage]]: - """ - Processes a list of chat messages to invoke a method on an OpenAPI service. - - It parses the last message in the list, expecting it to contain an OpenAI function calling descriptor - (name & parameters) in JSON format. - - :param messages: A list of `ChatMessage` objects containing the messages to be processed. The last message - should contain the function invocation payload in OpenAI function calling format. See the example in the class - docstring for the expected format. - :param service_openapi_spec: The OpenAPI JSON specification object of the service to be invoked. - :param service_credentials: The credentials to be used for authentication with the service. - Currently, only the http and apiKey OpenAPI security schemes are supported. - - :return: A dictionary with the following keys: - - `service_response`: a list of `ChatMessage` objects, each containing the response from the service. The - response is in JSON format, and the `content` attribute of the `ChatMessage` - contains the JSON string. - - :raises ValueError: If the last message is not from the assistant or if it does not contain the correct - payload to invoke a method on the service. - """ - - last_message = messages[-1] - if not last_message.is_from(ChatRole.ASSISTANT): - raise ValueError(f"{last_message} is not from the assistant.") - if not last_message.content: - raise ValueError("Function calling message content is empty.") - - builder = ClientConfigurationBuilder() - config_openapi = ( - builder.with_openapi_spec(service_openapi_spec) - .with_credentials(service_credentials or {}) - .with_provider(self.llm_provider) - .build() - ) - logger.debug(f"Invoking service {config_openapi.get_openapi_spec().get_name()} with {last_message.content}") - openapi_service = OpenAPIServiceClient(config_openapi) - try: - payload = ( - json.loads(last_message.content) if isinstance(last_message.content, str) else last_message.content - ) - service_response = openapi_service.invoke(payload) - except Exception as e: # pylint: disable=broad-exception-caught - logger.error(f"Error invoking OpenAPI endpoint. Error: {e}") - service_response = {"error": str(e)} - response_messages = [ChatMessage.from_user(json.dumps(service_response))] - - return {"service_response": response_messages} diff --git a/haystack_experimental/components/converters/__init__.py b/haystack_experimental/components/converters/__init__.py deleted file mode 100644 index 6fd23f7d..00000000 --- a/haystack_experimental/components/converters/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -# SPDX-FileCopyrightText: 2022-present deepset GmbH -# -# SPDX-License-Identifier: Apache-2.0 - -from haystack_experimental.components.converters.openapi import OpenAPIServiceToFunctions - -__all__ = ["OpenAPIServiceToFunctions"] diff --git a/haystack_experimental/components/converters/openapi.py b/haystack_experimental/components/converters/openapi.py deleted file mode 100644 index 17d0eb77..00000000 --- a/haystack_experimental/components/converters/openapi.py +++ /dev/null @@ -1,83 +0,0 @@ -# SPDX-FileCopyrightText: 2022-present deepset GmbH -# -# SPDX-License-Identifier: Apache-2.0 - -from pathlib import Path -from typing import Any, Dict, List, Optional, Union - -from haystack import component, logging -from haystack.dataclasses.byte_stream import ByteStream - -from haystack_experimental.util.openapi import ClientConfigurationBuilder, validate_provider - -logger = logging.getLogger(__name__) - - -@component -class OpenAPIServiceToFunctions: - """ - Converts OpenAPI service schemas to a format suitable for OpenAI, Anthropic, or Cohere function calling. - - The definition must respect OpenAPI specification 3.0.0 or higher. - It can be specified in JSON or YAML format. - Each function must have: - - unique operationId - - description - - requestBody and/or parameters - - schema for the requestBody and/or parameters - For more details on OpenAPI specification see the - [official documentation](https://github.com/OAI/OpenAPI-Specification). - - Usage example: - ```python - from haystack_experimental.components.converters import OpenAPIServiceToFunctions - - converter = OpenAPIServiceToFunctions() - result = converter.run(sources=["path/to/openapi_definition.yaml"]) - assert result["functions"] - ``` - """ - - MIN_REQUIRED_OPENAPI_SPEC_VERSION = 3 - - def __init__(self, provider: Optional[str] = None): - """ - Create an OpenAPIServiceToFunctions component. - - :param provider: The LLM provider to use, defaults to "openai". - """ - self.llm_provider = validate_provider(provider or "openai") - - @component.output_types(functions=List[Dict[str, Any]], openapi_specs=List[Dict[str, Any]]) - def run(self, sources: List[Union[str, Path, ByteStream]]) -> Dict[str, Any]: - """ - Converts OpenAPI definitions into LLM specific function calling format. - - :param sources: - File paths or ByteStream objects of OpenAPI definitions (in JSON or YAML format). - - :returns: - A dictionary with the following keys: - - functions: Function definitions in JSON object format - - openapi_specs: OpenAPI specs in JSON/YAML object format with resolved references - - :raises RuntimeError: - If the OpenAPI definitions cannot be downloaded or processed. - :raises ValueError: - If the source type is not recognized or no functions are found in the OpenAPI definitions. - """ - all_extracted_fc_definitions: List[Dict[str, Any]] = [] - all_openapi_specs = [] - - builder = ClientConfigurationBuilder() - for source in sources: - source = source.to_string() if isinstance(source, ByteStream) else source - # to get tools definitions all we need is the openapi spec - config_openapi = builder.with_openapi_spec(source).with_provider(self.llm_provider).build() - - all_extracted_fc_definitions.extend(config_openapi.get_tools_definitions()) - all_openapi_specs.append(config_openapi.get_openapi_spec().to_dict(resolve_references=True)) - if not all_extracted_fc_definitions: - logger.warning("No OpenAI function definitions extracted from the provided OpenAPI specification sources.") - - return {"functions": all_extracted_fc_definitions, "openapi_specs": all_openapi_specs} diff --git a/test/components/connectors/__init__.py b/haystack_experimental/components/tools/openapi/__init__.py similarity index 100% rename from test/components/connectors/__init__.py rename to haystack_experimental/components/tools/openapi/__init__.py diff --git a/test/components/converters/__init__.py b/haystack_experimental/components/tools/openapi/openapi.py similarity index 100% rename from test/components/converters/__init__.py rename to haystack_experimental/components/tools/openapi/openapi.py diff --git a/haystack_experimental/util/openapi.py b/haystack_experimental/util/openapi.py index 552892ca..f6faad81 100644 --- a/haystack_experimental/util/openapi.py +++ b/haystack_experimental/util/openapi.py @@ -9,7 +9,7 @@ from base64 import b64encode from dataclasses import dataclass, field from pathlib import Path -from typing import Any, Callable, Dict, List, Literal, Optional, Protocol, Union, runtime_checkable +from typing import Any, Callable, Dict, List, Literal, Optional, Union from urllib.parse import urlparse import jsonref @@ -25,25 +25,11 @@ logger = logging.getLogger(__name__) -def validate_provider(provider: str) -> str: - """ - Check if the selected provider is supported. - - :param provider: The selected provider to validate. - :return: The validated provider. - :raises ValueError: If the selected provider is not supported. - """ - available_providers = ["openai", "anthropic", "cohere"] - if provider not in available_providers: - raise ValueError(f"LLM provider {provider} is not supported. Available providers: {available_providers}") - return provider - - -@runtime_checkable -class AuthenticationStrategy(Protocol): +class AuthenticationStrategy: """ Represents an authentication strategy that can be applied to an HTTP request. """ + def apply_auth(self, security_scheme: Dict[str, Any], request: Dict[str, Any]): """ Apply the authentication strategy to the given request. @@ -53,17 +39,10 @@ def apply_auth(self, security_scheme: Dict[str, Any], request: Dict[str, Any]): """ -class PassThroughAuthentication(AuthenticationStrategy): - """No-op authentication strategy that does nothing.""" - def apply_auth(self, security_scheme: Dict[str, Any], request: Dict[str, Any]): - """ - No-op authentication strategy that does nothing. - """ - - @dataclass class ApiKeyAuthentication(AuthenticationStrategy): - """ API key authentication strategy.""" + """API key authentication strategy.""" + api_key: Optional[str] = None def apply_auth(self, security_scheme: Dict[str, Any], request: Dict[str, Any]): @@ -88,7 +67,8 @@ def apply_auth(self, security_scheme: Dict[str, Any], request: Dict[str, Any]): @dataclass class HTTPAuthentication(AuthenticationStrategy): - """ HTTP authentication strategy.""" + """HTTP authentication strategy.""" + username: Optional[str] = None password: Optional[str] = None token: Optional[str] = None @@ -126,7 +106,8 @@ def apply_auth(self, security_scheme: Dict[str, Any], request: Dict[str, Any]): @dataclass class HttpClientConfig: - """ Configuration for the HTTP client. """ + """Configuration for the HTTP client.""" + timeout: int = 10 max_retries: int = 3 backoff_factor: float = 0.3 @@ -135,7 +116,8 @@ class HttpClientConfig: class HttpClient: - """ HTTP client for sending requests. """ + """HTTP client for sending requests.""" + def __init__(self, config: Optional[HttpClientConfig] = None): self.config = config or HttpClientConfig() self.session = requests.Session() @@ -159,13 +141,16 @@ def send_request(self, request: Dict[str, Any]) -> Any: :param request: A dictionary containing the request details. """ url = request["url"] - method = request["method"] headers = {**self.config.default_headers, **request.get("headers", {})} - params = request.get("params", {}) - json_data = request.get("json") - auth = request.get("auth") try: - response = self.session.request(method, url, headers=headers, params=params, json=json_data, auth=auth) + response = self.session.request( + request["method"], + request["url"], + headers=headers, + params=request.get("params", {}), + json=request.get("json"), + auth=request.get("auth"), + ) response.raise_for_status() return response.json() except requests.exceptions.HTTPError as e: @@ -183,70 +168,46 @@ class HttpClientError(Exception): """Exception raised for errors in the HTTP client.""" +@dataclass class Operation: - """ Represents an operation in an OpenAPI specification.""" - def __init__(self, path: str, method: str, operation_dict: Dict[str, Any], spec_dict: Dict[str, Any]): - if method.lower() not in VALID_HTTP_METHODS: - raise ValueError(f"Invalid HTTP method: {method}") - self.path = path - self.method = method.lower() - self.operation_dict = operation_dict - self.spec_dict = spec_dict + """Represents an operation in an OpenAPI specification.""" + path: str + method: str + operation_dict: Dict[str, Any] + spec_dict: Dict[str, Any] + security_requirements: List[Dict[str, List[str]]] = field(init=False) + request_body: Dict[str, Any] = field(init=False) + parameters: List[Dict[str, Any]] = field(init=False) + + def __post_init__(self): + if self.method.lower() not in VALID_HTTP_METHODS: + raise ValueError(f"Invalid HTTP method: {self.method}") + self.method = self.method.lower() + self.security_requirements = self.operation_dict.get("security", []) or self.spec_dict.get("security", []) + self.request_body = self.operation_dict.get("requestBody", {}) + self.parameters = self.operation_dict.get("parameters", []) + self.spec_dict.get("paths", {}).get( + self.path, {} + ).get("parameters", []) def get_parameters(self, location: Optional[Literal["header", "query", "path"]] = None) -> List[Dict[str, Any]]: """ Get the parameters for the operation. - - :param location: The location of the parameters to retrieve. If None, all parameters are returned. """ - parameters = self.operation_dict.get("parameters", []) - path_item = self.spec_dict.get("paths", {}).get(self.path, {}) - parameters.extend(path_item.get("parameters", [])) if location: - return [param for param in parameters if param["in"] == location] - return parameters - - def get_request_body(self) -> Dict[str, Any]: - """ - Get the request body for the operation. - """ - return self.operation_dict.get("requestBody", {}) - - def get_responses(self) -> Dict[str, Any]: - """ - Get the responses for the operation. - """ - return self.operation_dict.get("responses", {}) - - def get_security_requirements(self) -> List[Dict[str, List[str]]]: - """ - Get the security requirements for the operation. - """ - security_requirements = self.operation_dict.get("security", []) - if not security_requirements: - security_requirements = self.spec_dict.get("security", []) - return security_requirements - - def get_server_url(self) -> str: - """ - Get the server URL for the operation. - """ - servers = self.operation_dict.get("servers", []) - if not servers: - servers = self.spec_dict.get("servers", []) - if servers: - return servers[0].get("url", "") - return "" + return [param for param in self.parameters if param["in"] == location] + return self.parameters - def get_field(self, key: str, default: Any = None) -> Any: + def get_server(self) -> str: """ - Get a field from the operation dictionary. + Get the servers for the operation. """ - return self.operation_dict.get(key, default) + servers = self.operation_dict.get("servers", []) or self.spec_dict.get("servers", []) + return servers[0].get("url", "") # just use the first server from the list class OpenAPISpecification: - """ Represents an OpenAPI specification.""" + """Represents an OpenAPI specification.""" + def __init__(self, spec_dict: Dict[str, Any]): if not isinstance(spec_dict, Dict): raise ValueError(f"Invalid OpenAPI specification, expected a dictionary: {spec_dict}") @@ -301,41 +262,13 @@ def from_url(cls, url: str) -> "OpenAPISpecification": raise ConnectionError(f"Failed to fetch the specification from URL: {url}. {e!s}") from e return cls.from_str(content) - def get_name(self) -> str: - """ - Get the title of the OpenAPI specification. - """ - return self.spec_dict.get("info", {}).get("title", "") - - def get_paths(self) -> Dict[str, Dict[str, Any]]: - """ - Get the paths from the OpenAPI specification. - """ - return self.spec_dict.get("paths", {}) - - def get_operation(self, path: str, method: Optional[str] = None) -> Operation: - """ - Retrieve an operation from the OpenAPI specification. - """ - path_item = self.get_paths().get(path, {}) - return self.get_operation_item(path, path_item, method) - - def find_operation_by_path_substring(self, path_partial: str, method: Optional[str] = None) -> Operation: - """ - Find an operation by a substring of the path. - """ - for path, path_item in self.get_paths().items(): - if path_partial in path: - return self.get_operation_item(path, path_item, method) - raise ValueError(f"No operation found with path containing {path_partial}") - def find_operation_by_id(self, op_id: str, method: Optional[str] = None) -> Operation: """ Find an operation by operationId. """ - for path, path_item in self.get_paths().items(): + for path, path_item in self.spec_dict.get("paths", {}).items(): op: Operation = self.get_operation_item(path, path_item, method) - if op_id in op.get_field("operationId", ""): + if op_id in op.operation_dict.get("operationId", ""): return self.get_operation_item(path, path_item, method) raise ValueError(f"No operation found with operationId {op_id}") @@ -359,17 +292,6 @@ def get_operation_item(self, path: str, path_item: Dict[str, Any], method: Optio raise ValueError(f"Multiple operations found at path {path}, method parameter is required.") raise ValueError(f"No operations found at path {path}.") - def get_operations(self) -> List[Operation]: - """ - Get all operations from the OpenAPI specification. - """ - operations = [] - for path, path_item in self.get_paths().items(): - for method, operation_dict in path_item.items(): - if method.lower() in VALID_HTTP_METHODS: - operations.append(Operation(path, method, operation_dict, self.spec_dict)) - return operations - def get_security_schemes(self) -> Dict[str, Dict[str, Any]]: """ Get the security schemes from the OpenAPI specification. @@ -391,15 +313,14 @@ def to_dict(self, *, resolve_references: Optional[bool] = False) -> Dict[str, An class ClientConfiguration: - """ Configuration for the OpenAPI client. """ + """Configuration for the OpenAPI client.""" def __init__( # noqa: PLR0913 pylint: disable=too-many-arguments - self, - openapi_spec: Union[str, Path, Dict[str, Any]], - credentials: Optional[Union[str, Dict[str, Any], AuthenticationStrategy]] = None, - http_client: Optional[HttpClient] = None, - http_client_config: Optional[HttpClientConfig] = None, - llm_provider: Optional[str] = None, + self, + openapi_spec: Union[str, Path, Dict[str, Any]], + credentials: Optional[Union[str, Dict[str, Any], AuthenticationStrategy]] = None, + http_client: Optional[HttpClient] = None, + llm_provider: Optional[str] = None, ): # noqa: PLR0913 if isinstance(openapi_spec, (str, Path)) and os.path.isfile(openapi_spec): self.openapi_spec = OpenAPISpecification.from_file(openapi_spec) @@ -414,34 +335,15 @@ def __init__( # noqa: PLR0913 pylint: disable=too-many-arguments raise ValueError("Invalid OpenAPI specification format. Expected file path or dictionary.") self.credentials = credentials - self.http_client = http_client or HttpClient(http_client_config) - self.http_client_config = http_client_config or HttpClientConfig() + self.http_client = http_client or HttpClient(HttpClientConfig()) self.llm_provider = llm_provider or "openai" - def get_openapi_spec(self) -> OpenAPISpecification: - """ - Get the OpenAPI specification. - """ - return self.openapi_spec - - def get_http_client(self) -> HttpClient: - """ - Get the HTTP client. - """ - return self.http_client - - def get_http_client_config(self) -> HttpClientConfig: - """ - Get the HTTP client configuration. - """ - return self.http_client_config - def get_auth_config(self) -> AuthenticationStrategy: """ Get the authentication configuration. """ if not self.credentials: - return PassThroughAuthentication() + return AuthenticationStrategy() if isinstance(self.credentials, AuthenticationStrategy): return self.credentials security_schemes = self.openapi_spec.get_security_schemes() @@ -469,7 +371,7 @@ def get_payload_extractor(self): return LLMFunctionPayloadExtractor(arguments_field_name=arguments_field_name) def _create_authentication_from_string( - self, credentials: str, security_schemes: Dict[str, Any] + self, credentials: str, security_schemes: Dict[str, Any] ) -> AuthenticationStrategy: for scheme in security_schemes.values(): if scheme["type"] == "apiKey": @@ -501,6 +403,7 @@ class LLMFunctionPayloadExtractor: """ Implements a recursive search for extracting LLM generated function payloads. """ + def __init__(self, arguments_field_name: str): self.arguments_field_name = arguments_field_name @@ -547,7 +450,7 @@ def _search(self, payload: Any) -> Dict[str, Any]: return {} def _get_dict_converter( - self, obj: Any, method_names: Optional[List[str]] = None + self, obj: Any, method_names: Optional[List[str]] = None ) -> Union[Callable[[], Dict[str, Any]], None]: method_names = method_names or ["model_dump", "dict"] # search for pydantic v2 then v1 for attr in method_names: @@ -559,120 +462,50 @@ def _is_primitive(self, obj) -> bool: return isinstance(obj, (int, float, str, bool, type(None))) -class ClientConfigurationBuilder: +class OpenAPIServiceClient: """ - ClientConfigurationBuilder provides a fluent interface for constructing a `ClientConfiguration`. + A client for invoking operations on REST services defined by OpenAPI specifications. - This builder allows for the step-by-step configuration of all necessary components to interact with an - API defined by an OpenAPI specification. + Together with the `ClientConfiguration`, its `ClientConfigurationBuilder`, the `OpenAPIServiceClient` + simplifies the process of (LLMs) with services defined by OpenAPI specifications. """ - def __init__(self): - self._openapi_spec: Union[str, Path, Dict[str, Any], None] = None - self._credentials: Optional[Union[str, Dict[str, Any], AuthenticationStrategy]] = None - self._http_client: Optional[HttpClient] = None - self._http_client_config: Optional[HttpClientConfig] = None - self._llm_provider: Optional[str] = None - - def with_openapi_spec(self, openapi_spec: Union[str, Path, Dict[str, Any]]) -> "ClientConfigurationBuilder": - """ - Sets the OpenAPI specification for the configuration. - - :param openapi_spec: The OpenAPI specification as a URL, file path, or dictionary. - :return: The instance of this builder to allow for method chaining. - """ - self._openapi_spec = openapi_spec - return self - - def with_credentials( - self, credentials: Union[str, Dict[str, Any], AuthenticationStrategy] - ) -> "ClientConfigurationBuilder": - """ - Specifies the credentials used for authenticating requests made by the client. - - :param credentials: Credentials as a string, dictionary, or an AuthenticationStrategy instance. - :return: The instance of this builder to allow for method chaining. - """ - self._credentials = credentials - return self - - def with_http_client(self, http_client: HttpClient) -> "ClientConfigurationBuilder": - """ - Specifies the HTTP client to be used for making API calls. - - :param http_client: The HTTP client implementation. - :return: The instance of this builder to allow for method chaining. - """ - self._http_client = http_client - return self - - def with_http_client_config(self, http_client_config: HttpClientConfig) -> "ClientConfigurationBuilder": - """ - Specifies the HTTP client configuration. - - If not set, the default configuration is used. - - :param http_client_config: Configuration settings for the HTTP client. - :return: The instance of this builder to allow for method chaining. - """ - self._http_client_config = http_client_config - return self - - def with_provider(self, llm_provider: str) -> "ClientConfigurationBuilder": - """ - Specifies the Large Language Model (LLM) provider to be used for generating function calls. - - :param llm_provider: The LLM provider name. - :return: The instance of this builder to allow for method chaining. - """ - self._llm_provider = llm_provider - return self + def __init__(self, client_config: ClientConfiguration): + self.auth_config = client_config.get_auth_config() + self.openapi_spec = client_config.openapi_spec + self.http_client = client_config.http_client + self.payload_extractor = client_config.get_payload_extractor() - def build(self) -> ClientConfiguration: + def invoke(self, function_payload: Any) -> Any: """ - Constructs a `ClientConfiguration` instance using the settings provided. - - It validates that an OpenAPI specification has been set before proceeding with the build. + Invokes a function specified in the function payload. - :return: A configured instance of ClientConfiguration. - :raises ValueError: If the OpenAPI specification is not set. + :param function_payload: The function payload containing the details of the function to be invoked. + :returns: The response from the service after invoking the function. + :raises OpenAPIClientError: If the function invocation payload cannot be extracted from the function payload. + :raises HttpClientError: If an error occurs while sending the request and receiving the response. """ - if self._openapi_spec is None: - raise ValueError("OpenAPI specification must be provided to build a configuration.") - - return ClientConfiguration( - openapi_spec=self._openapi_spec, - credentials=self._credentials, - http_client=self._http_client, - http_client_config=self._http_client_config, - llm_provider=self._llm_provider or "openai", - ) - - -class RequestBuilder: - """ Builds an HTTP request based on an OpenAPI operation""" - def __init__(self, client_config: ClientConfiguration): - self.openapi_parser = client_config.get_openapi_spec() - self.http_client = client_config.get_http_client() - self.auth_config = client_config.get_auth_config() or PassThroughAuthentication() + fn_invocation_payload = self.payload_extractor.extract_function_invocation(function_payload) + if not fn_invocation_payload: + raise OpenAPIClientError( + f"Failed to extract function invocation payload from {function_payload} using " + f"{self.payload_extractor.__class__.__name__}. Ensure the payload format matches the expected " + "structure for the designated LLM extractor." + ) + # fn_invocation_payload, if not empty, guaranteed to have "name" and "arguments" keys from here on + operation = self.openapi_spec.find_operation_by_id(fn_invocation_payload.get("name")) + request = self._build_request(operation, **fn_invocation_payload.get("arguments")) + self._apply_authentication(self.auth_config, operation, request) + return self.http_client.send_request(request) - def build_request(self, operation: Operation, **kwargs) -> Any: - """ - Build an HTTP request based on the operation and arguments provided. - """ - url = self._build_url(operation, **kwargs) - method = operation.method.lower() - headers = self._build_headers(operation) - query_params = self._build_query_params(operation, **kwargs) - body = self._build_request_body(operation, **kwargs) + def _build_request(self, operation: Operation, **kwargs) -> Any: request = { - "url": url, - "method": method, - "headers": headers, - "params": query_params, - "json": body, + "url": self._build_url(operation, **kwargs), + "method": operation.method.lower(), + "headers": self._build_headers(operation, **kwargs), + "params": self._build_query_params(operation, **kwargs), + "json": self._build_request_body(operation, **kwargs), } - self._apply_authentication(operation, request) return request def _build_headers(self, operation: Operation, **kwargs) -> Dict[str, str]: @@ -686,7 +519,7 @@ def _build_headers(self, operation: Operation, **kwargs) -> Dict[str, str]: return headers def _build_url(self, operation: Operation, **kwargs) -> str: - server_url = operation.get_server_url() + server_url = operation.get_server() path = operation.path for parameter in operation.get_parameters("path"): param_value = kwargs.get(parameter["name"], None) @@ -698,8 +531,6 @@ def _build_url(self, operation: Operation, **kwargs) -> str: def _build_query_params(self, operation: Operation, **kwargs) -> Dict[str, Any]: query_params = {} - - # Simplify query parameter assembly using _get_parameter_value for parameter in operation.get_parameters("query"): param_value = kwargs.get(parameter["name"], None) if param_value: @@ -709,7 +540,7 @@ def _build_query_params(self, operation: Operation, **kwargs) -> Dict[str, Any]: return query_params def _build_request_body(self, operation: Operation, **kwargs) -> Any: - request_body = operation.get_request_body() + request_body = operation.request_body if request_body: content = request_body.get("content", {}) if "application/json" in content: @@ -717,56 +548,19 @@ def _build_request_body(self, operation: Operation, **kwargs) -> Any: raise NotImplementedError("Request body content type not supported") return None - def _apply_authentication(self, operation: Operation, request: Dict[str, Any]): - # security requirements specify which authentication scheme to apply (the "what/which") - security_requirements = operation.get_security_requirements() - # security schemes define how to authenticate (the "how") + def _apply_authentication(self, auth: AuthenticationStrategy, operation: Operation, request: Dict[str, Any]): + auth_config = auth or AuthenticationStrategy() + security_requirements = operation.security_requirements security_schemes = operation.spec_dict.get("components", {}).get("securitySchemes", {}) if security_requirements: for requirement in security_requirements: for scheme_name in requirement: if scheme_name in security_schemes: security_scheme = security_schemes[scheme_name] - self.auth_config.apply_auth(security_scheme, request) + auth_config.apply_auth(security_scheme, request) break -class OpenAPIServiceClient: - """ - A client for invoking operations on REST services defined by OpenAPI specifications. - - Together with the `ClientConfiguration`, its `ClientConfigurationBuilder`, the `OpenAPIServiceClient` - simplifies the process of (LLMs) with services defined by OpenAPI specifications. - """ - - def __init__(self, client_config: ClientConfiguration): - self.openapi_spec = client_config.get_openapi_spec() - self.http_client = client_config.get_http_client() - self.request_builder = RequestBuilder(client_config) - self.payload_extractor = client_config.get_payload_extractor() - - def invoke(self, function_payload: Any) -> Any: - """ - Invokes a function specified in the function payload. - - :param function_payload: The function payload containing the details of the function to be invoked. - :returns: The response from the service after invoking the function. - :raises OpenAPIClientError: If the function invocation payload cannot be extracted from the function payload. - :raises HttpClientError: If an error occurs while sending the request and receiving the response. - """ - fn_invocation_payload = self.payload_extractor.extract_function_invocation(function_payload) - if not fn_invocation_payload: - raise OpenAPIClientError( - f"Failed to extract function invocation payload from {function_payload} using " - f"{self.payload_extractor.__class__.__name__}. Ensure the payload format matches the expected " - "structure for the designated LLM extractor." - ) - # fn_invocation_payload, if not empty, guaranteed to have "name" and "arguments" keys from here on - operation = self.openapi_spec.find_operation_by_id(fn_invocation_payload.get("name")) - request = self.request_builder.build_request(operation, **fn_invocation_payload.get("arguments")) - return self.http_client.send_request(request) - - class OpenAPIClientError(Exception): """Exception raised for errors in the OpenAPI client.""" @@ -805,8 +599,11 @@ def cohere_converter(schema: OpenAPISpecification) -> List[Dict[str, Any]]: return _openapi_to_functions(resolved_schema, "not important for cohere", _parse_endpoint_spec_cohere) -def _openapi_to_functions(service_openapi_spec: Dict[str, Any], parameters_name: str, - parse_endpoint_fn: Callable[[Dict[str, Any], str], Dict[str, Any]]) -> List[Dict[str, Any]]: +def _openapi_to_functions( + service_openapi_spec: Dict[str, Any], + parameters_name: str, + parse_endpoint_fn: Callable[[Dict[str, Any], str], Dict[str, Any]], +) -> List[Dict[str, Any]]: """ Extracts functions from the OpenAPI specification, converts them into a function schema. """ @@ -869,8 +666,9 @@ def _parse_endpoint_spec_openai(resolved_spec: Dict[str, Any], parameters_name: return {} -def _parse_property_attributes(property_schema: Dict[str, Any], include_attributes: Optional[List[str]] = None - ) -> Dict[str, Any]: +def _parse_property_attributes( + property_schema: Dict[str, Any], include_attributes: Optional[List[str]] = None +) -> Dict[str, Any]: """ Recursively parses the attributes of a property schema. """ @@ -883,8 +681,7 @@ def _parse_property_attributes(property_schema: Dict[str, Any], include_attribut if schema_type == "object": properties = property_schema.get("properties", {}) parsed_properties = { - prop_name: _parse_property_attributes(prop, include_attributes) - for prop_name, prop in properties.items() + prop_name: _parse_property_attributes(prop, include_attributes) for prop_name, prop in properties.items() } parsed_schema["properties"] = parsed_properties if "required" in property_schema: @@ -928,9 +725,7 @@ def _parse_parameters(operation: Dict[str, Any]) -> Dict[str, Any]: schema_properties = content["schema"].get("properties", {}) required_properties = content["schema"].get("required", []) for name, schema in schema_properties.items(): - parameters[name] = _parse_schema( - schema, name in required_properties, schema.get("description", "") - ) + parameters[name] = _parse_schema(schema, name in required_properties, schema.get("description", "")) return parameters @@ -960,8 +755,14 @@ def _parse_schema(schema: Dict[str, Any], required: bool, description: str) -> D def _get_type(schema: Dict[str, Any]) -> str: - type_mapping = {"integer": "int", "string": "str", "boolean": "bool", "number": "float", "object": "object", - "array": "list"} + type_mapping = { + "integer": "int", + "string": "str", + "boolean": "bool", + "number": "float", + "object": "object", + "array": "list", + } schema_type = schema.get("type", "object") if schema_type not in type_mapping: raise ValueError(f"Unsupported schema type {schema_type}") diff --git a/test/components/connectors/test_openapi.py b/test/components/connectors/test_openapi.py deleted file mode 100644 index f9b5e413..00000000 --- a/test/components/connectors/test_openapi.py +++ /dev/null @@ -1,114 +0,0 @@ -# SPDX-FileCopyrightText: 2022-present deepset GmbH -# -# SPDX-License-Identifier: Apache-2.0 -import json -import os -from unittest.mock import patch - -import pytest - -from haystack_experimental.components.connectors import OpenAPIServiceConnector -from haystack.dataclasses import ChatMessage - - -class TestOpenAPIServiceConnector: - @pytest.fixture - def setup_mock(self): - with patch("haystack_experimental.components.connectors.openapi.OpenAPIServiceClient") as mock_client: - mock_client_instance = mock_client.return_value - mock_client_instance.invoke.return_value = {"service_response": "Yes, he was fired and rehired"} - yield mock_client_instance - - def test_init(self): - service_connector = OpenAPIServiceConnector() - assert service_connector is not None - assert service_connector.llm_provider == "openai" - - def test_init_with_anthropic_provider(self): - service_connector = OpenAPIServiceConnector(provider="anthropic") - assert service_connector is not None - assert service_connector.llm_provider == "anthropic" - - def test_run_with_mock(self, setup_mock, test_files_path): - fc_payload = [ - { - "function": {"arguments": '{"q": "Why was Sam Altman ousted from OpenAI?"}', "name": "search"}, - "id": "call_PmEBYvZ7mGrQP5PUASA5m9wO", - "type": "function", - } - ] - with open(os.path.join(test_files_path, "json/serperdev_openapi_spec.json"), "r") as file: - serperdev_openapi_spec = json.load(file) - - service_connector = OpenAPIServiceConnector() - result = service_connector.run( - messages=[ChatMessage.from_assistant(json.dumps(fc_payload))], - service_openapi_spec=serperdev_openapi_spec, - service_credentials="fake_api_key", - ) - - assert "service_response" in result - assert len(result["service_response"]) == 1 - assert isinstance(result["service_response"][0], ChatMessage) - response_content = json.loads(result["service_response"][0].content) - assert response_content == {"service_response": "Yes, he was fired and rehired"} - - # verify invocation payload - setup_mock.invoke.assert_called_once() - invocation_payload = [ - { - "function": {"arguments": '{"q": "Why was Sam Altman ousted from OpenAI?"}', "name": "search"}, - "id": "call_PmEBYvZ7mGrQP5PUASA5m9wO", - "type": "function", - } - ] - setup_mock.invoke.assert_called_with(invocation_payload) - - @pytest.mark.integration - @pytest.mark.skipif("SERPERDEV_API_KEY" not in os.environ, reason="SerperDev API key is not available") - def test_run(self, test_files_path): - fc_payload = [ - { - "function": {"arguments": '{"q": "Why was Sam Altman ousted from OpenAI?"}', "name": "search"}, - "id": "call_PmEBYvZ7mGrQP5PUASA5m9wO", - "type": "function", - } - ] - - with open(os.path.join(test_files_path, "json/serperdev_openapi_spec.json"), "r") as file: - serperdev_openapi_spec = json.load(file) - - service_connector = OpenAPIServiceConnector() - result = service_connector.run( - messages=[ChatMessage.from_assistant(json.dumps(fc_payload))], - service_openapi_spec=serperdev_openapi_spec, - service_credentials=os.environ["SERPERDEV_API_KEY"], - ) - assert "service_response" in result - assert len(result["service_response"]) == 1 - assert isinstance(result["service_response"][0], ChatMessage) - response_text = result["service_response"][0].content - assert "Sam" in response_text or "Altman" in response_text - - @pytest.mark.integration - def test_run_no_credentials(self, test_files_path): - fc_payload = [ - { - "function": {"arguments": '{"q": "Why was Sam Altman ousted from OpenAI?"}', "name": "search"}, - "id": "call_PmEBYvZ7mGrQP5PUASA5m9wO", - "type": "function", - } - ] - - with open(os.path.join(test_files_path, "json/serperdev_openapi_spec.json"), "r") as file: - serperdev_openapi_spec = json.load(file) - - service_connector = OpenAPIServiceConnector() - result = service_connector.run( - messages=[ChatMessage.from_assistant(json.dumps(fc_payload))], service_openapi_spec=serperdev_openapi_spec - ) - assert "service_response" in result - assert len(result["service_response"]) == 1 - assert isinstance(result["service_response"][0], ChatMessage) - response_text = result["service_response"][0].content - assert "403" in response_text diff --git a/test/components/converters/test_openapi.py b/test/components/converters/test_openapi.py deleted file mode 100644 index fabb4624..00000000 --- a/test/components/converters/test_openapi.py +++ /dev/null @@ -1,246 +0,0 @@ -# SPDX-FileCopyrightText: 2022-present deepset GmbH -# -# SPDX-License-Identifier: Apache-2.0 -import json -import sys -import tempfile - -import pytest - -from haystack_experimental.components.converters import OpenAPIServiceToFunctions -from haystack.dataclasses import ByteStream - - -@pytest.fixture -def json_serperdev_openapi_spec(): - serper_spec = """ - { - "openapi": "3.0.0", - "info": { - "title": "SerperDev", - "version": "1.0.0", - "description": "API for performing search queries" - }, - "servers": [ - { - "url": "https://google.serper.dev" - } - ], - "paths": { - "/search": { - "post": { - "operationId": "search", - "description": "Search the web with Google", - "requestBody": { - "required": true, - "content": { - "application/json": { - "schema": { - "type": "object", - "properties": { - "q": { - "type": "string" - } - } - } - } - } - }, - "responses": { - "200": { - "description": "Successful response", - "content": { - "application/json": { - "schema": { - "type": "object", - "properties": { - "searchParameters": { - "type": "undefined" - }, - "knowledgeGraph": { - "type": "undefined" - }, - "answerBox": { - "type": "undefined" - }, - "organic": { - "type": "undefined" - }, - "topStories": { - "type": "undefined" - }, - "peopleAlsoAsk": { - "type": "undefined" - }, - "relatedSearches": { - "type": "undefined" - } - } - } - } - } - } - }, - "security": [ - { - "apikey": [] - } - ] - } - } - }, - "components": { - "securitySchemes": { - "apikey": { - "type": "apiKey", - "name": "x-api-key", - "in": "header" - } - } - } - } - """ - return serper_spec - - -@pytest.fixture -def yaml_serperdev_openapi_spec(): - serper_spec = """ - openapi: 3.0.0 - info: - title: SerperDev - version: 1.0.0 - description: API for performing search queries - servers: - - url: 'https://google.serper.dev' - paths: - /search: - post: - operationId: search - description: Search the web with Google - requestBody: - required: true - content: - application/json: - schema: - type: object - properties: - q: - type: string - responses: - '200': - description: Successful response - content: - application/json: - schema: - type: object - properties: - searchParameters: - type: undefined - knowledgeGraph: - type: undefined - answerBox: - type: undefined - organic: - type: undefined - topStories: - type: undefined - peopleAlsoAsk: - type: undefined - relatedSearches: - type: undefined - security: - - apikey: [] - components: - securitySchemes: - apikey: - type: apiKey - name: x-api-key - in: header - """ - return serper_spec - - -@pytest.fixture -def fn_definition_transform(): - return lambda function_def: {"type": "function", "function": function_def} - - -class TestOpenAPIServiceToFunctions: - # test we can extract functions from openapi spec given - def test_run_with_bytestream_source(self, json_serperdev_openapi_spec, fn_definition_transform): - service = OpenAPIServiceToFunctions() - spec_stream = ByteStream.from_string(json_serperdev_openapi_spec) - result = service.run(sources=[spec_stream]) - assert len(result["functions"]) == 1 - fc = result["functions"][0] - - # check that fc definition is as expected - assert fc == fn_definition_transform( - { - "name": "search", - "description": "Search the web with Google", - "parameters": {"type": "object", "properties": {"q": {"type": "string"}}}, - } - ) - - @pytest.mark.skipif( - sys.platform in ["win32", "cygwin"], - reason="Can't run on Windows Github CI, need access temp file but windows does not allow it", - ) - def test_run_with_file_source(self, json_serperdev_openapi_spec, fn_definition_transform): - # test we can extract functions from openapi spec given in file - service = OpenAPIServiceToFunctions() - # write the spec to NamedTemporaryFile and check that it is parsed correctly - with tempfile.NamedTemporaryFile() as tmp: - tmp.write(json_serperdev_openapi_spec.encode("utf-8")) - tmp.seek(0) - result = service.run(sources=[tmp.name]) - assert len(result["functions"]) == 1 - fc = result["functions"][0] - - # check that fc definition is as expected - assert fc == fn_definition_transform( - { - "name": "search", - "description": "Search the web with Google", - "parameters": {"type": "object", "properties": {"q": {"type": "string"}}}, - } - ) - - def test_run_with_invalid_bytestream_source(self, caplog): - # test invalid source - service = OpenAPIServiceToFunctions() - with pytest.raises(ValueError, match="Invalid OpenAPI specification"): - service.run(sources=[ByteStream.from_string("")]) - - def test_complex_types_conversion(self, test_files_path, fn_definition_transform): - # ensure that complex types from OpenAPI spec are converted to the expected format in OpenAI function calling - service = OpenAPIServiceToFunctions() - result = service.run(sources=[test_files_path / "json" / "complex_types_openapi_service.json"]) - assert len(result["functions"]) == 1 - - with open(test_files_path / "json" / "complex_types_openai_spec.json") as openai_spec_file: - desired_output = json.load(openai_spec_file) - assert result["functions"][0] == fn_definition_transform(desired_output) - - def test_simple_and_complex_at_once(self, test_files_path, json_serperdev_openapi_spec, fn_definition_transform): - # ensure multiple functions are extracted from multiple paths in OpenAPI spec - service = OpenAPIServiceToFunctions() - sources = [ - ByteStream.from_string(json_serperdev_openapi_spec), - test_files_path / "json" / "complex_types_openapi_service.json", - ] - result = service.run(sources=sources) - assert len(result["functions"]) == 2 - - with open(test_files_path / "json" / "complex_types_openai_spec.json") as openai_spec_file: - desired_output = json.load(openai_spec_file) - assert result["functions"][0] == fn_definition_transform( - { - "name": "search", - "description": "Search the web with Google", - "parameters": {"type": "object", "properties": {"q": {"type": "string"}}}, - } - ) - assert result["functions"][1] == fn_definition_transform(desired_output) diff --git a/test/util/conftest.py b/test/util/conftest.py index bb575703..39caa850 100644 --- a/test/util/conftest.py +++ b/test/util/conftest.py @@ -32,25 +32,18 @@ def strip_host(self, url: str) -> str: return new_path def send_request(self, request: dict) -> dict: - method = request["method"] # OAS spec will list a server URL, but FastAPI doesn't need it for local testing, in fact it will fail # if the URL has a host. So we strip it here. url = self.strip_host(request["url"]) - headers = request.get("headers", {}) - params = request.get("params", {}) - json_data = request.get("json", None) - auth = request.get("auth", None) - cookies = request.get("cookies", {}) - try: response = self.client.request( - method, + request["method"], url, - headers=headers, - params=params, - json=json_data, - auth=auth, - cookies=cookies, + headers=request.get("headers", {}), + params=request.get("params", {}), + json=request.get("json", None), + auth=request.get("auth", None), + cookies=request.get("cookies", {}), ) response.raise_for_status() return response.json() diff --git a/test/util/test_openapi_client.py b/test/util/test_openapi_client.py index 581dfdf8..3769bd0c 100644 --- a/test/util/test_openapi_client.py +++ b/test/util/test_openapi_client.py @@ -7,7 +7,7 @@ from fastapi.responses import JSONResponse from pydantic import BaseModel -from haystack_experimental.util.openapi import ClientConfigurationBuilder, OpenAPIServiceClient +from haystack_experimental.util.openapi import OpenAPIServiceClient, ClientConfiguration from test.util.conftest import FastAPITestClient """ @@ -72,12 +72,8 @@ def greet_request_body(body: GreetBody): class TestOpenAPI: def test_greet_mix_params_body(self, test_files_path): - builder = ClientConfigurationBuilder() - config = ( - builder.with_openapi_spec(test_files_path / "yaml" / "openapi_greeting_service.yml") - .with_http_client(FastAPITestClient(create_greet_mix_params_body_app())) - .build() - ) + config = ClientConfiguration(openapi_spec=test_files_path / "yaml" / "openapi_greeting_service.yml", + http_client=FastAPITestClient(create_greet_mix_params_body_app())) client = OpenAPIServiceClient(config) payload = { "id": "call_NJr1NBz2Th7iUWJpRIJZoJIA", @@ -91,12 +87,8 @@ def test_greet_mix_params_body(self, test_files_path): assert response == {"greeting": "Bonjour, John from mix_params_body!"} def test_greet_params_only(self, test_files_path): - builder = ClientConfigurationBuilder() - config = ( - builder.with_openapi_spec(test_files_path / "yaml" / "openapi_greeting_service.yml") - .with_http_client(FastAPITestClient(create_greet_params_only_app())) - .build() - ) + config = ClientConfiguration(openapi_spec=test_files_path / "yaml" / "openapi_greeting_service.yml", + http_client=FastAPITestClient(create_greet_params_only_app())) client = OpenAPIServiceClient(config) payload = { "id": "call_NJr1NBz2Th7iUWJpRIJZoJIA", @@ -110,12 +102,8 @@ def test_greet_params_only(self, test_files_path): assert response == {"greeting": "Hello, John from params_only!"} def test_greet_request_body_only(self, test_files_path): - builder = ClientConfigurationBuilder() - config = ( - builder.with_openapi_spec(test_files_path / "yaml" / "openapi_greeting_service.yml") - .with_http_client(FastAPITestClient(create_greet_request_body_only_app())) - .build() - ) + config = ClientConfiguration(openapi_spec=test_files_path / "yaml" / "openapi_greeting_service.yml", + http_client=FastAPITestClient(create_greet_request_body_only_app())) client = OpenAPIServiceClient(config) payload = { "id": "call_NJr1NBz2Th7iUWJpRIJZoJIA", diff --git a/test/util/test_openapi_client_auth.py b/test/util/test_openapi_client_auth.py index d326d164..6797c2aa 100644 --- a/test/util/test_openapi_client_auth.py +++ b/test/util/test_openapi_client_auth.py @@ -15,8 +15,8 @@ HTTPBearer, ) -from haystack_experimental.util.openapi import ClientConfigurationBuilder, OpenAPIServiceClient, ApiKeyAuthentication, \ - HTTPAuthentication +from haystack_experimental.util.openapi import OpenAPIServiceClient, ApiKeyAuthentication, \ + HTTPAuthentication, ClientConfiguration from test.util.conftest import FastAPITestClient API_KEY = "secret_api_key" @@ -73,7 +73,7 @@ def create_greet_bearer_auth_app() -> FastAPI: app = FastAPI() def bearer_auth_scheme( - credentials: HTTPAuthorizationCredentials = Depends(bearer_auth), # noqa: B008 + credentials: HTTPAuthorizationCredentials = Depends(bearer_auth), # noqa: B008 ): if credentials.scheme != "Bearer" or credentials.credentials != BEARER_TOKEN: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token") @@ -138,13 +138,9 @@ def greet_oauth(name: str, token: HTTPAuthorizationCredentials = Depends(oauth_a class TestOpenAPIAuth: def test_greet_api_key_auth(self, test_files_path): - builder = ClientConfigurationBuilder() - config = ( - builder.with_openapi_spec(test_files_path / "yaml" / "openapi_greeting_service.yml") - .with_http_client(FastAPITestClient(create_greet_api_key_auth_app())) - .with_credentials(ApiKeyAuthentication(API_KEY)) - .build() - ) + config = ClientConfiguration(openapi_spec=test_files_path / "yaml" / "openapi_greeting_service.yml", + http_client=FastAPITestClient(create_greet_api_key_auth_app()), + credentials=ApiKeyAuthentication(API_KEY)) client = OpenAPIServiceClient(config) payload = { "id": "call_NJr1NBz2Th7iUWJpRIJZoJIA", @@ -158,13 +154,9 @@ def test_greet_api_key_auth(self, test_files_path): assert response == {"greeting": "Hello, John from api_key_auth, using secret_api_key"} def test_greet_basic_auth(self, test_files_path): - builder = ClientConfigurationBuilder() - config = ( - builder.with_openapi_spec(test_files_path / "yaml" / "openapi_greeting_service.yml") - .with_http_client(FastAPITestClient(create_greet_basic_auth_app())) - .with_credentials(HTTPAuthentication(BASIC_AUTH_USERNAME, BASIC_AUTH_PASSWORD)) - .build() - ) + config = ClientConfiguration(openapi_spec=test_files_path / "yaml" / "openapi_greeting_service.yml", + http_client=FastAPITestClient(create_greet_basic_auth_app()), + credentials=HTTPAuthentication(BASIC_AUTH_USERNAME, BASIC_AUTH_PASSWORD)) client = OpenAPIServiceClient(config) payload = { "id": "call_NJr1NBz2Th7iUWJpRIJZoJIA", @@ -178,13 +170,9 @@ def test_greet_basic_auth(self, test_files_path): assert response == {"greeting": "Hello, John from basic_auth, using admin"} def test_greet_api_key_query_auth(self, test_files_path): - builder = ClientConfigurationBuilder() - config = ( - builder.with_openapi_spec(test_files_path / "yaml" / "openapi_greeting_service.yml") - .with_http_client(FastAPITestClient(create_greet_api_key_query_app())) - .with_credentials(ApiKeyAuthentication(API_KEY_QUERY)) - .build() - ) + config = ClientConfiguration(openapi_spec=test_files_path / "yaml" / "openapi_greeting_service.yml", + http_client=FastAPITestClient(create_greet_api_key_query_app()), + credentials=ApiKeyAuthentication(API_KEY_QUERY)) client = OpenAPIServiceClient(config) payload = { "id": "call_NJr1NBz2Th7iUWJpRIJZoJIA", @@ -198,13 +186,11 @@ def test_greet_api_key_query_auth(self, test_files_path): assert response == {"greeting": "Hello, John from api_key_query_auth, using secret_api_key_query"} def test_greet_api_key_cookie_auth(self, test_files_path): - builder = ClientConfigurationBuilder() - config = ( - builder.with_openapi_spec(test_files_path / "yaml" / "openapi_greeting_service.yml") - .with_http_client(FastAPITestClient(create_greet_api_key_cookie_app())) - .with_credentials(ApiKeyAuthentication(API_KEY_COOKIE)) - .build() - ) + + config = ClientConfiguration(openapi_spec=test_files_path / "yaml" / "openapi_greeting_service.yml", + http_client=FastAPITestClient(create_greet_api_key_cookie_app()), + credentials=ApiKeyAuthentication(API_KEY_COOKIE)) + client = OpenAPIServiceClient(config) payload = { "id": "call_NJr1NBz2Th7iUWJpRIJZoJIA", @@ -218,13 +204,9 @@ def test_greet_api_key_cookie_auth(self, test_files_path): assert response == {"greeting": "Hello, John from api_key_cookie_auth, using secret_api_key_cookie"} def test_greet_bearer_auth(self, test_files_path): - builder = ClientConfigurationBuilder() - config = ( - builder.with_openapi_spec(test_files_path / "yaml" / "openapi_greeting_service.yml") - .with_http_client(FastAPITestClient(create_greet_bearer_auth_app())) - .with_credentials(HTTPAuthentication(token=BEARER_TOKEN)) - .build() - ) + config = ClientConfiguration(openapi_spec=test_files_path / "yaml" / "openapi_greeting_service.yml", + http_client=FastAPITestClient(create_greet_bearer_auth_app()), + credentials=HTTPAuthentication(token=BEARER_TOKEN)) client = OpenAPIServiceClient(config) payload = { "id": "call_NJr1NBz2Th7iUWJpRIJZoJIA", diff --git a/test/util/test_openapi_client_complex_request_body.py b/test/util/test_openapi_client_complex_request_body.py index 17702cb3..79549177 100644 --- a/test/util/test_openapi_client_complex_request_body.py +++ b/test/util/test_openapi_client_complex_request_body.py @@ -11,7 +11,7 @@ from fastapi.responses import JSONResponse from pydantic import BaseModel -from haystack_experimental.util.openapi import ClientConfigurationBuilder, OpenAPIServiceClient +from haystack_experimental.util.openapi import OpenAPIServiceClient, ClientConfiguration from test.util.conftest import FastAPITestClient @@ -56,13 +56,11 @@ class TestComplexRequestBody: @pytest.mark.parametrize("spec_file_path", ["openapi_order_service.yml", "openapi_order_service.json"]) def test_create_order(self, spec_file_path, test_files_path): - builder = ClientConfigurationBuilder() path_element = "yaml" if spec_file_path.endswith(".yml") else "json" - config = ( - builder.with_openapi_spec(test_files_path / path_element / spec_file_path) - .with_http_client(FastAPITestClient(create_order_app())) - .build() - ) + + config = ClientConfiguration(openapi_spec=test_files_path / path_element / spec_file_path, + http_client=FastAPITestClient(create_order_app())) + client = OpenAPIServiceClient(config) order_json = { "customer": {"name": "John Doe", "email": "john@example.com"}, diff --git a/test/util/test_openapi_client_complex_request_body_mixed.py b/test/util/test_openapi_client_complex_request_body_mixed.py index 907520eb..b769bbbf 100644 --- a/test/util/test_openapi_client_complex_request_body_mixed.py +++ b/test/util/test_openapi_client_complex_request_body_mixed.py @@ -9,7 +9,7 @@ from fastapi.responses import JSONResponse from pydantic import BaseModel -from haystack_experimental.util.openapi import ClientConfigurationBuilder, OpenAPIServiceClient +from haystack_experimental.util.openapi import OpenAPIServiceClient, ClientConfiguration from test.util.conftest import FastAPITestClient @@ -56,12 +56,8 @@ def process_payment(payment: PaymentRequest): class TestPaymentProcess: def test_process_payment(self, test_files_path): - config = ( - ClientConfigurationBuilder() - .with_openapi_spec(test_files_path / "json" / "complex_types_openapi_service.json") - .with_http_client(FastAPITestClient(create_payment_app())) - .build() - ) + config = ClientConfiguration(openapi_spec=test_files_path / "json" / "complex_types_openapi_service.json", + http_client=FastAPITestClient(create_payment_app())) client = OpenAPIServiceClient(config) payment_json = { diff --git a/test/util/test_openapi_client_edge_cases.py b/test/util/test_openapi_client_edge_cases.py index f421e73e..e1c36920 100644 --- a/test/util/test_openapi_client_edge_cases.py +++ b/test/util/test_openapi_client_edge_cases.py @@ -5,24 +5,15 @@ import pytest -from haystack_experimental.util.openapi import ClientConfigurationBuilder, OpenAPIServiceClient +from haystack_experimental.util.openapi import OpenAPIServiceClient, ClientConfiguration from test.util.conftest import FastAPITestClient class TestEdgeCases: - def test_invalid_openapi_spec(self): - builder = ClientConfigurationBuilder() - with pytest.raises(ValueError, match="Invalid OpenAPI specification"): - config = builder.with_openapi_spec("invalid_spec.yml").build() - OpenAPIServiceClient(config) def test_missing_operation_id(self, test_files_path): - builder = ClientConfigurationBuilder() - config = ( - builder.with_openapi_spec(test_files_path / "yaml" / "openapi_edge_cases.yml") - .with_http_client(FastAPITestClient(None)) - .build() - ) + config = ClientConfiguration(openapi_spec=test_files_path / "yaml" / "openapi_edge_cases.yml", + http_client=FastAPITestClient(None)) client = OpenAPIServiceClient(config) payload = { diff --git a/test/util/test_openapi_client_error_handling.py b/test/util/test_openapi_client_error_handling.py index 3201ad65..53e9478a 100644 --- a/test/util/test_openapi_client_error_handling.py +++ b/test/util/test_openapi_client_error_handling.py @@ -8,7 +8,8 @@ import pytest from fastapi import FastAPI, HTTPException -from haystack_experimental.util.openapi import ClientConfigurationBuilder, OpenAPIServiceClient, HttpClientError +from haystack_experimental.util.openapi import OpenAPIServiceClient, HttpClientError, \ + ClientConfiguration from test.util.conftest import FastAPITestClient @@ -25,12 +26,8 @@ def raise_http_error(status_code: int): class TestErrorHandling: @pytest.mark.parametrize("status_code", [400, 401, 403, 404, 500]) def test_http_error_handling(self, test_files_path, status_code): - builder = ClientConfigurationBuilder() - config = ( - builder.with_openapi_spec(test_files_path / "yaml" / "openapi_error_handling.yml") - .with_http_client(FastAPITestClient(create_error_handling_app())) - .build() - ) + config = ClientConfiguration(openapi_spec=test_files_path / "yaml" / "openapi_error_handling.yml", + http_client=FastAPITestClient(create_error_handling_app())) client = OpenAPIServiceClient(config) json_error = {"status_code": status_code} payload = { diff --git a/test/util/test_openapi_client_live.py b/test/util/test_openapi_client_live.py index 869c56a8..ec80b83b 100644 --- a/test/util/test_openapi_client_live.py +++ b/test/util/test_openapi_client_live.py @@ -7,7 +7,7 @@ import pytest import yaml -from haystack_experimental.util.openapi import ClientConfigurationBuilder, OpenAPIServiceClient +from haystack_experimental.util.openapi import OpenAPIServiceClient, ClientConfiguration class TestClientLive: @@ -15,12 +15,7 @@ class TestClientLive: @pytest.mark.skipif("SERPERDEV_API_KEY" not in os.environ, reason="SERPERDEV_API_KEY not set") @pytest.mark.integration def test_serperdev(self, test_files_path): - builder = ClientConfigurationBuilder() - config = ( - builder.with_openapi_spec(test_files_path / "yaml" / "serper.yml") - .with_credentials(os.getenv("SERPERDEV_API_KEY")) - .build() - ) + config = ClientConfiguration(openapi_spec=test_files_path / "yaml" / "serper.yml", credentials=os.getenv("SERPERDEV_API_KEY")) serper_api = OpenAPIServiceClient(config) payload = { "id": "call_NJr1NBz2Th7iUWJpRIJZoJIA", @@ -40,8 +35,7 @@ def test_serperdev_load_spec_first(self, test_files_path): loaded_spec = yaml.safe_load(file) # use builder with dict spec - builder = ClientConfigurationBuilder() - config = builder.with_openapi_spec(loaded_spec).with_credentials(os.getenv("SERPERDEV_API_KEY")).build() + config = ClientConfiguration(openapi_spec=loaded_spec, credentials=os.getenv("SERPERDEV_API_KEY")) serper_api = OpenAPIServiceClient(config) payload = { "id": "call_NJr1NBz2Th7iUWJpRIJZoJIA", @@ -56,8 +50,7 @@ def test_serperdev_load_spec_first(self, test_files_path): @pytest.mark.integration def test_github(self, test_files_path): - builder = ClientConfigurationBuilder() - config = builder.with_openapi_spec(test_files_path / "yaml" / "github_compare.yml").build() + config = ClientConfiguration(openapi_spec=test_files_path / "yaml" / "github_compare.yml") api = OpenAPIServiceClient(config) params = {"owner": "deepset-ai", "repo": "haystack", "basehead": "main...add_default_adapter_filters"} diff --git a/test/util/test_openapi_client_live_anthropic.py b/test/util/test_openapi_client_live_anthropic.py index 0ac5c09c..62c2a853 100644 --- a/test/util/test_openapi_client_live_anthropic.py +++ b/test/util/test_openapi_client_live_anthropic.py @@ -7,7 +7,7 @@ import anthropic import pytest -from haystack_experimental.util.openapi import ClientConfigurationBuilder, OpenAPIServiceClient +from haystack_experimental.util.openapi import ClientConfiguration, OpenAPIServiceClient class TestClientLiveAnthropic: @@ -15,14 +15,10 @@ class TestClientLiveAnthropic: @pytest.mark.skipif("SERPERDEV_API_KEY" not in os.environ, reason="SERPERDEV_API_KEY not set") @pytest.mark.skipif("ANTHROPIC_API_KEY" not in os.environ, reason="ANTHROPIC_API_KEY not set") @pytest.mark.integration - def test_serperdev(self): - builder = ClientConfigurationBuilder() - config = ( - builder.with_openapi_spec("https://bit.ly/serper_dev_spec_yaml") - .with_credentials(os.getenv("SERPERDEV_API_KEY")) - .with_provider("anthropic") - .build() - ) + def test_serperdev(self, test_files_path): + config = ClientConfiguration(openapi_spec=test_files_path / "yaml" / "serper.yml", + credentials=os.getenv("SERPERDEV_API_KEY"), + llm_provider="anthropic") client = anthropic.Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY")) response = client.beta.tools.messages.create( model="claude-3-opus-20240229", @@ -44,12 +40,8 @@ def test_serperdev(self): @pytest.mark.skipif("ANTHROPIC_API_KEY" not in os.environ, reason="ANTHROPIC_API_KEY not set") @pytest.mark.integration def test_github(self, test_files_path): - builder = ClientConfigurationBuilder() - config = ( - builder.with_openapi_spec(test_files_path / "yaml" / "github_compare.yml") - .with_provider("anthropic") - .build() - ) + config = ClientConfiguration(openapi_spec=test_files_path / "yaml" / "github_compare.yml", + llm_provider="anthropic") client = anthropic.Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY")) response = client.beta.tools.messages.create( diff --git a/test/util/test_openapi_client_live_cohere.py b/test/util/test_openapi_client_live_cohere.py index 901b474d..4a316d83 100644 --- a/test/util/test_openapi_client_live_cohere.py +++ b/test/util/test_openapi_client_live_cohere.py @@ -6,7 +6,7 @@ import cohere import pytest -from haystack_experimental.util.openapi import ClientConfigurationBuilder, OpenAPIServiceClient +from haystack_experimental.util.openapi import ClientConfiguration, OpenAPIServiceClient # Copied from Cohere's documentation preamble = """ @@ -28,13 +28,9 @@ class TestClientLiveCohere: @pytest.mark.skipif("COHERE_API_KEY" not in os.environ, reason="COHERE_API_KEY not set") @pytest.mark.integration def test_serperdev(self, test_files_path): - builder = ClientConfigurationBuilder() - config = ( - builder.with_openapi_spec(test_files_path / "yaml" / "serper.yml") - .with_credentials(os.getenv("SERPERDEV_API_KEY")) - .with_provider("cohere") - .build() - ) + config = ClientConfiguration(openapi_spec=test_files_path / "yaml" / "serper.yml", + credentials=os.getenv("SERPERDEV_API_KEY"), + llm_provider="cohere") client = cohere.Client(api_key=os.getenv("COHERE_API_KEY")) response = client.chat( model="command-r", @@ -56,10 +52,8 @@ def test_serperdev(self, test_files_path): @pytest.mark.skipif("COHERE_API_KEY" not in os.environ, reason="COHERE_API_KEY not set") @pytest.mark.integration def test_github(self, test_files_path): - builder = ClientConfigurationBuilder() - config = ( - builder.with_openapi_spec(test_files_path / "yaml" / "github_compare.yml").with_provider("cohere").build() - ) + config = ClientConfiguration(openapi_spec=test_files_path / "yaml" / "github_compare.yml", + llm_provider="cohere") client = cohere.Client(api_key=os.getenv("COHERE_API_KEY")) response = client.chat( diff --git a/test/util/test_openapi_client_live_openai.py b/test/util/test_openapi_client_live_openai.py index bfc00e06..cf723939 100644 --- a/test/util/test_openapi_client_live_openai.py +++ b/test/util/test_openapi_client_live_openai.py @@ -7,7 +7,7 @@ import pytest from openai import OpenAI -from haystack_experimental.util.openapi import ClientConfigurationBuilder, OpenAPIServiceClient +from haystack_experimental.util.openapi import ClientConfiguration, OpenAPIServiceClient class TestClientLiveOpenAPI: @@ -15,13 +15,10 @@ class TestClientLiveOpenAPI: @pytest.mark.skipif("SERPERDEV_API_KEY" not in os.environ, reason="SERPERDEV_API_KEY not set") @pytest.mark.skipif("OPENAI_API_KEY" not in os.environ, reason="OPENAI_API_KEY not set") @pytest.mark.integration - def test_serperdev(self): - builder = ClientConfigurationBuilder() - config = ( - builder.with_openapi_spec("https://bit.ly/serper_dev_spec_yaml") - .with_credentials(os.getenv("SERPERDEV_API_KEY")) - .build() - ) + def test_serperdev(self, test_files_path): + + config = ClientConfiguration(openapi_spec=test_files_path / "yaml" / "serper.yml", + credentials=os.getenv("SERPERDEV_API_KEY")) client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) response = client.chat.completions.create( model="gpt-3.5-turbo", @@ -42,9 +39,7 @@ def test_serperdev(self): @pytest.mark.skipif("OPENAI_API_KEY" not in os.environ, reason="OPENAI_API_KEY not set") @pytest.mark.integration def test_github(self, test_files_path): - builder = ClientConfigurationBuilder() - config = builder.with_openapi_spec(test_files_path / "yaml" / "github_compare.yml").build() - + config = ClientConfiguration(openapi_spec=test_files_path / "yaml" / "github_compare.yml") client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) response = client.chat.completions.create( model="gpt-3.5-turbo", @@ -66,9 +61,7 @@ def test_github(self, test_files_path): @pytest.mark.integration def test_firecrawl(self): openapi_spec_url = "https://raw.githubusercontent.com/mendableai/firecrawl/main/apps/api/openapi.json" - builder = ClientConfigurationBuilder() - config = builder.with_openapi_spec(openapi_spec_url).with_credentials(os.getenv("FIRECRAWL_API_KEY")).build() - + config = ClientConfiguration(openapi_spec=openapi_spec_url, credentials=os.getenv("FIRECRAWL_API_KEY")) client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) response = client.chat.completions.create( model="gpt-3.5-turbo", diff --git a/test/util/test_openapi_spec.py b/test/util/test_openapi_spec.py index 722946f1..ca784169 100644 --- a/test/util/test_openapi_spec.py +++ b/test/util/test_openapi_spec.py @@ -86,18 +86,6 @@ def test_initialized_from_file(self, tmp_path): }, } - # can get all paths - def test_get_all_paths(self): - spec_dict = { - "openapi": "3.0.0", - "info": {"title": "Test API", "version": "1.0.0"}, - "servers": [{"url": "https://api.example.com"}], - "paths": {"/users": {}, "/products": {}, "/orders": {}}, - } - openapi_spec = OpenAPISpecification(spec_dict) - paths = openapi_spec.get_paths() - assert paths == {"/users": {}, "/products": {}, "/orders": {}} - # raises ValueError if initialized from an invalid schema def test_raises_value_error_invalid_schema(self): spec_dict = {"info": {"title": "Test API", "version": "1.0.0"}, "paths": {"/users": {}}} From 20be1e59c7658b16ccfb621cefd080de18ffc947 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Wed, 5 Jun 2024 14:21:34 +0200 Subject: [PATCH 03/40] Refactoring step 2 --- haystack_experimental/util/openapi.py | 308 +----------------- .../util/payload_extraction.py | 67 ++++ .../util/schema_conversion.py | 218 +++++++++++++ test/util/test_openapi_spec.py | 22 -- 4 files changed, 297 insertions(+), 318 deletions(-) create mode 100644 haystack_experimental/util/payload_extraction.py create mode 100644 haystack_experimental/util/schema_conversion.py diff --git a/haystack_experimental/util/openapi.py b/haystack_experimental/util/openapi.py index f6faad81..9d7dc420 100644 --- a/haystack_experimental/util/openapi.py +++ b/haystack_experimental/util/openapi.py @@ -2,22 +2,23 @@ # # SPDX-License-Identifier: Apache-2.0 -import dataclasses import json import logging import os from base64 import b64encode from dataclasses import dataclass, field from pathlib import Path -from typing import Any, Callable, Dict, List, Literal, Optional, Union +from typing import Any, Dict, List, Literal, Optional, Union from urllib.parse import urlparse -import jsonref import requests import yaml from requests.adapters import HTTPAdapter from urllib3 import Retry +from haystack_experimental.util.payload_extraction import create_function_payload_extractor +from haystack_experimental.util.schema_conversion import anthropic_converter, cohere_converter, openai_converter + VALID_HTTP_METHODS = ["get", "put", "post", "delete", "options", "head", "patch", "trace"] MIN_REQUIRED_OPENAPI_SPEC_VERSION = 3 @@ -121,9 +122,6 @@ class HttpClient: def __init__(self, config: Optional[HttpClientConfig] = None): self.config = config or HttpClientConfig() self.session = requests.Session() - self._initialize_session() - - def _initialize_session(self) -> None: retries = Retry( total=self.config.max_retries, backoff_factor=self.config.backoff_factor, @@ -145,7 +143,7 @@ def send_request(self, request: Dict[str, Any]) -> Any: try: response = self.session.request( request["method"], - request["url"], + url, headers=headers, params=request.get("params", {}), json=request.get("json"), @@ -299,18 +297,6 @@ def get_security_schemes(self) -> Dict[str, Dict[str, Any]]: components = self.spec_dict.get("components", {}) return components.get("securitySchemes", {}) - def to_dict(self, *, resolve_references: Optional[bool] = False) -> Dict[str, Any]: - """ - Converts the OpenAPI specification to a dictionary format. - - Optionally resolves all $ref references within the spec, returning a fully resolved specification - dictionary if `resolve_references` is set to True. - - :param resolve_references: If True, resolve references in the specification. - :return: A dictionary representation of the OpenAPI specification, optionally fully resolved. - """ - return jsonref.replace_refs(self.spec_dict, proxies=False) if resolve_references else self.spec_dict - class ClientConfiguration: """Configuration for the OpenAPI client.""" @@ -368,7 +354,7 @@ def get_payload_extractor(self): provider_to_arguments_field_name = {"anthropic": "input", "cohere": "parameters"} # add more providers here # default to OpenAI "arguments" arguments_field_name = provider_to_arguments_field_name.get(self.llm_provider, "arguments") - return LLMFunctionPayloadExtractor(arguments_field_name=arguments_field_name) + return create_function_payload_extractor(arguments_field_name) def _create_authentication_from_string( self, credentials: str, security_schemes: Dict[str, Any] @@ -399,69 +385,6 @@ def is_valid_http_url(self, url: str) -> bool: return all([r.scheme in ["http", "https"], r.netloc]) -class LLMFunctionPayloadExtractor: - """ - Implements a recursive search for extracting LLM generated function payloads. - """ - - def __init__(self, arguments_field_name: str): - self.arguments_field_name = arguments_field_name - - def extract_function_invocation(self, payload: Any) -> Dict[str, Any]: - """ - Extract the function invocation details from the payload. - """ - fields_and_values = self._search(payload) - if fields_and_values: - arguments = fields_and_values.get(self.arguments_field_name) - if not isinstance(arguments, (str, dict)): - raise ValueError( - f"Invalid {self.arguments_field_name} type {type(arguments)} for function call, expected str/dict" - ) - return { - "name": fields_and_values.get("name"), - "arguments": json.loads(arguments) if isinstance(arguments, str) else arguments, - } - return {} - - def _required_fields(self) -> List[str]: - return ["name", self.arguments_field_name] - - def _search(self, payload: Any) -> Dict[str, Any]: - if self._is_primitive(payload): - return {} - if dict_converter := self._get_dict_converter(payload): - payload = dict_converter() - elif dataclasses.is_dataclass(payload): - payload = dataclasses.asdict(payload) - if isinstance(payload, dict): - if all(field in payload for field in self._required_fields()): - # this is the payload we are looking for - return payload - for value in payload.values(): - result = self._search(value) - if result: - return result - elif isinstance(payload, list): - for item in payload: - result = self._search(item) - if result: - return result - return {} - - def _get_dict_converter( - self, obj: Any, method_names: Optional[List[str]] = None - ) -> Union[Callable[[], Dict[str, Any]], None]: - method_names = method_names or ["model_dump", "dict"] # search for pydantic v2 then v1 - for attr in method_names: - if hasattr(obj, attr) and callable(getattr(obj, attr)): - return getattr(obj, attr) - return None - - def _is_primitive(self, obj) -> bool: - return isinstance(obj, (int, float, str, bool, type(None))) - - class OpenAPIServiceClient: """ A client for invoking operations on REST services defined by OpenAPI specifications. @@ -471,10 +394,8 @@ class OpenAPIServiceClient: """ def __init__(self, client_config: ClientConfiguration): - self.auth_config = client_config.get_auth_config() - self.openapi_spec = client_config.openapi_spec + self.client_config = client_config self.http_client = client_config.http_client - self.payload_extractor = client_config.get_payload_extractor() def invoke(self, function_payload: Any) -> Any: """ @@ -485,17 +406,16 @@ def invoke(self, function_payload: Any) -> Any: :raises OpenAPIClientError: If the function invocation payload cannot be extracted from the function payload. :raises HttpClientError: If an error occurs while sending the request and receiving the response. """ - fn_invocation_payload = self.payload_extractor.extract_function_invocation(function_payload) + fn_extractor = self.client_config.get_payload_extractor() + fn_invocation_payload = fn_extractor(function_payload) if not fn_invocation_payload: raise OpenAPIClientError( - f"Failed to extract function invocation payload from {function_payload} using " - f"{self.payload_extractor.__class__.__name__}. Ensure the payload format matches the expected " - "structure for the designated LLM extractor." + f"Failed to extract function invocation payload from {function_payload}" ) # fn_invocation_payload, if not empty, guaranteed to have "name" and "arguments" keys from here on - operation = self.openapi_spec.find_operation_by_id(fn_invocation_payload.get("name")) + operation = self.client_config.openapi_spec.find_operation_by_id(fn_invocation_payload.get("name")) request = self._build_request(operation, **fn_invocation_payload.get("arguments")) - self._apply_authentication(self.auth_config, operation, request) + self._apply_authentication(self.client_config.get_auth_config(), operation, request) return self.http_client.send_request(request) def _build_request(self, operation: Operation, **kwargs) -> Any: @@ -563,207 +483,3 @@ def _apply_authentication(self, auth: AuthenticationStrategy, operation: Operati class OpenAPIClientError(Exception): """Exception raised for errors in the OpenAPI client.""" - - -def openai_converter(schema: OpenAPISpecification) -> List[Dict[str, Any]]: - """ - Converts OpenAPI specification to a list of function suitable for OpenAI LLM function calling. - - :param schema: The OpenAPI specification to convert. - :return: A list of dictionaries, each representing a function definition. - """ - resolved_schema = jsonref.replace_refs(schema.spec_dict) - fn_definitions = _openapi_to_functions(resolved_schema, "parameters", _parse_endpoint_spec_openai) - return [{"type": "function", "function": fn} for fn in fn_definitions] - - -def anthropic_converter(schema: OpenAPISpecification) -> List[Dict[str, Any]]: - """ - Converts an OpenAPI specification to a list of function definitions for Anthropic LLM function calling. - - :param schema: The OpenAPI specification to convert. - :return: A list of dictionaries, each representing a function definition. - """ - resolved_schema = jsonref.replace_refs(schema.spec_dict) - return _openapi_to_functions(resolved_schema, "input_schema", _parse_endpoint_spec_openai) - - -def cohere_converter(schema: OpenAPISpecification) -> List[Dict[str, Any]]: - """ - Converts an OpenAPI specification to a list of function definitions for Cohere LLM function calling. - - :param schema: The OpenAPI specification to convert. - :return: A list of dictionaries, each representing a function definition. - """ - resolved_schema = jsonref.replace_refs(schema.spec_dict) - return _openapi_to_functions(resolved_schema, "not important for cohere", _parse_endpoint_spec_cohere) - - -def _openapi_to_functions( - service_openapi_spec: Dict[str, Any], - parameters_name: str, - parse_endpoint_fn: Callable[[Dict[str, Any], str], Dict[str, Any]], -) -> List[Dict[str, Any]]: - """ - Extracts functions from the OpenAPI specification, converts them into a function schema. - """ - - # Doesn't enforce rigid spec validation because that would require a lot of dependencies - # We check the version and require minimal fields to be present, so we can extract functions - spec_version = service_openapi_spec.get("openapi") - if not spec_version: - raise ValueError(f"Invalid OpenAPI spec provided. Could not extract version from {service_openapi_spec}") - service_openapi_spec_version = int(spec_version.split(".")[0]) - # Compare the versions - if service_openapi_spec_version < MIN_REQUIRED_OPENAPI_SPEC_VERSION: - raise ValueError( - f"Invalid OpenAPI spec version {service_openapi_spec_version}. Must be " - f"at least {MIN_REQUIRED_OPENAPI_SPEC_VERSION}." - ) - functions: List[Dict[str, Any]] = [] - for paths in service_openapi_spec["paths"].values(): - for path_spec in paths.values(): - function_dict = parse_endpoint_fn(path_spec, parameters_name) - if function_dict: - functions.append(function_dict) - return functions - - -def _parse_endpoint_spec_openai(resolved_spec: Dict[str, Any], parameters_name: str) -> Dict[str, Any]: - """ - Parses an OpenAPI endpoint specification for OpenAI. - """ - if not isinstance(resolved_spec, dict): - logger.warning("Invalid OpenAPI spec format provided. Could not extract function.") - return {} - function_name = resolved_spec.get("operationId") - description = resolved_spec.get("description") or resolved_spec.get("summary", "") - schema: Dict[str, Any] = {"type": "object", "properties": {}} - # requestBody section - req_body_schema = ( - resolved_spec.get("requestBody", {}).get("content", {}).get("application/json", {}).get("schema", {}) - ) - if "properties" in req_body_schema: - for prop_name, prop_schema in req_body_schema["properties"].items(): - schema["properties"][prop_name] = _parse_property_attributes(prop_schema) - if "required" in req_body_schema: - schema.setdefault("required", []).extend(req_body_schema["required"]) - - # parameters section - for param in resolved_spec.get("parameters", []): - if "schema" in param: - schema_dict = _parse_property_attributes(param["schema"]) - # these attributes are not in param[schema] level but on param level - useful_attributes = ["description", "pattern", "enum"] - schema_dict.update({key: param[key] for key in useful_attributes if param.get(key)}) - schema["properties"][param["name"]] = schema_dict - if param.get("required", False): - schema.setdefault("required", []).append(param["name"]) - - if function_name and description and schema["properties"]: - return {"name": function_name, "description": description, parameters_name: schema} - logger.warning("Invalid OpenAPI spec format provided. Could not extract function from %s", resolved_spec) - return {} - - -def _parse_property_attributes( - property_schema: Dict[str, Any], include_attributes: Optional[List[str]] = None -) -> Dict[str, Any]: - """ - Recursively parses the attributes of a property schema. - """ - include_attributes = include_attributes or ["description", "pattern", "enum"] - schema_type = property_schema.get("type") - parsed_schema = {"type": schema_type} if schema_type else {} - for attr in include_attributes: - if attr in property_schema: - parsed_schema[attr] = property_schema[attr] - if schema_type == "object": - properties = property_schema.get("properties", {}) - parsed_properties = { - prop_name: _parse_property_attributes(prop, include_attributes) for prop_name, prop in properties.items() - } - parsed_schema["properties"] = parsed_properties - if "required" in property_schema: - parsed_schema["required"] = property_schema["required"] - elif schema_type == "array": - items = property_schema.get("items", {}) - parsed_schema["items"] = _parse_property_attributes(items, include_attributes) - return parsed_schema - - -def _parse_endpoint_spec_cohere(operation: Dict[str, Any], ignored_param: str) -> Dict[str, Any]: - """ - Parses an endpoint specification for Cohere. - """ - function_name = operation.get("operationId") - description = operation.get("description") or operation.get("summary", "") - parameter_definitions = _parse_parameters(operation) - if function_name: - return { - "name": function_name, - "description": description, - "parameter_definitions": parameter_definitions, - } - logger.warning("Operation missing operationId, cannot create function definition.") - return {} - - -def _parse_parameters(operation: Dict[str, Any]) -> Dict[str, Any]: - """ - Parses the parameters from an operation specification. - """ - parameters = {} - for param in operation.get("parameters", []): - if "schema" in param: - parameters[param["name"]] = _parse_schema( - param["schema"], param.get("required", False), param.get("description", "") - ) - if "requestBody" in operation: - content = operation["requestBody"].get("content", {}).get("application/json", {}) - if "schema" in content: - schema_properties = content["schema"].get("properties", {}) - required_properties = content["schema"].get("required", []) - for name, schema in schema_properties.items(): - parameters[name] = _parse_schema(schema, name in required_properties, schema.get("description", "")) - return parameters - - -def _parse_schema(schema: Dict[str, Any], required: bool, description: str) -> Dict[str, Any]: # noqa: FBT001 - """ - Parses a schema part of an operation specification. - """ - schema_type = _get_type(schema) - if schema_type == "object": - # Recursive call for complex types - properties = schema.get("properties", {}) - nested_parameters = { - name: _parse_schema( - schema=prop_schema, - required=bool(name in schema.get("required", False)), - description=prop_schema.get("description", ""), - ) - for name, prop_schema in properties.items() - } - return { - "type": schema_type, - "description": description, - "properties": nested_parameters, - "required": required, - } - return {"type": schema_type, "description": description, "required": required} - - -def _get_type(schema: Dict[str, Any]) -> str: - type_mapping = { - "integer": "int", - "string": "str", - "boolean": "bool", - "number": "float", - "object": "object", - "array": "list", - } - schema_type = schema.get("type", "object") - if schema_type not in type_mapping: - raise ValueError(f"Unsupported schema type {schema_type}") - return type_mapping[schema_type] diff --git a/haystack_experimental/util/payload_extraction.py b/haystack_experimental/util/payload_extraction.py new file mode 100644 index 00000000..202d11c1 --- /dev/null +++ b/haystack_experimental/util/payload_extraction.py @@ -0,0 +1,67 @@ +import dataclasses +import json +from typing import Any, Callable, Dict, List, Optional, Union + + +def _get_dict_converter(obj: Any, + method_names: Optional[List[str]] = None) -> Union[Callable[[], Dict[str, Any]], None]: + method_names = method_names or ["model_dump", "dict"] # search for pydantic v2 then v1 + for attr in method_names: + if hasattr(obj, attr) and callable(getattr(obj, attr)): + return getattr(obj, attr) + return None + + +def _is_primitive(obj) -> bool: + return isinstance(obj, (int, float, str, bool, type(None))) + + +def _required_fields(arguments_field_name: str) -> List[str]: + return ["name", arguments_field_name] + + +def _search(payload: Any, arguments_field_name: str) -> Dict[str, Any]: + if _is_primitive(payload): + return {} + if dict_converter := _get_dict_converter(payload): + payload = dict_converter() + elif dataclasses.is_dataclass(payload): + payload = dataclasses.asdict(payload) + if isinstance(payload, dict): + if all(field in payload for field in _required_fields(arguments_field_name)): + # this is the payload we are looking for + return payload + for value in payload.values(): + result = _search(value, arguments_field_name) + if result: + return result + elif isinstance(payload, list): + for item in payload: + result = _search(item, arguments_field_name) + if result: + return result + return {} + + +def create_function_payload_extractor(arguments_field_name: str) -> Callable[[Any], Dict[str, Any]]: + """ + Extracts invocation payload from a given LLM completion containing function invocation. + """ + def _extract_function_invocation(payload: Any) -> Dict[str, Any]: + """ + Extract the function invocation details from the payload. + """ + fields_and_values = _search(payload, arguments_field_name) + if fields_and_values: + arguments = fields_and_values.get(arguments_field_name) + if not isinstance(arguments, (str, dict)): + raise ValueError( + f"Invalid {arguments_field_name} type {type(arguments)} for function call, expected str/dict" + ) + return { + "name": fields_and_values.get("name"), + "arguments": json.loads(arguments) if isinstance(arguments, str) else arguments, + } + return {} + + return _extract_function_invocation diff --git a/haystack_experimental/util/schema_conversion.py b/haystack_experimental/util/schema_conversion.py new file mode 100644 index 00000000..bca96c80 --- /dev/null +++ b/haystack_experimental/util/schema_conversion.py @@ -0,0 +1,218 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import logging +from typing import Any, Callable, Dict, List, Optional + +import jsonref + +VALID_HTTP_METHODS = ["get", "put", "post", "delete", "options", "head", "patch", "trace"] + +MIN_REQUIRED_OPENAPI_SPEC_VERSION = 3 + +logger = logging.getLogger(__name__) + + +def openai_converter(schema: "OpenAPISpecification") -> List[Dict[str, Any]]: # noqa: F821 + """ + Converts OpenAPI specification to a list of function suitable for OpenAI LLM function calling. + + :param schema: The OpenAPI specification to convert. + :return: A list of dictionaries, each representing a function definition. + """ + resolved_schema = jsonref.replace_refs(schema.spec_dict) + fn_definitions = _openapi_to_functions(resolved_schema, "parameters", _parse_endpoint_spec_openai) + return [{"type": "function", "function": fn} for fn in fn_definitions] + + +def anthropic_converter(schema: "OpenAPISpecification") -> List[Dict[str, Any]]: # noqa: F821 + """ + Converts an OpenAPI specification to a list of function definitions for Anthropic LLM function calling. + + :param schema: The OpenAPI specification to convert. + :return: A list of dictionaries, each representing a function definition. + """ + resolved_schema = jsonref.replace_refs(schema.spec_dict) + return _openapi_to_functions(resolved_schema, "input_schema", _parse_endpoint_spec_openai) + + +def cohere_converter(schema: "OpenAPISpecification") -> List[Dict[str, Any]]: # noqa: F821 + """ + Converts an OpenAPI specification to a list of function definitions for Cohere LLM function calling. + + :param schema: The OpenAPI specification to convert. + :return: A list of dictionaries, each representing a function definition. + """ + resolved_schema = jsonref.replace_refs(schema.spec_dict) + return _openapi_to_functions(resolved_schema, "not important for cohere", _parse_endpoint_spec_cohere) + + +def _openapi_to_functions( + service_openapi_spec: Dict[str, Any], + parameters_name: str, + parse_endpoint_fn: Callable[[Dict[str, Any], str], Dict[str, Any]], +) -> List[Dict[str, Any]]: + """ + Extracts functions from the OpenAPI specification, converts them into a function schema. + """ + + # Doesn't enforce rigid spec validation because that would require a lot of dependencies + # We check the version and require minimal fields to be present, so we can extract functions + spec_version = service_openapi_spec.get("openapi") + if not spec_version: + raise ValueError(f"Invalid OpenAPI spec provided. Could not extract version from {service_openapi_spec}") + service_openapi_spec_version = int(spec_version.split(".")[0]) + # Compare the versions + if service_openapi_spec_version < MIN_REQUIRED_OPENAPI_SPEC_VERSION: + raise ValueError( + f"Invalid OpenAPI spec version {service_openapi_spec_version}. Must be " + f"at least {MIN_REQUIRED_OPENAPI_SPEC_VERSION}." + ) + functions: List[Dict[str, Any]] = [] + for paths in service_openapi_spec["paths"].values(): + for path_spec in paths.values(): + function_dict = parse_endpoint_fn(path_spec, parameters_name) + if function_dict: + functions.append(function_dict) + return functions + + +def _parse_endpoint_spec_openai(resolved_spec: Dict[str, Any], parameters_name: str) -> Dict[str, Any]: + """ + Parses an OpenAPI endpoint specification for OpenAI. + """ + if not isinstance(resolved_spec, dict): + logger.warning("Invalid OpenAPI spec format provided. Could not extract function.") + return {} + function_name = resolved_spec.get("operationId") + description = resolved_spec.get("description") or resolved_spec.get("summary", "") + schema: Dict[str, Any] = {"type": "object", "properties": {}} + # requestBody section + req_body_schema = ( + resolved_spec.get("requestBody", {}).get("content", {}).get("application/json", {}).get("schema", {}) + ) + if "properties" in req_body_schema: + for prop_name, prop_schema in req_body_schema["properties"].items(): + schema["properties"][prop_name] = _parse_property_attributes(prop_schema) + if "required" in req_body_schema: + schema.setdefault("required", []).extend(req_body_schema["required"]) + + # parameters section + for param in resolved_spec.get("parameters", []): + if "schema" in param: + schema_dict = _parse_property_attributes(param["schema"]) + # these attributes are not in param[schema] level but on param level + useful_attributes = ["description", "pattern", "enum"] + schema_dict.update({key: param[key] for key in useful_attributes if param.get(key)}) + schema["properties"][param["name"]] = schema_dict + if param.get("required", False): + schema.setdefault("required", []).append(param["name"]) + + if function_name and description and schema["properties"]: + return {"name": function_name, "description": description, parameters_name: schema} + logger.warning("Invalid OpenAPI spec format provided. Could not extract function from %s", resolved_spec) + return {} + + +def _parse_property_attributes( + property_schema: Dict[str, Any], include_attributes: Optional[List[str]] = None +) -> Dict[str, Any]: + """ + Recursively parses the attributes of a property schema. + """ + include_attributes = include_attributes or ["description", "pattern", "enum"] + schema_type = property_schema.get("type") + parsed_schema = {"type": schema_type} if schema_type else {} + for attr in include_attributes: + if attr in property_schema: + parsed_schema[attr] = property_schema[attr] + if schema_type == "object": + properties = property_schema.get("properties", {}) + parsed_properties = { + prop_name: _parse_property_attributes(prop, include_attributes) for prop_name, prop in properties.items() + } + parsed_schema["properties"] = parsed_properties + if "required" in property_schema: + parsed_schema["required"] = property_schema["required"] + elif schema_type == "array": + items = property_schema.get("items", {}) + parsed_schema["items"] = _parse_property_attributes(items, include_attributes) + return parsed_schema + + +def _parse_endpoint_spec_cohere(operation: Dict[str, Any], ignored_param: str) -> Dict[str, Any]: + """ + Parses an endpoint specification for Cohere. + """ + function_name = operation.get("operationId") + description = operation.get("description") or operation.get("summary", "") + parameter_definitions = _parse_parameters(operation) + if function_name: + return { + "name": function_name, + "description": description, + "parameter_definitions": parameter_definitions, + } + logger.warning("Operation missing operationId, cannot create function definition.") + return {} + + +def _parse_parameters(operation: Dict[str, Any]) -> Dict[str, Any]: + """ + Parses the parameters from an operation specification. + """ + parameters = {} + for param in operation.get("parameters", []): + if "schema" in param: + parameters[param["name"]] = _parse_schema( + param["schema"], param.get("required", False), param.get("description", "") + ) + if "requestBody" in operation: + content = operation["requestBody"].get("content", {}).get("application/json", {}) + if "schema" in content: + schema_properties = content["schema"].get("properties", {}) + required_properties = content["schema"].get("required", []) + for name, schema in schema_properties.items(): + parameters[name] = _parse_schema(schema, name in required_properties, schema.get("description", "")) + return parameters + + +def _parse_schema(schema: Dict[str, Any], required: bool, description: str) -> Dict[str, Any]: # noqa: FBT001 + """ + Parses a schema part of an operation specification. + """ + schema_type = _get_type(schema) + if schema_type == "object": + # Recursive call for complex types + properties = schema.get("properties", {}) + nested_parameters = { + name: _parse_schema( + schema=prop_schema, + required=bool(name in schema.get("required", False)), + description=prop_schema.get("description", ""), + ) + for name, prop_schema in properties.items() + } + return { + "type": schema_type, + "description": description, + "properties": nested_parameters, + "required": required, + } + return {"type": schema_type, "description": description, "required": required} + + +def _get_type(schema: Dict[str, Any]) -> str: + type_mapping = { + "integer": "int", + "string": "str", + "boolean": "bool", + "number": "float", + "object": "object", + "array": "list", + } + schema_type = schema.get("type", "object") + if schema_type not in type_mapping: + raise ValueError(f"Unsupported schema type {schema_type}") + return type_mapping[schema_type] diff --git a/test/util/test_openapi_spec.py b/test/util/test_openapi_spec.py index ca784169..5567b2e7 100644 --- a/test/util/test_openapi_spec.py +++ b/test/util/test_openapi_spec.py @@ -91,25 +91,3 @@ def test_raises_value_error_invalid_schema(self): spec_dict = {"info": {"title": "Test API", "version": "1.0.0"}, "paths": {"/users": {}}} with pytest.raises(ValueError): OpenAPISpecification(spec_dict) - - # Should return the raw OpenAPI specification dictionary with resolved references. - def test_return_raw_spec_with_resolved_references(self, test_files_path): - spec = OpenAPISpecification.from_file(test_files_path / "json" / "complex_types_openapi_service.json") - raw_spec = spec.to_dict(resolve_references=True) - - assert "$ref" not in str(raw_spec) - assert "#/" not in str(raw_spec) - - # verify that we can serialize the raw spec to a string - schema_ser = json.dumps(raw_spec, indent=2) - - # and that the serialized string does not contain any $ref or #/ references - assert "$ref" not in schema_ser - assert "#/" not in schema_ser - - # and that we can deserialize the serialized string back to a dictionary - schema = json.loads(schema_ser) - assert "$ref" not in schema - assert "#/" not in schema - - assert schema == raw_spec From adb96c50febad56cbdfe13accab2e3ae9662e705 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Wed, 5 Jun 2024 15:06:51 +0200 Subject: [PATCH 04/40] Refactor step 3 --- haystack_experimental/util/openapi.py | 100 +++++++----------- .../util/payload_extraction.py | 48 ++++----- test/util/conftest.py | 6 +- test/util/test_openapi_client.py | 6 +- test/util/test_openapi_client_auth.py | 10 +- ...est_openapi_client_complex_request_body.py | 2 +- ...enapi_client_complex_request_body_mixed.py | 2 +- test/util/test_openapi_client_edge_cases.py | 2 +- .../test_openapi_client_error_handling.py | 2 +- 9 files changed, 75 insertions(+), 103 deletions(-) diff --git a/haystack_experimental/util/openapi.py b/haystack_experimental/util/openapi.py index 9d7dc420..fef95de7 100644 --- a/haystack_experimental/util/openapi.py +++ b/haystack_experimental/util/openapi.py @@ -8,13 +8,11 @@ from base64 import b64encode from dataclasses import dataclass, field from pathlib import Path -from typing import Any, Dict, List, Literal, Optional, Union +from typing import Any, Callable, Dict, List, Literal, Optional, Union from urllib.parse import urlparse import requests import yaml -from requests.adapters import HTTPAdapter -from urllib3 import Retry from haystack_experimental.util.payload_extraction import create_function_payload_extractor from haystack_experimental.util.schema_conversion import anthropic_converter, cohere_converter, openai_converter @@ -105,63 +103,6 @@ def apply_auth(self, security_scheme: Dict[str, Any], request: Dict[str, Any]): raise ValueError("HTTPAuthentication strategy received a non-HTTP security scheme.") -@dataclass -class HttpClientConfig: - """Configuration for the HTTP client.""" - - timeout: int = 10 - max_retries: int = 3 - backoff_factor: float = 0.3 - retry_on_status: set = field(default_factory=lambda: {500, 502, 503, 504}) - default_headers: Dict[str, str] = field(default_factory=dict) - - -class HttpClient: - """HTTP client for sending requests.""" - - def __init__(self, config: Optional[HttpClientConfig] = None): - self.config = config or HttpClientConfig() - self.session = requests.Session() - retries = Retry( - total=self.config.max_retries, - backoff_factor=self.config.backoff_factor, - status_forcelist=self.config.retry_on_status, - ) - adapter = HTTPAdapter(max_retries=retries) - self.session.mount("http://", adapter) - self.session.mount("https://", adapter) - self.session.headers.update(self.config.default_headers) - - def send_request(self, request: Dict[str, Any]) -> Any: - """ - Send an HTTP request using the provided request dictionary. - - :param request: A dictionary containing the request details. - """ - url = request["url"] - headers = {**self.config.default_headers, **request.get("headers", {})} - try: - response = self.session.request( - request["method"], - url, - headers=headers, - params=request.get("params", {}), - json=request.get("json"), - auth=request.get("auth"), - ) - response.raise_for_status() - return response.json() - except requests.exceptions.HTTPError as e: - logger.warning("HTTP error occurred: %s while sending request to %s", e, url) - raise HttpClientError(f"HTTP error occurred: {e}") from e - except requests.exceptions.RequestException as e: - logger.warning("Request error occurred: %s while sending request to %s", e, url) - raise HttpClientError(f"HTTP error occurred: {e}") from e - except Exception as e: - logger.warning("An error occurred: %s while sending request to %s", e, url) - raise HttpClientError(f"An error occurred: {e}") from e - - class HttpClientError(Exception): """Exception raised for errors in the HTTP client.""" @@ -305,7 +246,7 @@ def __init__( # noqa: PLR0913 pylint: disable=too-many-arguments self, openapi_spec: Union[str, Path, Dict[str, Any]], credentials: Optional[Union[str, Dict[str, Any], AuthenticationStrategy]] = None, - http_client: Optional[HttpClient] = None, + request_sender: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None, llm_provider: Optional[str] = None, ): # noqa: PLR0913 if isinstance(openapi_spec, (str, Path)) and os.path.isfile(openapi_spec): @@ -321,7 +262,7 @@ def __init__( # noqa: PLR0913 pylint: disable=too-many-arguments raise ValueError("Invalid OpenAPI specification format. Expected file path or dictionary.") self.credentials = credentials - self.http_client = http_client or HttpClient(HttpClientConfig()) + self.request_sender = request_sender self.llm_provider = llm_provider or "openai" def get_auth_config(self) -> AuthenticationStrategy: @@ -395,7 +336,7 @@ class OpenAPIServiceClient: def __init__(self, client_config: ClientConfiguration): self.client_config = client_config - self.http_client = client_config.http_client + self.request_sender = client_config.request_sender or self._request_sender() def invoke(self, function_payload: Any) -> Any: """ @@ -416,7 +357,38 @@ def invoke(self, function_payload: Any) -> Any: operation = self.client_config.openapi_spec.find_operation_by_id(fn_invocation_payload.get("name")) request = self._build_request(operation, **fn_invocation_payload.get("arguments")) self._apply_authentication(self.client_config.get_auth_config(), operation, request) - return self.http_client.send_request(request) + return self.request_sender(request) + + def _request_sender(self) -> Callable[[Dict[str, Any]], Dict[str, Any]]: + """ + Returns a callable that sends the request using the HTTP client. + """ + def send_request(request: Dict[str, Any]) -> Dict[str, Any]: + url = request["url"] + headers = {**request.get("headers", {})} + try: + response = requests.request( + request["method"], + url, + headers=headers, + params=request.get("params", {}), + json=request.get("json"), + auth=request.get("auth"), + timeout=10, + ) + response.raise_for_status() + return response.json() + except requests.exceptions.HTTPError as e: + logger.warning("HTTP error occurred: %s while sending request to %s", e, url) + raise HttpClientError(f"HTTP error occurred: {e}") from e + except requests.exceptions.RequestException as e: + logger.warning("Request error occurred: %s while sending request to %s", e, url) + raise HttpClientError(f"HTTP error occurred: {e}") from e + except Exception as e: + logger.warning("An error occurred: %s while sending request to %s", e, url) + raise HttpClientError(f"An error occurred: {e}") from e + + return send_request def _build_request(self, operation: Operation, **kwargs) -> Any: request = { diff --git a/haystack_experimental/util/payload_extraction.py b/haystack_experimental/util/payload_extraction.py index 202d11c1..b98969c1 100644 --- a/haystack_experimental/util/payload_extraction.py +++ b/haystack_experimental/util/payload_extraction.py @@ -3,6 +3,30 @@ from typing import Any, Callable, Dict, List, Optional, Union +def create_function_payload_extractor(arguments_field_name: str) -> Callable[[Any], Dict[str, Any]]: + """ + Extracts invocation payload from a given LLM completion containing function invocation. + """ + def _extract_function_invocation(payload: Any) -> Dict[str, Any]: + """ + Extract the function invocation details from the payload. + """ + fields_and_values = _search(payload, arguments_field_name) + if fields_and_values: + arguments = fields_and_values.get(arguments_field_name) + if not isinstance(arguments, (str, dict)): + raise ValueError( + f"Invalid {arguments_field_name} type {type(arguments)} for function call, expected str/dict" + ) + return { + "name": fields_and_values.get("name"), + "arguments": json.loads(arguments) if isinstance(arguments, str) else arguments, + } + return {} + + return _extract_function_invocation + + def _get_dict_converter(obj: Any, method_names: Optional[List[str]] = None) -> Union[Callable[[], Dict[str, Any]], None]: method_names = method_names or ["model_dump", "dict"] # search for pydantic v2 then v1 @@ -41,27 +65,3 @@ def _search(payload: Any, arguments_field_name: str) -> Dict[str, Any]: if result: return result return {} - - -def create_function_payload_extractor(arguments_field_name: str) -> Callable[[Any], Dict[str, Any]]: - """ - Extracts invocation payload from a given LLM completion containing function invocation. - """ - def _extract_function_invocation(payload: Any) -> Dict[str, Any]: - """ - Extract the function invocation details from the payload. - """ - fields_and_values = _search(payload, arguments_field_name) - if fields_and_values: - arguments = fields_and_values.get(arguments_field_name) - if not isinstance(arguments, (str, dict)): - raise ValueError( - f"Invalid {arguments_field_name} type {type(arguments)} for function call, expected str/dict" - ) - return { - "name": fields_and_values.get("name"), - "arguments": json.loads(arguments) if isinstance(arguments, str) else arguments, - } - return {} - - return _extract_function_invocation diff --git a/test/util/conftest.py b/test/util/conftest.py index 39caa850..05bd940a 100644 --- a/test/util/conftest.py +++ b/test/util/conftest.py @@ -10,7 +10,7 @@ from fastapi import FastAPI from fastapi.testclient import TestClient -from haystack_experimental.util.openapi import HttpClient, HttpClientError +from haystack_experimental.util.openapi import HttpClientError @pytest.fixture() @@ -18,7 +18,7 @@ def test_files_path(): return Path(__file__).parent.parent / "test_files" -class FastAPITestClient(HttpClient): +class FastAPITestClient: def __init__(self, app: FastAPI): self.app = app @@ -31,7 +31,7 @@ def strip_host(self, url: str) -> str: new_path += "?" + parsed_url.query return new_path - def send_request(self, request: dict) -> dict: + def __call__(self, request: dict) -> dict: # OAS spec will list a server URL, but FastAPI doesn't need it for local testing, in fact it will fail # if the URL has a host. So we strip it here. url = self.strip_host(request["url"]) diff --git a/test/util/test_openapi_client.py b/test/util/test_openapi_client.py index 3769bd0c..66a74056 100644 --- a/test/util/test_openapi_client.py +++ b/test/util/test_openapi_client.py @@ -73,7 +73,7 @@ class TestOpenAPI: def test_greet_mix_params_body(self, test_files_path): config = ClientConfiguration(openapi_spec=test_files_path / "yaml" / "openapi_greeting_service.yml", - http_client=FastAPITestClient(create_greet_mix_params_body_app())) + request_sender=FastAPITestClient(create_greet_mix_params_body_app())) client = OpenAPIServiceClient(config) payload = { "id": "call_NJr1NBz2Th7iUWJpRIJZoJIA", @@ -88,7 +88,7 @@ def test_greet_mix_params_body(self, test_files_path): def test_greet_params_only(self, test_files_path): config = ClientConfiguration(openapi_spec=test_files_path / "yaml" / "openapi_greeting_service.yml", - http_client=FastAPITestClient(create_greet_params_only_app())) + request_sender=FastAPITestClient(create_greet_params_only_app())) client = OpenAPIServiceClient(config) payload = { "id": "call_NJr1NBz2Th7iUWJpRIJZoJIA", @@ -103,7 +103,7 @@ def test_greet_params_only(self, test_files_path): def test_greet_request_body_only(self, test_files_path): config = ClientConfiguration(openapi_spec=test_files_path / "yaml" / "openapi_greeting_service.yml", - http_client=FastAPITestClient(create_greet_request_body_only_app())) + request_sender=FastAPITestClient(create_greet_request_body_only_app())) client = OpenAPIServiceClient(config) payload = { "id": "call_NJr1NBz2Th7iUWJpRIJZoJIA", diff --git a/test/util/test_openapi_client_auth.py b/test/util/test_openapi_client_auth.py index 6797c2aa..014edde5 100644 --- a/test/util/test_openapi_client_auth.py +++ b/test/util/test_openapi_client_auth.py @@ -139,7 +139,7 @@ class TestOpenAPIAuth: def test_greet_api_key_auth(self, test_files_path): config = ClientConfiguration(openapi_spec=test_files_path / "yaml" / "openapi_greeting_service.yml", - http_client=FastAPITestClient(create_greet_api_key_auth_app()), + request_sender=FastAPITestClient(create_greet_api_key_auth_app()), credentials=ApiKeyAuthentication(API_KEY)) client = OpenAPIServiceClient(config) payload = { @@ -155,7 +155,7 @@ def test_greet_api_key_auth(self, test_files_path): def test_greet_basic_auth(self, test_files_path): config = ClientConfiguration(openapi_spec=test_files_path / "yaml" / "openapi_greeting_service.yml", - http_client=FastAPITestClient(create_greet_basic_auth_app()), + request_sender=FastAPITestClient(create_greet_basic_auth_app()), credentials=HTTPAuthentication(BASIC_AUTH_USERNAME, BASIC_AUTH_PASSWORD)) client = OpenAPIServiceClient(config) payload = { @@ -171,7 +171,7 @@ def test_greet_basic_auth(self, test_files_path): def test_greet_api_key_query_auth(self, test_files_path): config = ClientConfiguration(openapi_spec=test_files_path / "yaml" / "openapi_greeting_service.yml", - http_client=FastAPITestClient(create_greet_api_key_query_app()), + request_sender=FastAPITestClient(create_greet_api_key_query_app()), credentials=ApiKeyAuthentication(API_KEY_QUERY)) client = OpenAPIServiceClient(config) payload = { @@ -188,7 +188,7 @@ def test_greet_api_key_query_auth(self, test_files_path): def test_greet_api_key_cookie_auth(self, test_files_path): config = ClientConfiguration(openapi_spec=test_files_path / "yaml" / "openapi_greeting_service.yml", - http_client=FastAPITestClient(create_greet_api_key_cookie_app()), + request_sender=FastAPITestClient(create_greet_api_key_cookie_app()), credentials=ApiKeyAuthentication(API_KEY_COOKIE)) client = OpenAPIServiceClient(config) @@ -205,7 +205,7 @@ def test_greet_api_key_cookie_auth(self, test_files_path): def test_greet_bearer_auth(self, test_files_path): config = ClientConfiguration(openapi_spec=test_files_path / "yaml" / "openapi_greeting_service.yml", - http_client=FastAPITestClient(create_greet_bearer_auth_app()), + request_sender=FastAPITestClient(create_greet_bearer_auth_app()), credentials=HTTPAuthentication(token=BEARER_TOKEN)) client = OpenAPIServiceClient(config) payload = { diff --git a/test/util/test_openapi_client_complex_request_body.py b/test/util/test_openapi_client_complex_request_body.py index 79549177..ccfb63df 100644 --- a/test/util/test_openapi_client_complex_request_body.py +++ b/test/util/test_openapi_client_complex_request_body.py @@ -59,7 +59,7 @@ def test_create_order(self, spec_file_path, test_files_path): path_element = "yaml" if spec_file_path.endswith(".yml") else "json" config = ClientConfiguration(openapi_spec=test_files_path / path_element / spec_file_path, - http_client=FastAPITestClient(create_order_app())) + request_sender=FastAPITestClient(create_order_app())) client = OpenAPIServiceClient(config) order_json = { diff --git a/test/util/test_openapi_client_complex_request_body_mixed.py b/test/util/test_openapi_client_complex_request_body_mixed.py index b769bbbf..4dc532fc 100644 --- a/test/util/test_openapi_client_complex_request_body_mixed.py +++ b/test/util/test_openapi_client_complex_request_body_mixed.py @@ -57,7 +57,7 @@ class TestPaymentProcess: def test_process_payment(self, test_files_path): config = ClientConfiguration(openapi_spec=test_files_path / "json" / "complex_types_openapi_service.json", - http_client=FastAPITestClient(create_payment_app())) + request_sender=FastAPITestClient(create_payment_app())) client = OpenAPIServiceClient(config) payment_json = { diff --git a/test/util/test_openapi_client_edge_cases.py b/test/util/test_openapi_client_edge_cases.py index e1c36920..888580d9 100644 --- a/test/util/test_openapi_client_edge_cases.py +++ b/test/util/test_openapi_client_edge_cases.py @@ -13,7 +13,7 @@ class TestEdgeCases: def test_missing_operation_id(self, test_files_path): config = ClientConfiguration(openapi_spec=test_files_path / "yaml" / "openapi_edge_cases.yml", - http_client=FastAPITestClient(None)) + request_sender=FastAPITestClient(None)) client = OpenAPIServiceClient(config) payload = { diff --git a/test/util/test_openapi_client_error_handling.py b/test/util/test_openapi_client_error_handling.py index 53e9478a..d826e33b 100644 --- a/test/util/test_openapi_client_error_handling.py +++ b/test/util/test_openapi_client_error_handling.py @@ -27,7 +27,7 @@ class TestErrorHandling: @pytest.mark.parametrize("status_code", [400, 401, 403, 404, 500]) def test_http_error_handling(self, test_files_path, status_code): config = ClientConfiguration(openapi_spec=test_files_path / "yaml" / "openapi_error_handling.yml", - http_client=FastAPITestClient(create_error_handling_app())) + request_sender=FastAPITestClient(create_error_handling_app())) client = OpenAPIServiceClient(config) json_error = {"status_code": status_code} payload = { From 221df1bd7a2735f08120a66eb62c0be318e7edef Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Wed, 5 Jun 2024 15:25:31 +0200 Subject: [PATCH 05/40] Refactoring step 4 --- haystack_experimental/util/openapi.py | 55 +++++++++------------------ 1 file changed, 19 insertions(+), 36 deletions(-) diff --git a/haystack_experimental/util/openapi.py b/haystack_experimental/util/openapi.py index fef95de7..94b7614b 100644 --- a/haystack_experimental/util/openapi.py +++ b/haystack_experimental/util/openapi.py @@ -18,9 +18,7 @@ from haystack_experimental.util.schema_conversion import anthropic_converter, cohere_converter, openai_converter VALID_HTTP_METHODS = ["get", "put", "post", "delete", "options", "head", "patch", "trace"] - MIN_REQUIRED_OPENAPI_SPEC_VERSION = 3 - logger = logging.getLogger(__name__) @@ -28,7 +26,6 @@ class AuthenticationStrategy: """ Represents an authentication strategy that can be applied to an HTTP request. """ - def apply_auth(self, security_scheme: Dict[str, Any], request: Dict[str, Any]): """ Apply the authentication strategy to the given request. @@ -41,7 +38,6 @@ def apply_auth(self, security_scheme: Dict[str, Any], request: Dict[str, Any]): @dataclass class ApiKeyAuthentication(AuthenticationStrategy): """API key authentication strategy.""" - api_key: Optional[str] = None def apply_auth(self, security_scheme: Dict[str, Any], request: Dict[str, Any]): @@ -67,7 +63,6 @@ def apply_auth(self, security_scheme: Dict[str, Any], request: Dict[str, Any]): @dataclass class HTTPAuthentication(AuthenticationStrategy): """HTTP authentication strategy.""" - username: Optional[str] = None password: Optional[str] = None token: Optional[str] = None @@ -235,8 +230,7 @@ def get_security_schemes(self) -> Dict[str, Dict[str, Any]]: """ Get the security schemes from the OpenAPI specification. """ - components = self.spec_dict.get("components", {}) - return components.get("securitySchemes", {}) + return self.spec_dict.get("components", {}).get("securitySchemes", {}) class ClientConfiguration: @@ -391,27 +385,7 @@ def send_request(request: Dict[str, Any]) -> Dict[str, Any]: return send_request def _build_request(self, operation: Operation, **kwargs) -> Any: - request = { - "url": self._build_url(operation, **kwargs), - "method": operation.method.lower(), - "headers": self._build_headers(operation, **kwargs), - "params": self._build_query_params(operation, **kwargs), - "json": self._build_request_body(operation, **kwargs), - } - return request - - def _build_headers(self, operation: Operation, **kwargs) -> Dict[str, str]: - headers = {} - for parameter in operation.get_parameters("header"): - param_value = kwargs.get(parameter["name"], None) - if param_value: - headers[parameter["name"]] = str(param_value) - elif parameter.get("required", False): - raise ValueError(f"Missing required header parameter: {parameter['name']}") - return headers - - def _build_url(self, operation: Operation, **kwargs) -> str: - server_url = operation.get_server() + # url path = operation.path for parameter in operation.get_parameters("path"): param_value = kwargs.get(parameter["name"], None) @@ -419,9 +393,18 @@ def _build_url(self, operation: Operation, **kwargs) -> str: path = path.replace(f"{{{parameter['name']}}}", str(param_value)) elif parameter.get("required", False): raise ValueError(f"Missing required path parameter: {parameter['name']}") - return server_url + path - - def _build_query_params(self, operation: Operation, **kwargs) -> Dict[str, Any]: + url = operation.get_server() + path + # method + method = operation.method.lower() + # headers + headers = {} + for parameter in operation.get_parameters("header"): + param_value = kwargs.get(parameter["name"], None) + if param_value: + headers[parameter["name"]] = str(param_value) + elif parameter.get("required", False): + raise ValueError(f"Missing required header parameter: {parameter['name']}") + # query params query_params = {} for parameter in operation.get_parameters("query"): param_value = kwargs.get(parameter["name"], None) @@ -429,16 +412,16 @@ def _build_query_params(self, operation: Operation, **kwargs) -> Dict[str, Any]: query_params[parameter["name"]] = param_value elif parameter.get("required", False): raise ValueError(f"Missing required query parameter: {parameter['name']}") - return query_params - def _build_request_body(self, operation: Operation, **kwargs) -> Any: + json_payload = None request_body = operation.request_body if request_body: content = request_body.get("content", {}) if "application/json" in content: - return {**kwargs} - raise NotImplementedError("Request body content type not supported") - return None + json_payload = {**kwargs} + else: + raise NotImplementedError("Request body content type not supported") + return {"url": url, "method": method, "headers": headers, "params": query_params, "json": json_payload} def _apply_authentication(self, auth: AuthenticationStrategy, operation: Operation, request: Dict[str, Any]): auth_config = auth or AuthenticationStrategy() From e2ecd8d77c58883f516b90c639250c986947478a Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Wed, 5 Jun 2024 16:50:10 +0200 Subject: [PATCH 06/40] Add OpenAPITool initial impl --- .../components/tools/openapi/__init__.py | 3 + .../tools/openapi/generator_factory.py | 85 +++++++++++++ .../components/tools/openapi/openapi.py | 118 ++++++++++++++++++ 3 files changed, 206 insertions(+) create mode 100644 haystack_experimental/components/tools/openapi/generator_factory.py diff --git a/haystack_experimental/components/tools/openapi/__init__.py b/haystack_experimental/components/tools/openapi/__init__.py index c1764a6e..8668d6f9 100644 --- a/haystack_experimental/components/tools/openapi/__init__.py +++ b/haystack_experimental/components/tools/openapi/__init__.py @@ -1,3 +1,6 @@ # SPDX-FileCopyrightText: 2022-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 +from haystack_experimental.components.tools.openapi.openapi import OpenAPITool + +__all__ = ["OpenAPITool"] diff --git a/haystack_experimental/components/tools/openapi/generator_factory.py b/haystack_experimental/components/tools/openapi/generator_factory.py new file mode 100644 index 00000000..0084a3e3 --- /dev/null +++ b/haystack_experimental/components/tools/openapi/generator_factory.py @@ -0,0 +1,85 @@ +import importlib +import re +from dataclasses import dataclass +from enum import Enum +from typing import Any, Optional, Tuple + + +class LLMProvider(Enum): + OPENAI = "openai" + ANTHROPIC = "anthropic" + COHERE = "cohere" + + +PROVIDER_DETAILS = { + LLMProvider.OPENAI: { + "class_path": "haystack.components.generators.chat.openai.OpenAIChatGenerator", + "patterns": [re.compile(r"^gpt.*")], + }, + LLMProvider.ANTHROPIC: { + "class_path": "haystack_integrations.components.generators.anthropic.AnthropicChatGenerator", + "patterns": [re.compile(r"^claude.*")], + }, + LLMProvider.COHERE: { + "class_path": "haystack_integrations.components.generators.cohere.CohereChatGenerator", + "patterns": [re.compile(r"^command-r.*")], + }, +} + + +def load_class(full_class_path: str): + """ + Load a class from a string representation of its path e.g. "module.submodule.class_name" + """ + module_path, _, class_name = full_class_path.rpartition(".") + module = importlib.import_module(module_path) + return getattr(module, class_name) + + +@dataclass +class LLMIdentifier: + provider: LLMProvider + model_name: str + + def __post_init__(self): + if not isinstance(self.provider, LLMProvider): + raise ValueError(f"Invalid provider: {self.provider}") + + if not isinstance(self.model_name, str): + raise ValueError(f"Model name must be a string: {self.model_name}") + + details = PROVIDER_DETAILS.get(self.provider) + if not details or not any( + pattern.match(self.model_name) for pattern in details["patterns"] + ): + raise ValueError( + f"Invalid combination of provider {self.provider} and model name {self.model_name}" + ) + + +def create_generator( + model_name: str, provider: Optional[str] = None, **model_kwargs +) -> Tuple[LLMIdentifier, Any]: + """ + Create ChatGenerator instance based on the model name and provider. + """ + if provider: + try: + provider_enum = LLMProvider[provider.lower()] + except KeyError: + raise ValueError(f"Invalid provider: {provider}") + else: + provider_enum = None + for prov, details in PROVIDER_DETAILS.items(): + if any(pattern.match(model_name) for pattern in details["patterns"]): + provider_enum = prov + break + + if provider_enum is None: + raise ValueError(f"Could not infer provider for model name: {model_name}") + + llm_identifier = LLMIdentifier(provider=provider_enum, model_name=model_name) + class_path = PROVIDER_DETAILS[llm_identifier.provider]["class_path"] + return llm_identifier, load_class(class_path)( + model=llm_identifier.model_name, **model_kwargs + ) diff --git a/haystack_experimental/components/tools/openapi/openapi.py b/haystack_experimental/components/tools/openapi/openapi.py index c1764a6e..e5cac83a 100644 --- a/haystack_experimental/components/tools/openapi/openapi.py +++ b/haystack_experimental/components/tools/openapi/openapi.py @@ -1,3 +1,121 @@ # SPDX-FileCopyrightText: 2022-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 +import json +from pathlib import Path +from typing import Any, Dict, List, Optional, Union + +from haystack import component, default_from_dict, default_to_dict, logging +from haystack.dataclasses import ChatMessage, ChatRole + +from haystack_experimental.components.tools.openapi.generator_factory import ( + create_generator, +) +from haystack_experimental.util.openapi import ClientConfiguration, OpenAPIServiceClient + +logger = logging.getLogger(__name__) + + +@component +class OpenAPITool: + """ + The OpenAPITool calls an OpenAPI service using payloads generated by the chat generator from human instructions. + + Here is an example of how to use the OpenAPITool component to scrape a URL using the FireCrawl API: + + ```python + from haystack.components.tools import OpenAPITool + from haystack.components.generators.chat.openai import OpenAIChatGenerator + from haystack.dataclasses import ChatMessage + + tool = OpenAPITool(llm_provider="openai", model="gpt-3.5-turbo", + tool_spec="https://raw.githubusercontent.com/mendableai/firecrawl/main/apps/api/openapi.json", + tool_credentials="") + + results = tool.run(messages=[ChatMessage.from_user("Scrape URL: https://news.ycombinator.com/")]) + print(results) + ``` + + Similarly, you can use the OpenAPITool component to use any OpenAPI service/tool by providing the OpenAPI + specification and credentials. + """ + + def __init__( + self, + model: str, + tool_spec: Optional[Union[str, Path]] = None, + tool_credentials: Optional[Union[str, Dict[str, Any]]] = None, + ): + self.llm_id, self.chat_generator = create_generator(model) + self.config_openapi = ( + ClientConfiguration( + openapi_spec=tool_spec, + credentials=tool_credentials, + llm_provider=self.llm_id.provider.value, + ) + if tool_spec + else None + ) + + @component.output_types(service_response=List[ChatMessage]) + def run( + self, + messages: List[ChatMessage], + fc_generator_kwargs: Optional[Dict[str, Any]] = None, + tool_spec: Optional[Union[str, Path, Dict[str, Any]]] = None, + tool_credentials: Optional[Union[dict, str]] = None, + ) -> Dict[str, List[ChatMessage]]: + """ + Invokes the underlying OpenAPI service/tool with the function calling payload generated by the chat generator. + + :param messages: List of ChatMessages to generate function calling payload (e.g. human instructions). + :param fc_generator_kwargs: Additional arguments for the function calling payload generation process. + :param tool_spec: OpenAPI specification for the tool/service. + :param tool_credentials: Credentials for the tool/service. + :returns: a dictionary containing the service response with the following key: + - `service_response`: List of ChatMessages containing the service response. + """ + last_message = messages[-1] + if not last_message.is_from(ChatRole.USER): + raise ValueError(f"{last_message} not from the user") + if not last_message.content: + raise ValueError("Function calling instruction message content is empty.") + + # build a new ClientConfiguration if a runtime tool_spec is provided + config_openapi = ( + ClientConfiguration( + openapi_spec=tool_spec, + credentials=tool_credentials, + llm_provider=self.llm_id.provider.value, + ) + if tool_spec + else self.config_openapi + ) + + if not config_openapi: + raise ValueError( + "OpenAPI specification not provided. Please provide an OpenAPI specification either at initialization " + "or during runtime." + ) + # merge fc_generator_kwargs, tools definitions comes from the OpenAPI spec, other kwargs are passed by the user + fc_generator_kwargs = { + "tools": config_openapi.get_tools_definitions(), + **(fc_generator_kwargs or {}), + } + + # generate function calling payload with the chat generator + logger.debug( + f"Invoking chat generator with {last_message.content} to generate function calling payload." + ) + fc_payload = self.chat_generator.run(messages, fc_generator_kwargs) + + openapi_service = OpenAPIServiceClient(config_openapi) + try: + invocation_payload = json.loads(fc_payload["replies"][0].content) + logger.debug(f"Invoking tool with {invocation_payload}") + service_response = openapi_service.invoke(invocation_payload) + except Exception as e: + logger.error(f"Error invoking OpenAPI endpoint. Error: {e}") + service_response = {"error": str(e)} + response_messages = [ChatMessage.from_user(json.dumps(service_response))] + return {"service_response": response_messages} From 992ef6ef437ec5e01d31c260a10488f250072162 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Wed, 5 Jun 2024 16:57:01 +0200 Subject: [PATCH 07/40] Add headers --- haystack_experimental/components/tools/openapi/__init__.py | 1 + .../components/tools/openapi/generator_factory.py | 4 ++++ haystack_experimental/components/tools/openapi/openapi.py | 1 + haystack_experimental/util/payload_extraction.py | 4 ++++ haystack_experimental/util/schema_conversion.py | 4 ++++ 5 files changed, 14 insertions(+) diff --git a/haystack_experimental/components/tools/openapi/__init__.py b/haystack_experimental/components/tools/openapi/__init__.py index 8668d6f9..7b68a7a2 100644 --- a/haystack_experimental/components/tools/openapi/__init__.py +++ b/haystack_experimental/components/tools/openapi/__init__.py @@ -1,6 +1,7 @@ # SPDX-FileCopyrightText: 2022-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 + from haystack_experimental.components.tools.openapi.openapi import OpenAPITool __all__ = ["OpenAPITool"] diff --git a/haystack_experimental/components/tools/openapi/generator_factory.py b/haystack_experimental/components/tools/openapi/generator_factory.py index 0084a3e3..41da1923 100644 --- a/haystack_experimental/components/tools/openapi/generator_factory.py +++ b/haystack_experimental/components/tools/openapi/generator_factory.py @@ -1,3 +1,7 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + import importlib import re from dataclasses import dataclass diff --git a/haystack_experimental/components/tools/openapi/openapi.py b/haystack_experimental/components/tools/openapi/openapi.py index e5cac83a..ffaf8f3f 100644 --- a/haystack_experimental/components/tools/openapi/openapi.py +++ b/haystack_experimental/components/tools/openapi/openapi.py @@ -1,6 +1,7 @@ # SPDX-FileCopyrightText: 2022-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 + import json from pathlib import Path from typing import Any, Dict, List, Optional, Union diff --git a/haystack_experimental/util/payload_extraction.py b/haystack_experimental/util/payload_extraction.py index b98969c1..bffd7815 100644 --- a/haystack_experimental/util/payload_extraction.py +++ b/haystack_experimental/util/payload_extraction.py @@ -1,3 +1,7 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + import dataclasses import json from typing import Any, Callable, Dict, List, Optional, Union diff --git a/haystack_experimental/util/schema_conversion.py b/haystack_experimental/util/schema_conversion.py index bca96c80..6d17e344 100644 --- a/haystack_experimental/util/schema_conversion.py +++ b/haystack_experimental/util/schema_conversion.py @@ -4,6 +4,10 @@ import logging from typing import Any, Callable, Dict, List, Optional +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + import jsonref From 910a7b252ee1b2b563cd75cd0e71ef5c8b857e8c Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Wed, 5 Jun 2024 17:27:07 +0200 Subject: [PATCH 08/40] Refactoring step 5 - move things around --- .../components/tools/__init__.py | 2 +- .../components/tools/openapi/__init__.py | 2 +- .../tools/openapi/generator_factory.py | 9 +- .../components/tools/openapi/openapi.py | 608 +++++++++++++++--- .../components/tools/openapi/openapi_tool.py | 123 ++++ .../tools/openapi}/payload_extraction.py | 19 +- .../tools/openapi}/schema_conversion.py | 78 ++- haystack_experimental/util/openapi.py | 440 ------------- test/components/tools/openapi/__init__.py | 3 + .../tools/openapi}/conftest.py | 4 +- .../tools/openapi}/test_openapi_client.py | 4 +- .../openapi}/test_openapi_client_auth.py | 4 +- ...est_openapi_client_complex_request_body.py | 4 +- ...enapi_client_complex_request_body_mixed.py | 4 +- .../test_openapi_client_edge_cases.py | 4 +- .../test_openapi_client_error_handling.py | 4 +- .../openapi}/test_openapi_client_live.py | 2 +- .../test_openapi_client_live_anthropic.py | 2 +- .../test_openapi_client_live_cohere.py | 2 +- .../test_openapi_client_live_openai.py | 2 +- .../test_openapi_cohere_conversion.py | 2 +- .../test_openapi_openai_conversion.py | 2 +- .../tools/openapi}/test_openapi_spec.py | 4 +- 23 files changed, 742 insertions(+), 586 deletions(-) create mode 100644 haystack_experimental/components/tools/openapi/openapi_tool.py rename haystack_experimental/{util => components/tools/openapi}/payload_extraction.py (83%) rename haystack_experimental/{util => components/tools/openapi}/schema_conversion.py (79%) delete mode 100644 haystack_experimental/util/openapi.py create mode 100644 test/components/tools/openapi/__init__.py rename test/{util => components/tools/openapi}/conftest.py (90%) rename test/{util => components/tools/openapi}/test_openapi_client.py (95%) rename test/{util => components/tools/openapi}/test_openapi_client_auth.py (97%) rename test/{util => components/tools/openapi}/test_openapi_client_complex_request_body.py (92%) rename test/{util => components/tools/openapi}/test_openapi_client_complex_request_body_mixed.py (92%) rename test/{util => components/tools/openapi}/test_openapi_client_edge_cases.py (82%) rename test/{util => components/tools/openapi}/test_openapi_client_error_handling.py (87%) rename test/{util => components/tools/openapi}/test_openapi_client_live.py (95%) rename test/{util => components/tools/openapi}/test_openapi_client_live_anthropic.py (95%) rename test/{util => components/tools/openapi}/test_openapi_client_live_cohere.py (96%) rename test/{util => components/tools/openapi}/test_openapi_client_live_openai.py (97%) rename test/{util => components/tools/openapi}/test_openapi_cohere_conversion.py (97%) rename test/{util => components/tools/openapi}/test_openapi_openai_conversion.py (97%) rename test/{util => components/tools/openapi}/test_openapi_spec.py (97%) diff --git a/haystack_experimental/components/tools/__init__.py b/haystack_experimental/components/tools/__init__.py index 65434145..be1a6773 100644 --- a/haystack_experimental/components/tools/__init__.py +++ b/haystack_experimental/components/tools/__init__.py @@ -4,4 +4,4 @@ from .openai.function_caller import OpenAIFunctionCaller -_all_ = ["OpenAIFunctionCaller"] +_all_ = ["OpenAIFunctionCaller"] \ No newline at end of file diff --git a/haystack_experimental/components/tools/openapi/__init__.py b/haystack_experimental/components/tools/openapi/__init__.py index 7b68a7a2..c109a7d3 100644 --- a/haystack_experimental/components/tools/openapi/__init__.py +++ b/haystack_experimental/components/tools/openapi/__init__.py @@ -2,6 +2,6 @@ # # SPDX-License-Identifier: Apache-2.0 -from haystack_experimental.components.tools.openapi.openapi import OpenAPITool +from haystack_experimental.components.tools.openapi.openapi_tool import OpenAPITool __all__ = ["OpenAPITool"] diff --git a/haystack_experimental/components/tools/openapi/generator_factory.py b/haystack_experimental/components/tools/openapi/generator_factory.py index 41da1923..78bb429f 100644 --- a/haystack_experimental/components/tools/openapi/generator_factory.py +++ b/haystack_experimental/components/tools/openapi/generator_factory.py @@ -10,6 +10,9 @@ class LLMProvider(Enum): + """ + Enum for different LLM providers + """ OPENAI = "openai" ANTHROPIC = "anthropic" COHERE = "cohere" @@ -42,6 +45,7 @@ def load_class(full_class_path: str): @dataclass class LLMIdentifier: + """ Dataclass to hold the LLM provider and model name""" provider: LLMProvider model_name: str @@ -68,10 +72,9 @@ def create_generator( Create ChatGenerator instance based on the model name and provider. """ if provider: - try: - provider_enum = LLMProvider[provider.lower()] - except KeyError: + if provider.lower() not in LLMProvider.__members__: raise ValueError(f"Invalid provider: {provider}") + provider_enum = LLMProvider[provider.lower()] else: provider_enum = None for prov, details in PROVIDER_DETAILS.items(): diff --git a/haystack_experimental/components/tools/openapi/openapi.py b/haystack_experimental/components/tools/openapi/openapi.py index ffaf8f3f..3764d37e 100644 --- a/haystack_experimental/components/tools/openapi/openapi.py +++ b/haystack_experimental/components/tools/openapi/openapi.py @@ -3,120 +3,548 @@ # SPDX-License-Identifier: Apache-2.0 import json +import logging +import os +from base64 import b64encode +from dataclasses import dataclass, field from pathlib import Path -from typing import Any, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Literal, Optional, Union +from urllib.parse import urlparse -from haystack import component, default_from_dict, default_to_dict, logging -from haystack.dataclasses import ChatMessage, ChatRole +import requests +import yaml -from haystack_experimental.components.tools.openapi.generator_factory import ( - create_generator, +from haystack_experimental.components.tools.openapi.payload_extraction import ( + create_function_payload_extractor, +) +from haystack_experimental.components.tools.openapi.schema_conversion import ( + anthropic_converter, + cohere_converter, + openai_converter, ) -from haystack_experimental.util.openapi import ClientConfiguration, OpenAPIServiceClient +VALID_HTTP_METHODS = [ + "get", + "put", + "post", + "delete", + "options", + "head", + "patch", + "trace", +] +MIN_REQUIRED_OPENAPI_SPEC_VERSION = 3 logger = logging.getLogger(__name__) -@component -class OpenAPITool: +class AuthenticationStrategy: + """ + Represents an authentication strategy that can be applied to an HTTP request. """ - The OpenAPITool calls an OpenAPI service using payloads generated by the chat generator from human instructions. - Here is an example of how to use the OpenAPITool component to scrape a URL using the FireCrawl API: + def apply_auth(self, security_scheme: Dict[str, Any], request: Dict[str, Any]): + """ + Apply the authentication strategy to the given request. - ```python - from haystack.components.tools import OpenAPITool - from haystack.components.generators.chat.openai import OpenAIChatGenerator - from haystack.dataclasses import ChatMessage + :param security_scheme: the security scheme from the OpenAPI spec. + :param request: the request to apply the authentication to. + """ - tool = OpenAPITool(llm_provider="openai", model="gpt-3.5-turbo", - tool_spec="https://raw.githubusercontent.com/mendableai/firecrawl/main/apps/api/openapi.json", - tool_credentials="") - results = tool.run(messages=[ChatMessage.from_user("Scrape URL: https://news.ycombinator.com/")]) - print(results) - ``` +@dataclass +class ApiKeyAuthentication(AuthenticationStrategy): + """API key authentication strategy.""" - Similarly, you can use the OpenAPITool component to use any OpenAPI service/tool by providing the OpenAPI - specification and credentials. - """ + api_key: Optional[str] = None - def __init__( - self, - model: str, - tool_spec: Optional[Union[str, Path]] = None, - tool_credentials: Optional[Union[str, Dict[str, Any]]] = None, - ): - self.llm_id, self.chat_generator = create_generator(model) - self.config_openapi = ( - ClientConfiguration( - openapi_spec=tool_spec, - credentials=tool_credentials, - llm_provider=self.llm_id.provider.value, + def apply_auth(self, security_scheme: Dict[str, Any], request: Dict[str, Any]): + """ + Apply the API key authentication strategy to the given request. + + :param security_scheme: the security scheme from the OpenAPI spec. + :param request: the request to apply the authentication to. + """ + if security_scheme["in"] == "header": + request.setdefault("headers", {})[security_scheme["name"]] = self.api_key + elif security_scheme["in"] == "query": + request.setdefault("params", {})[security_scheme["name"]] = self.api_key + elif security_scheme["in"] == "cookie": + request.setdefault("cookies", {})[security_scheme["name"]] = self.api_key + else: + raise ValueError( + f"Unsupported apiKey authentication location: {security_scheme['in']}, " + f"must be one of 'header', 'query', or 'cookie'" ) - if tool_spec - else None + + +@dataclass +class HTTPAuthentication(AuthenticationStrategy): + """HTTP authentication strategy.""" + + username: Optional[str] = None + password: Optional[str] = None + token: Optional[str] = None + + def __post_init__(self): + if not self.token and (not self.username or not self.password): + raise ValueError( + "For HTTP Basic Auth, both username and password must be provided. " + "For Bearer Auth, a token must be provided." + ) + + def apply_auth(self, security_scheme: Dict[str, Any], request: Dict[str, Any]): + """ + Apply the HTTP authentication strategy to the given request. + + :param security_scheme: the security scheme from the OpenAPI spec. + :param request: the request to apply the authentication to. + """ + if security_scheme["type"] == "http": + if security_scheme["scheme"].lower() == "basic": + if not self.username or not self.password: + raise ValueError( + "Username and password must be provided for Basic Auth." + ) + credentials = f"{self.username}:{self.password}" + encoded_credentials = b64encode(credentials.encode("utf-8")).decode( + "utf-8" + ) + request.setdefault("headers", {})[ + "Authorization" + ] = f"Basic {encoded_credentials}" + elif security_scheme["scheme"].lower() == "bearer": + if not self.token: + raise ValueError("Token must be provided for Bearer Auth.") + request.setdefault("headers", {})[ + "Authorization" + ] = f"Bearer {self.token}" + else: + raise ValueError( + f"Unsupported HTTP authentication scheme: {security_scheme['scheme']}" + ) + else: + raise ValueError( + "HTTPAuthentication strategy received a non-HTTP security scheme." + ) + + +class HttpClientError(Exception): + """Exception raised for errors in the HTTP client.""" + + +@dataclass +class Operation: + """Represents an operation in an OpenAPI specification.""" + + path: str + method: str + operation_dict: Dict[str, Any] + spec_dict: Dict[str, Any] + security_requirements: List[Dict[str, List[str]]] = field(init=False) + request_body: Dict[str, Any] = field(init=False) + parameters: List[Dict[str, Any]] = field(init=False) + + def __post_init__(self): + if self.method.lower() not in VALID_HTTP_METHODS: + raise ValueError(f"Invalid HTTP method: {self.method}") + self.method = self.method.lower() + self.security_requirements = self.operation_dict.get( + "security", [] + ) or self.spec_dict.get("security", []) + self.request_body = self.operation_dict.get("requestBody", {}) + self.parameters = self.operation_dict.get( + "parameters", [] + ) + self.spec_dict.get("paths", {}).get(self.path, {}).get("parameters", []) + + def get_parameters( + self, location: Optional[Literal["header", "query", "path"]] = None + ) -> List[Dict[str, Any]]: + """ + Get the parameters for the operation. + """ + if location: + return [param for param in self.parameters if param["in"] == location] + return self.parameters + + def get_server(self) -> str: + """ + Get the servers for the operation. + """ + servers = self.operation_dict.get("servers", []) or self.spec_dict.get( + "servers", [] ) + return servers[0].get("url", "") # just use the first server from the list + + +class OpenAPISpecification: + """Represents an OpenAPI specification.""" + + def __init__(self, spec_dict: Dict[str, Any]): + if not isinstance(spec_dict, Dict): + raise ValueError( + f"Invalid OpenAPI specification, expected a dictionary: {spec_dict}" + ) + # just a crude sanity check, by no means a full validation + if ( + "openapi" not in spec_dict + or "paths" not in spec_dict + or "servers" not in spec_dict + ): + raise ValueError( + "Invalid OpenAPI specification format. See https://swagger.io/specification/ for details.", + spec_dict, + ) + self.spec_dict = spec_dict + + @classmethod + def from_dict(cls, spec_dict: Dict[str, Any]) -> "OpenAPISpecification": + """ + Create an OpenAPISpecification instance from a dictionary. + """ + parser = cls(spec_dict) + return parser + + @classmethod + def from_str(cls, content: str) -> "OpenAPISpecification": + """ + Create an OpenAPISpecification instance from a string. + """ + try: + loaded_spec = json.loads(content) + except json.JSONDecodeError: + try: + loaded_spec = yaml.safe_load(content) + except yaml.YAMLError as e: + raise ValueError( + "Content cannot be decoded as JSON or YAML: " + str(e) + ) from e + return cls(loaded_spec) + + @classmethod + def from_file(cls, spec_file: Union[str, Path]) -> "OpenAPISpecification": + """ + Create an OpenAPISpecification instance from a file. + """ + with open(spec_file, encoding="utf-8") as file: + content = file.read() + return cls.from_str(content) - @component.output_types(service_response=List[ChatMessage]) - def run( + @classmethod + def from_url(cls, url: str) -> "OpenAPISpecification": + """ + Create an OpenAPISpecification instance from a URL. + """ + try: + response = requests.get(url, timeout=10) + response.raise_for_status() + content = response.text + except requests.RequestException as e: + raise ConnectionError( + f"Failed to fetch the specification from URL: {url}. {e!s}" + ) from e + return cls.from_str(content) + + def find_operation_by_id( + self, op_id: str, method: Optional[str] = None + ) -> Operation: + """ + Find an operation by operationId. + """ + for path, path_item in self.spec_dict.get("paths", {}).items(): + op: Operation = self.get_operation_item(path, path_item, method) + if op_id in op.operation_dict.get("operationId", ""): + return self.get_operation_item(path, path_item, method) + raise ValueError(f"No operation found with operationId {op_id}") + + def get_operation_item( + self, path: str, path_item: Dict[str, Any], method: Optional[str] = None + ) -> Operation: + """ + Get an operation item from the OpenAPI specification. + + :param path: The path of the operation. + :param path_item: The path item from the OpenAPI specification. + :param method: The HTTP method of the operation. + """ + if method: + operation_dict = path_item.get(method.lower(), {}) + if not operation_dict: + raise ValueError( + f"No operation found for method {method} at path {path}" + ) + return Operation(path, method.lower(), operation_dict, self.spec_dict) + if len(path_item) == 1: + method, operation_dict = next(iter(path_item.items())) + return Operation(path, method, operation_dict, self.spec_dict) + if len(path_item) > 1: + raise ValueError( + f"Multiple operations found at path {path}, method parameter is required." + ) + raise ValueError(f"No operations found at path {path}.") + + def get_security_schemes(self) -> Dict[str, Dict[str, Any]]: + """ + Get the security schemes from the OpenAPI specification. + """ + return self.spec_dict.get("components", {}).get("securitySchemes", {}) + + +class ClientConfiguration: + """Configuration for the OpenAPI client.""" + + def __init__( # noqa: PLR0913 pylint: disable=too-many-arguments self, - messages: List[ChatMessage], - fc_generator_kwargs: Optional[Dict[str, Any]] = None, - tool_spec: Optional[Union[str, Path, Dict[str, Any]]] = None, - tool_credentials: Optional[Union[dict, str]] = None, - ) -> Dict[str, List[ChatMessage]]: - """ - Invokes the underlying OpenAPI service/tool with the function calling payload generated by the chat generator. - - :param messages: List of ChatMessages to generate function calling payload (e.g. human instructions). - :param fc_generator_kwargs: Additional arguments for the function calling payload generation process. - :param tool_spec: OpenAPI specification for the tool/service. - :param tool_credentials: Credentials for the tool/service. - :returns: a dictionary containing the service response with the following key: - - `service_response`: List of ChatMessages containing the service response. - """ - last_message = messages[-1] - if not last_message.is_from(ChatRole.USER): - raise ValueError(f"{last_message} not from the user") - if not last_message.content: - raise ValueError("Function calling instruction message content is empty.") - - # build a new ClientConfiguration if a runtime tool_spec is provided - config_openapi = ( - ClientConfiguration( - openapi_spec=tool_spec, - credentials=tool_credentials, - llm_provider=self.llm_id.provider.value, + openapi_spec: Union[str, Path, Dict[str, Any]], + credentials: Optional[ + Union[str, Dict[str, Any], AuthenticationStrategy] + ] = None, + request_sender: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None, + llm_provider: Optional[str] = None, + ): # noqa: PLR0913 + if isinstance(openapi_spec, (str, Path)) and os.path.isfile(openapi_spec): + self.openapi_spec = OpenAPISpecification.from_file(openapi_spec) + elif isinstance(openapi_spec, dict): + self.openapi_spec = OpenAPISpecification.from_dict(openapi_spec) + elif isinstance(openapi_spec, str): + if self.is_valid_http_url(openapi_spec): + self.openapi_spec = OpenAPISpecification.from_url(openapi_spec) + else: + self.openapi_spec = OpenAPISpecification.from_str(openapi_spec) + else: + raise ValueError( + "Invalid OpenAPI specification format. Expected file path or dictionary." ) - if tool_spec - else self.config_openapi + + self.credentials = credentials + self.request_sender = request_sender + self.llm_provider = llm_provider or "openai" + + def get_auth_config(self) -> AuthenticationStrategy: + """ + Get the authentication configuration. + """ + if not self.credentials: + return AuthenticationStrategy() + if isinstance(self.credentials, AuthenticationStrategy): + return self.credentials + security_schemes = self.openapi_spec.get_security_schemes() + if isinstance(self.credentials, str): + return self._create_authentication_from_string( + self.credentials, security_schemes + ) + if isinstance(self.credentials, dict): + return self._create_authentication_from_dict(self.credentials) + raise ValueError(f"Unsupported credentials type: {type(self.credentials)}") + + def get_tools_definitions(self) -> List[Dict[str, Any]]: + """ + Get the tools definitions used as tools LLM parameter. + """ + provider_to_converter = { + "anthropic": anthropic_converter, + "cohere": cohere_converter, + } + converter = provider_to_converter.get(self.llm_provider, openai_converter) + return converter(self.openapi_spec) + + def get_payload_extractor(self): + """ + Get the payload extractor for the LLM provider. + """ + provider_to_arguments_field_name = { + "anthropic": "input", + "cohere": "parameters", + } # add more providers here + # default to OpenAI "arguments" + arguments_field_name = provider_to_arguments_field_name.get( + self.llm_provider, "arguments" ) + return create_function_payload_extractor(arguments_field_name) - if not config_openapi: - raise ValueError( - "OpenAPI specification not provided. Please provide an OpenAPI specification either at initialization " - "or during runtime." + def _create_authentication_from_string( + self, credentials: str, security_schemes: Dict[str, Any] + ) -> AuthenticationStrategy: + for scheme in security_schemes.values(): + if scheme["type"] == "apiKey": + return ApiKeyAuthentication(api_key=credentials) + if scheme["type"] == "http": + return HTTPAuthentication(token=credentials) + if scheme["type"] == "oauth2": + raise NotImplementedError("OAuth2 authentication is not yet supported.") + raise ValueError( + f"Unable to create authentication from provided credentials: {credentials}" + ) + + def _create_authentication_from_dict( + self, credentials: Dict[str, Any] + ) -> AuthenticationStrategy: + if "username" in credentials and "password" in credentials: + return HTTPAuthentication( + username=credentials["username"], password=credentials["password"] + ) + if "api_key" in credentials: + return ApiKeyAuthentication(api_key=credentials["api_key"]) + if "token" in credentials: + return HTTPAuthentication(token=credentials["token"]) + if "access_token" in credentials: + raise NotImplementedError("OAuth2 authentication is not yet supported.") + raise ValueError( + "Unable to create authentication from provided credentials: {credentials}" + ) + + def is_valid_http_url(self, url: str) -> bool: + """Check if a URL is a valid HTTP/HTTPS URL.""" + r = urlparse(url) + return all([r.scheme in ["http", "https"], r.netloc]) + + +class OpenAPIServiceClient: + """ + A client for invoking operations on REST services defined by OpenAPI specifications. + + Together with the `ClientConfiguration`, its `ClientConfigurationBuilder`, the `OpenAPIServiceClient` + simplifies the process of (LLMs) with services defined by OpenAPI specifications. + """ + + def __init__(self, client_config: ClientConfiguration): + self.client_config = client_config + self.request_sender = client_config.request_sender or self._request_sender() + + def invoke(self, function_payload: Any) -> Any: + """ + Invokes a function specified in the function payload. + + :param function_payload: The function payload containing the details of the function to be invoked. + :returns: The response from the service after invoking the function. + :raises OpenAPIClientError: If the function invocation payload cannot be extracted from the function payload. + :raises HttpClientError: If an error occurs while sending the request and receiving the response. + """ + fn_extractor = self.client_config.get_payload_extractor() + fn_invocation_payload = fn_extractor(function_payload) + if not fn_invocation_payload: + raise OpenAPIClientError( + f"Failed to extract function invocation payload from {function_payload}" ) - # merge fc_generator_kwargs, tools definitions comes from the OpenAPI spec, other kwargs are passed by the user - fc_generator_kwargs = { - "tools": config_openapi.get_tools_definitions(), - **(fc_generator_kwargs or {}), + # fn_invocation_payload, if not empty, guaranteed to have "name" and "arguments" keys from here on + operation = self.client_config.openapi_spec.find_operation_by_id( + fn_invocation_payload.get("name") + ) + request = self._build_request( + operation, **fn_invocation_payload.get("arguments") + ) + self._apply_authentication( + self.client_config.get_auth_config(), operation, request + ) + return self.request_sender(request) + + def _request_sender(self) -> Callable[[Dict[str, Any]], Dict[str, Any]]: + """ + Returns a callable that sends the request using the HTTP client. + """ + + def send_request(request: Dict[str, Any]) -> Dict[str, Any]: + url = request["url"] + headers = {**request.get("headers", {})} + try: + response = requests.request( + request["method"], + url, + headers=headers, + params=request.get("params", {}), + json=request.get("json"), + auth=request.get("auth"), + timeout=10, + ) + response.raise_for_status() + return response.json() + except requests.exceptions.HTTPError as e: + logger.warning( + "HTTP error occurred: %s while sending request to %s", e, url + ) + raise HttpClientError(f"HTTP error occurred: {e}") from e + except requests.exceptions.RequestException as e: + logger.warning( + "Request error occurred: %s while sending request to %s", e, url + ) + raise HttpClientError(f"HTTP error occurred: {e}") from e + except Exception as e: + logger.warning( + "An error occurred: %s while sending request to %s", e, url + ) + raise HttpClientError(f"An error occurred: {e}") from e + + return send_request + + def _build_request(self, operation: Operation, **kwargs) -> Any: + # url + path = operation.path + for parameter in operation.get_parameters("path"): + param_value = kwargs.get(parameter["name"], None) + if param_value: + path = path.replace(f"{{{parameter['name']}}}", str(param_value)) + elif parameter.get("required", False): + raise ValueError( + f"Missing required path parameter: {parameter['name']}" + ) + url = operation.get_server() + path + # method + method = operation.method.lower() + # headers + headers = {} + for parameter in operation.get_parameters("header"): + param_value = kwargs.get(parameter["name"], None) + if param_value: + headers[parameter["name"]] = str(param_value) + elif parameter.get("required", False): + raise ValueError( + f"Missing required header parameter: {parameter['name']}" + ) + # query params + query_params = {} + for parameter in operation.get_parameters("query"): + param_value = kwargs.get(parameter["name"], None) + if param_value: + query_params[parameter["name"]] = param_value + elif parameter.get("required", False): + raise ValueError( + f"Missing required query parameter: {parameter['name']}" + ) + + json_payload = None + request_body = operation.request_body + if request_body: + content = request_body.get("content", {}) + if "application/json" in content: + json_payload = {**kwargs} + else: + raise NotImplementedError("Request body content type not supported") + return { + "url": url, + "method": method, + "headers": headers, + "params": query_params, + "json": json_payload, } - # generate function calling payload with the chat generator - logger.debug( - f"Invoking chat generator with {last_message.content} to generate function calling payload." + def _apply_authentication( + self, + auth: AuthenticationStrategy, + operation: Operation, + request: Dict[str, Any], + ): + auth_config = auth or AuthenticationStrategy() + security_requirements = operation.security_requirements + security_schemes = operation.spec_dict.get("components", {}).get( + "securitySchemes", {} ) - fc_payload = self.chat_generator.run(messages, fc_generator_kwargs) + if security_requirements: + for requirement in security_requirements: + for scheme_name in requirement: + if scheme_name in security_schemes: + security_scheme = security_schemes[scheme_name] + auth_config.apply_auth(security_scheme, request) + break - openapi_service = OpenAPIServiceClient(config_openapi) - try: - invocation_payload = json.loads(fc_payload["replies"][0].content) - logger.debug(f"Invoking tool with {invocation_payload}") - service_response = openapi_service.invoke(invocation_payload) - except Exception as e: - logger.error(f"Error invoking OpenAPI endpoint. Error: {e}") - service_response = {"error": str(e)} - response_messages = [ChatMessage.from_user(json.dumps(service_response))] - return {"service_response": response_messages} + +class OpenAPIClientError(Exception): + """Exception raised for errors in the OpenAPI client.""" diff --git a/haystack_experimental/components/tools/openapi/openapi_tool.py b/haystack_experimental/components/tools/openapi/openapi_tool.py new file mode 100644 index 00000000..0b2e1ac7 --- /dev/null +++ b/haystack_experimental/components/tools/openapi/openapi_tool.py @@ -0,0 +1,123 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import json +from pathlib import Path +from typing import Any, Dict, List, Optional, Union + +from haystack import component, logging +from haystack.dataclasses import ChatMessage, ChatRole + +from haystack_experimental.components.tools.openapi.generator_factory import create_generator +from haystack_experimental.components.tools.openapi.openapi import ( + ClientConfiguration, + OpenAPIServiceClient, +) + +logger = logging.getLogger(__name__) + + +@component +class OpenAPITool: + """ + The OpenAPITool calls an OpenAPI service using payloads generated by the chat generator from human instructions. + + Here is an example of how to use the OpenAPITool component to scrape a URL using the FireCrawl API: + + ```python + from haystack_experimental.components.tools import OpenAPITool + from haystack.components.generators.chat.openai import OpenAIChatGenerator + from haystack.dataclasses import ChatMessage + + tool = OpenAPITool(llm_provider="openai", model="gpt-3.5-turbo", + tool_spec="https://raw.githubusercontent.com/mendableai/firecrawl/main/apps/api/openapi.json", + tool_credentials="") + + results = tool.run(messages=[ChatMessage.from_user("Scrape URL: https://news.ycombinator.com/")]) + print(results) + ``` + + Similarly, you can use the OpenAPITool component to use any OpenAPI service/tool by providing the OpenAPI + specification and credentials. + """ + + def __init__( + self, + model: str, + tool_spec: Optional[Union[str, Path]] = None, + tool_credentials: Optional[Union[str, Dict[str, Any]]] = None, + ): + self.llm_id, self.chat_generator = create_generator(model) + self.config_openapi = ( + ClientConfiguration( + openapi_spec=tool_spec, + credentials=tool_credentials, + llm_provider=self.llm_id.provider.value, + ) + if tool_spec + else None + ) + + @component.output_types(service_response=List[ChatMessage]) + def run( + self, + messages: List[ChatMessage], + fc_generator_kwargs: Optional[Dict[str, Any]] = None, + tool_spec: Optional[Union[str, Path, Dict[str, Any]]] = None, + tool_credentials: Optional[Union[dict, str]] = None, + ) -> Dict[str, List[ChatMessage]]: + """ + Invokes the underlying OpenAPI service/tool with the function calling payload generated by the chat generator. + + :param messages: List of ChatMessages to generate function calling payload (e.g. human instructions). + :param fc_generator_kwargs: Additional arguments for the function calling payload generation process. + :param tool_spec: OpenAPI specification for the tool/service. + :param tool_credentials: Credentials for the tool/service. + :returns: a dictionary containing the service response with the following key: + - `service_response`: List of ChatMessages containing the service response. + """ + last_message = messages[-1] + if not last_message.is_from(ChatRole.USER): + raise ValueError(f"{last_message} not from the user") + if not last_message.content: + raise ValueError("Function calling instruction message content is empty.") + + # build a new ClientConfiguration if a runtime tool_spec is provided + config_openapi = ( + ClientConfiguration( + openapi_spec=tool_spec, + credentials=tool_credentials, + llm_provider=self.llm_id.provider.value, + ) + if tool_spec + else self.config_openapi + ) + + if not config_openapi: + raise ValueError( + "OpenAPI specification not provided. Please provide an OpenAPI specification either at initialization " + "or during runtime." + ) + # merge fc_generator_kwargs, tools definitions comes from the OpenAPI spec, other kwargs are passed by the user + fc_generator_kwargs = { + "tools": config_openapi.get_tools_definitions(), + **(fc_generator_kwargs or {}), + } + + # generate function calling payload with the chat generator + logger.debug( + f"Invoking chat generator with {last_message.content} to generate function calling payload." + ) + fc_payload = self.chat_generator.run(messages, fc_generator_kwargs) + + openapi_service = OpenAPIServiceClient(config_openapi) + try: + invocation_payload = json.loads(fc_payload["replies"][0].content) + logger.debug(f"Invoking tool with {invocation_payload}") + service_response = openapi_service.invoke(invocation_payload) + except Exception as e: # pylint: disable=broad-exception-caught + logger.error(f"Error invoking OpenAPI endpoint. Error: {e}") + service_response = {"error": str(e)} + response_messages = [ChatMessage.from_user(json.dumps(service_response))] + return {"service_response": response_messages} diff --git a/haystack_experimental/util/payload_extraction.py b/haystack_experimental/components/tools/openapi/payload_extraction.py similarity index 83% rename from haystack_experimental/util/payload_extraction.py rename to haystack_experimental/components/tools/openapi/payload_extraction.py index bffd7815..157070ae 100644 --- a/haystack_experimental/util/payload_extraction.py +++ b/haystack_experimental/components/tools/openapi/payload_extraction.py @@ -7,10 +7,13 @@ from typing import Any, Callable, Dict, List, Optional, Union -def create_function_payload_extractor(arguments_field_name: str) -> Callable[[Any], Dict[str, Any]]: +def create_function_payload_extractor( + arguments_field_name: str, +) -> Callable[[Any], Dict[str, Any]]: """ Extracts invocation payload from a given LLM completion containing function invocation. """ + def _extract_function_invocation(payload: Any) -> Dict[str, Any]: """ Extract the function invocation details from the payload. @@ -24,16 +27,22 @@ def _extract_function_invocation(payload: Any) -> Dict[str, Any]: ) return { "name": fields_and_values.get("name"), - "arguments": json.loads(arguments) if isinstance(arguments, str) else arguments, + "arguments": json.loads(arguments) + if isinstance(arguments, str) + else arguments, } return {} return _extract_function_invocation -def _get_dict_converter(obj: Any, - method_names: Optional[List[str]] = None) -> Union[Callable[[], Dict[str, Any]], None]: - method_names = method_names or ["model_dump", "dict"] # search for pydantic v2 then v1 +def _get_dict_converter( + obj: Any, method_names: Optional[List[str]] = None +) -> Union[Callable[[], Dict[str, Any]], None]: + method_names = method_names or [ + "model_dump", + "dict", + ] # search for pydantic v2 then v1 for attr in method_names: if hasattr(obj, attr) and callable(getattr(obj, attr)): return getattr(obj, attr) diff --git a/haystack_experimental/util/schema_conversion.py b/haystack_experimental/components/tools/openapi/schema_conversion.py similarity index 79% rename from haystack_experimental/util/schema_conversion.py rename to haystack_experimental/components/tools/openapi/schema_conversion.py index 6d17e344..5b7e24e8 100644 --- a/haystack_experimental/util/schema_conversion.py +++ b/haystack_experimental/components/tools/openapi/schema_conversion.py @@ -4,21 +4,18 @@ import logging from typing import Any, Callable, Dict, List, Optional + # SPDX-FileCopyrightText: 2022-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 - - import jsonref -VALID_HTTP_METHODS = ["get", "put", "post", "delete", "options", "head", "patch", "trace"] - MIN_REQUIRED_OPENAPI_SPEC_VERSION = 3 logger = logging.getLogger(__name__) -def openai_converter(schema: "OpenAPISpecification") -> List[Dict[str, Any]]: # noqa: F821 +def openai_converter(schema: "OpenAPISpecification") -> List[Dict[str, Any]]: # noqa: F821 """ Converts OpenAPI specification to a list of function suitable for OpenAI LLM function calling. @@ -26,11 +23,13 @@ def openai_converter(schema: "OpenAPISpecification") -> List[Dict[str, Any]]: # :return: A list of dictionaries, each representing a function definition. """ resolved_schema = jsonref.replace_refs(schema.spec_dict) - fn_definitions = _openapi_to_functions(resolved_schema, "parameters", _parse_endpoint_spec_openai) + fn_definitions = _openapi_to_functions( + resolved_schema, "parameters", _parse_endpoint_spec_openai + ) return [{"type": "function", "function": fn} for fn in fn_definitions] -def anthropic_converter(schema: "OpenAPISpecification") -> List[Dict[str, Any]]: # noqa: F821 +def anthropic_converter(schema: "OpenAPISpecification") -> List[Dict[str, Any]]: # noqa: F821 """ Converts an OpenAPI specification to a list of function definitions for Anthropic LLM function calling. @@ -38,10 +37,12 @@ def anthropic_converter(schema: "OpenAPISpecification") -> List[Dict[str, Any]]: :return: A list of dictionaries, each representing a function definition. """ resolved_schema = jsonref.replace_refs(schema.spec_dict) - return _openapi_to_functions(resolved_schema, "input_schema", _parse_endpoint_spec_openai) + return _openapi_to_functions( + resolved_schema, "input_schema", _parse_endpoint_spec_openai + ) -def cohere_converter(schema: "OpenAPISpecification") -> List[Dict[str, Any]]: # noqa: F821 +def cohere_converter(schema: "OpenAPISpecification") -> List[Dict[str, Any]]: # noqa: F821 """ Converts an OpenAPI specification to a list of function definitions for Cohere LLM function calling. @@ -49,7 +50,9 @@ def cohere_converter(schema: "OpenAPISpecification") -> List[Dict[str, Any]]: # :return: A list of dictionaries, each representing a function definition. """ resolved_schema = jsonref.replace_refs(schema.spec_dict) - return _openapi_to_functions(resolved_schema, "not important for cohere", _parse_endpoint_spec_cohere) + return _openapi_to_functions( + resolved_schema, "not important for cohere", _parse_endpoint_spec_cohere + ) def _openapi_to_functions( @@ -65,7 +68,9 @@ def _openapi_to_functions( # We check the version and require minimal fields to be present, so we can extract functions spec_version = service_openapi_spec.get("openapi") if not spec_version: - raise ValueError(f"Invalid OpenAPI spec provided. Could not extract version from {service_openapi_spec}") + raise ValueError( + f"Invalid OpenAPI spec provided. Could not extract version from {service_openapi_spec}" + ) service_openapi_spec_version = int(spec_version.split(".")[0]) # Compare the versions if service_openapi_spec_version < MIN_REQUIRED_OPENAPI_SPEC_VERSION: @@ -82,19 +87,26 @@ def _openapi_to_functions( return functions -def _parse_endpoint_spec_openai(resolved_spec: Dict[str, Any], parameters_name: str) -> Dict[str, Any]: +def _parse_endpoint_spec_openai( + resolved_spec: Dict[str, Any], parameters_name: str +) -> Dict[str, Any]: """ Parses an OpenAPI endpoint specification for OpenAI. """ if not isinstance(resolved_spec, dict): - logger.warning("Invalid OpenAPI spec format provided. Could not extract function.") + logger.warning( + "Invalid OpenAPI spec format provided. Could not extract function." + ) return {} function_name = resolved_spec.get("operationId") description = resolved_spec.get("description") or resolved_spec.get("summary", "") schema: Dict[str, Any] = {"type": "object", "properties": {}} # requestBody section req_body_schema = ( - resolved_spec.get("requestBody", {}).get("content", {}).get("application/json", {}).get("schema", {}) + resolved_spec.get("requestBody", {}) + .get("content", {}) + .get("application/json", {}) + .get("schema", {}) ) if "properties" in req_body_schema: for prop_name, prop_schema in req_body_schema["properties"].items(): @@ -108,14 +120,23 @@ def _parse_endpoint_spec_openai(resolved_spec: Dict[str, Any], parameters_name: schema_dict = _parse_property_attributes(param["schema"]) # these attributes are not in param[schema] level but on param level useful_attributes = ["description", "pattern", "enum"] - schema_dict.update({key: param[key] for key in useful_attributes if param.get(key)}) + schema_dict.update( + {key: param[key] for key in useful_attributes if param.get(key)} + ) schema["properties"][param["name"]] = schema_dict if param.get("required", False): schema.setdefault("required", []).append(param["name"]) if function_name and description and schema["properties"]: - return {"name": function_name, "description": description, parameters_name: schema} - logger.warning("Invalid OpenAPI spec format provided. Could not extract function from %s", resolved_spec) + return { + "name": function_name, + "description": description, + parameters_name: schema, + } + logger.warning( + "Invalid OpenAPI spec format provided. Could not extract function from %s", + resolved_spec, + ) return {} @@ -134,7 +155,8 @@ def _parse_property_attributes( if schema_type == "object": properties = property_schema.get("properties", {}) parsed_properties = { - prop_name: _parse_property_attributes(prop, include_attributes) for prop_name, prop in properties.items() + prop_name: _parse_property_attributes(prop, include_attributes) + for prop_name, prop in properties.items() } parsed_schema["properties"] = parsed_properties if "required" in property_schema: @@ -145,7 +167,9 @@ def _parse_property_attributes( return parsed_schema -def _parse_endpoint_spec_cohere(operation: Dict[str, Any], ignored_param: str) -> Dict[str, Any]: +def _parse_endpoint_spec_cohere( + operation: Dict[str, Any], ignored_param: str +) -> Dict[str, Any]: """ Parses an endpoint specification for Cohere. """ @@ -170,19 +194,27 @@ def _parse_parameters(operation: Dict[str, Any]) -> Dict[str, Any]: for param in operation.get("parameters", []): if "schema" in param: parameters[param["name"]] = _parse_schema( - param["schema"], param.get("required", False), param.get("description", "") + param["schema"], + param.get("required", False), + param.get("description", ""), ) if "requestBody" in operation: - content = operation["requestBody"].get("content", {}).get("application/json", {}) + content = ( + operation["requestBody"].get("content", {}).get("application/json", {}) + ) if "schema" in content: schema_properties = content["schema"].get("properties", {}) required_properties = content["schema"].get("required", []) for name, schema in schema_properties.items(): - parameters[name] = _parse_schema(schema, name in required_properties, schema.get("description", "")) + parameters[name] = _parse_schema( + schema, name in required_properties, schema.get("description", "") + ) return parameters -def _parse_schema(schema: Dict[str, Any], required: bool, description: str) -> Dict[str, Any]: # noqa: FBT001 +def _parse_schema( + schema: Dict[str, Any], required: bool, description: str +) -> Dict[str, Any]: # noqa: FBT001 """ Parses a schema part of an operation specification. """ diff --git a/haystack_experimental/util/openapi.py b/haystack_experimental/util/openapi.py deleted file mode 100644 index 94b7614b..00000000 --- a/haystack_experimental/util/openapi.py +++ /dev/null @@ -1,440 +0,0 @@ -# SPDX-FileCopyrightText: 2022-present deepset GmbH -# -# SPDX-License-Identifier: Apache-2.0 - -import json -import logging -import os -from base64 import b64encode -from dataclasses import dataclass, field -from pathlib import Path -from typing import Any, Callable, Dict, List, Literal, Optional, Union -from urllib.parse import urlparse - -import requests -import yaml - -from haystack_experimental.util.payload_extraction import create_function_payload_extractor -from haystack_experimental.util.schema_conversion import anthropic_converter, cohere_converter, openai_converter - -VALID_HTTP_METHODS = ["get", "put", "post", "delete", "options", "head", "patch", "trace"] -MIN_REQUIRED_OPENAPI_SPEC_VERSION = 3 -logger = logging.getLogger(__name__) - - -class AuthenticationStrategy: - """ - Represents an authentication strategy that can be applied to an HTTP request. - """ - def apply_auth(self, security_scheme: Dict[str, Any], request: Dict[str, Any]): - """ - Apply the authentication strategy to the given request. - - :param security_scheme: the security scheme from the OpenAPI spec. - :param request: the request to apply the authentication to. - """ - - -@dataclass -class ApiKeyAuthentication(AuthenticationStrategy): - """API key authentication strategy.""" - api_key: Optional[str] = None - - def apply_auth(self, security_scheme: Dict[str, Any], request: Dict[str, Any]): - """ - Apply the API key authentication strategy to the given request. - - :param security_scheme: the security scheme from the OpenAPI spec. - :param request: the request to apply the authentication to. - """ - if security_scheme["in"] == "header": - request.setdefault("headers", {})[security_scheme["name"]] = self.api_key - elif security_scheme["in"] == "query": - request.setdefault("params", {})[security_scheme["name"]] = self.api_key - elif security_scheme["in"] == "cookie": - request.setdefault("cookies", {})[security_scheme["name"]] = self.api_key - else: - raise ValueError( - f"Unsupported apiKey authentication location: {security_scheme['in']}, " - f"must be one of 'header', 'query', or 'cookie'" - ) - - -@dataclass -class HTTPAuthentication(AuthenticationStrategy): - """HTTP authentication strategy.""" - username: Optional[str] = None - password: Optional[str] = None - token: Optional[str] = None - - def __post_init__(self): - if not self.token and (not self.username or not self.password): - raise ValueError( - "For HTTP Basic Auth, both username and password must be provided. " - "For Bearer Auth, a token must be provided." - ) - - def apply_auth(self, security_scheme: Dict[str, Any], request: Dict[str, Any]): - """ - Apply the HTTP authentication strategy to the given request. - - :param security_scheme: the security scheme from the OpenAPI spec. - :param request: the request to apply the authentication to. - """ - if security_scheme["type"] == "http": - if security_scheme["scheme"].lower() == "basic": - if not self.username or not self.password: - raise ValueError("Username and password must be provided for Basic Auth.") - credentials = f"{self.username}:{self.password}" - encoded_credentials = b64encode(credentials.encode("utf-8")).decode("utf-8") - request.setdefault("headers", {})["Authorization"] = f"Basic {encoded_credentials}" - elif security_scheme["scheme"].lower() == "bearer": - if not self.token: - raise ValueError("Token must be provided for Bearer Auth.") - request.setdefault("headers", {})["Authorization"] = f"Bearer {self.token}" - else: - raise ValueError(f"Unsupported HTTP authentication scheme: {security_scheme['scheme']}") - else: - raise ValueError("HTTPAuthentication strategy received a non-HTTP security scheme.") - - -class HttpClientError(Exception): - """Exception raised for errors in the HTTP client.""" - - -@dataclass -class Operation: - """Represents an operation in an OpenAPI specification.""" - path: str - method: str - operation_dict: Dict[str, Any] - spec_dict: Dict[str, Any] - security_requirements: List[Dict[str, List[str]]] = field(init=False) - request_body: Dict[str, Any] = field(init=False) - parameters: List[Dict[str, Any]] = field(init=False) - - def __post_init__(self): - if self.method.lower() not in VALID_HTTP_METHODS: - raise ValueError(f"Invalid HTTP method: {self.method}") - self.method = self.method.lower() - self.security_requirements = self.operation_dict.get("security", []) or self.spec_dict.get("security", []) - self.request_body = self.operation_dict.get("requestBody", {}) - self.parameters = self.operation_dict.get("parameters", []) + self.spec_dict.get("paths", {}).get( - self.path, {} - ).get("parameters", []) - - def get_parameters(self, location: Optional[Literal["header", "query", "path"]] = None) -> List[Dict[str, Any]]: - """ - Get the parameters for the operation. - """ - if location: - return [param for param in self.parameters if param["in"] == location] - return self.parameters - - def get_server(self) -> str: - """ - Get the servers for the operation. - """ - servers = self.operation_dict.get("servers", []) or self.spec_dict.get("servers", []) - return servers[0].get("url", "") # just use the first server from the list - - -class OpenAPISpecification: - """Represents an OpenAPI specification.""" - - def __init__(self, spec_dict: Dict[str, Any]): - if not isinstance(spec_dict, Dict): - raise ValueError(f"Invalid OpenAPI specification, expected a dictionary: {spec_dict}") - # just a crude sanity check, by no means a full validation - if "openapi" not in spec_dict or "paths" not in spec_dict or "servers" not in spec_dict: - raise ValueError( - "Invalid OpenAPI specification format. See https://swagger.io/specification/ for details.", spec_dict - ) - self.spec_dict = spec_dict - - @classmethod - def from_dict(cls, spec_dict: Dict[str, Any]) -> "OpenAPISpecification": - """ - Create an OpenAPISpecification instance from a dictionary. - """ - parser = cls(spec_dict) - return parser - - @classmethod - def from_str(cls, content: str) -> "OpenAPISpecification": - """ - Create an OpenAPISpecification instance from a string. - """ - try: - loaded_spec = json.loads(content) - except json.JSONDecodeError: - try: - loaded_spec = yaml.safe_load(content) - except yaml.YAMLError as e: - raise ValueError("Content cannot be decoded as JSON or YAML: " + str(e)) from e - return cls(loaded_spec) - - @classmethod - def from_file(cls, spec_file: Union[str, Path]) -> "OpenAPISpecification": - """ - Create an OpenAPISpecification instance from a file. - """ - with open(spec_file, encoding="utf-8") as file: - content = file.read() - return cls.from_str(content) - - @classmethod - def from_url(cls, url: str) -> "OpenAPISpecification": - """ - Create an OpenAPISpecification instance from a URL. - """ - try: - response = requests.get(url, timeout=10) - response.raise_for_status() - content = response.text - except requests.RequestException as e: - raise ConnectionError(f"Failed to fetch the specification from URL: {url}. {e!s}") from e - return cls.from_str(content) - - def find_operation_by_id(self, op_id: str, method: Optional[str] = None) -> Operation: - """ - Find an operation by operationId. - """ - for path, path_item in self.spec_dict.get("paths", {}).items(): - op: Operation = self.get_operation_item(path, path_item, method) - if op_id in op.operation_dict.get("operationId", ""): - return self.get_operation_item(path, path_item, method) - raise ValueError(f"No operation found with operationId {op_id}") - - def get_operation_item(self, path: str, path_item: Dict[str, Any], method: Optional[str] = None) -> Operation: - """ - Get an operation item from the OpenAPI specification. - - :param path: The path of the operation. - :param path_item: The path item from the OpenAPI specification. - :param method: The HTTP method of the operation. - """ - if method: - operation_dict = path_item.get(method.lower(), {}) - if not operation_dict: - raise ValueError(f"No operation found for method {method} at path {path}") - return Operation(path, method.lower(), operation_dict, self.spec_dict) - if len(path_item) == 1: - method, operation_dict = next(iter(path_item.items())) - return Operation(path, method, operation_dict, self.spec_dict) - if len(path_item) > 1: - raise ValueError(f"Multiple operations found at path {path}, method parameter is required.") - raise ValueError(f"No operations found at path {path}.") - - def get_security_schemes(self) -> Dict[str, Dict[str, Any]]: - """ - Get the security schemes from the OpenAPI specification. - """ - return self.spec_dict.get("components", {}).get("securitySchemes", {}) - - -class ClientConfiguration: - """Configuration for the OpenAPI client.""" - - def __init__( # noqa: PLR0913 pylint: disable=too-many-arguments - self, - openapi_spec: Union[str, Path, Dict[str, Any]], - credentials: Optional[Union[str, Dict[str, Any], AuthenticationStrategy]] = None, - request_sender: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None, - llm_provider: Optional[str] = None, - ): # noqa: PLR0913 - if isinstance(openapi_spec, (str, Path)) and os.path.isfile(openapi_spec): - self.openapi_spec = OpenAPISpecification.from_file(openapi_spec) - elif isinstance(openapi_spec, dict): - self.openapi_spec = OpenAPISpecification.from_dict(openapi_spec) - elif isinstance(openapi_spec, str): - if self.is_valid_http_url(openapi_spec): - self.openapi_spec = OpenAPISpecification.from_url(openapi_spec) - else: - self.openapi_spec = OpenAPISpecification.from_str(openapi_spec) - else: - raise ValueError("Invalid OpenAPI specification format. Expected file path or dictionary.") - - self.credentials = credentials - self.request_sender = request_sender - self.llm_provider = llm_provider or "openai" - - def get_auth_config(self) -> AuthenticationStrategy: - """ - Get the authentication configuration. - """ - if not self.credentials: - return AuthenticationStrategy() - if isinstance(self.credentials, AuthenticationStrategy): - return self.credentials - security_schemes = self.openapi_spec.get_security_schemes() - if isinstance(self.credentials, str): - return self._create_authentication_from_string(self.credentials, security_schemes) - if isinstance(self.credentials, dict): - return self._create_authentication_from_dict(self.credentials) - raise ValueError(f"Unsupported credentials type: {type(self.credentials)}") - - def get_tools_definitions(self) -> List[Dict[str, Any]]: - """ - Get the tools definitions used as tools LLM parameter. - """ - provider_to_converter = {"anthropic": anthropic_converter, "cohere": cohere_converter} - converter = provider_to_converter.get(self.llm_provider, openai_converter) - return converter(self.openapi_spec) - - def get_payload_extractor(self): - """ - Get the payload extractor for the LLM provider. - """ - provider_to_arguments_field_name = {"anthropic": "input", "cohere": "parameters"} # add more providers here - # default to OpenAI "arguments" - arguments_field_name = provider_to_arguments_field_name.get(self.llm_provider, "arguments") - return create_function_payload_extractor(arguments_field_name) - - def _create_authentication_from_string( - self, credentials: str, security_schemes: Dict[str, Any] - ) -> AuthenticationStrategy: - for scheme in security_schemes.values(): - if scheme["type"] == "apiKey": - return ApiKeyAuthentication(api_key=credentials) - if scheme["type"] == "http": - return HTTPAuthentication(token=credentials) - if scheme["type"] == "oauth2": - raise NotImplementedError("OAuth2 authentication is not yet supported.") - raise ValueError(f"Unable to create authentication from provided credentials: {credentials}") - - def _create_authentication_from_dict(self, credentials: Dict[str, Any]) -> AuthenticationStrategy: - if "username" in credentials and "password" in credentials: - return HTTPAuthentication(username=credentials["username"], password=credentials["password"]) - if "api_key" in credentials: - return ApiKeyAuthentication(api_key=credentials["api_key"]) - if "token" in credentials: - return HTTPAuthentication(token=credentials["token"]) - if "access_token" in credentials: - raise NotImplementedError("OAuth2 authentication is not yet supported.") - raise ValueError("Unable to create authentication from provided credentials: {credentials}") - - def is_valid_http_url(self, url: str) -> bool: - """Check if a URL is a valid HTTP/HTTPS URL.""" - r = urlparse(url) - return all([r.scheme in ["http", "https"], r.netloc]) - - -class OpenAPIServiceClient: - """ - A client for invoking operations on REST services defined by OpenAPI specifications. - - Together with the `ClientConfiguration`, its `ClientConfigurationBuilder`, the `OpenAPIServiceClient` - simplifies the process of (LLMs) with services defined by OpenAPI specifications. - """ - - def __init__(self, client_config: ClientConfiguration): - self.client_config = client_config - self.request_sender = client_config.request_sender or self._request_sender() - - def invoke(self, function_payload: Any) -> Any: - """ - Invokes a function specified in the function payload. - - :param function_payload: The function payload containing the details of the function to be invoked. - :returns: The response from the service after invoking the function. - :raises OpenAPIClientError: If the function invocation payload cannot be extracted from the function payload. - :raises HttpClientError: If an error occurs while sending the request and receiving the response. - """ - fn_extractor = self.client_config.get_payload_extractor() - fn_invocation_payload = fn_extractor(function_payload) - if not fn_invocation_payload: - raise OpenAPIClientError( - f"Failed to extract function invocation payload from {function_payload}" - ) - # fn_invocation_payload, if not empty, guaranteed to have "name" and "arguments" keys from here on - operation = self.client_config.openapi_spec.find_operation_by_id(fn_invocation_payload.get("name")) - request = self._build_request(operation, **fn_invocation_payload.get("arguments")) - self._apply_authentication(self.client_config.get_auth_config(), operation, request) - return self.request_sender(request) - - def _request_sender(self) -> Callable[[Dict[str, Any]], Dict[str, Any]]: - """ - Returns a callable that sends the request using the HTTP client. - """ - def send_request(request: Dict[str, Any]) -> Dict[str, Any]: - url = request["url"] - headers = {**request.get("headers", {})} - try: - response = requests.request( - request["method"], - url, - headers=headers, - params=request.get("params", {}), - json=request.get("json"), - auth=request.get("auth"), - timeout=10, - ) - response.raise_for_status() - return response.json() - except requests.exceptions.HTTPError as e: - logger.warning("HTTP error occurred: %s while sending request to %s", e, url) - raise HttpClientError(f"HTTP error occurred: {e}") from e - except requests.exceptions.RequestException as e: - logger.warning("Request error occurred: %s while sending request to %s", e, url) - raise HttpClientError(f"HTTP error occurred: {e}") from e - except Exception as e: - logger.warning("An error occurred: %s while sending request to %s", e, url) - raise HttpClientError(f"An error occurred: {e}") from e - - return send_request - - def _build_request(self, operation: Operation, **kwargs) -> Any: - # url - path = operation.path - for parameter in operation.get_parameters("path"): - param_value = kwargs.get(parameter["name"], None) - if param_value: - path = path.replace(f"{{{parameter['name']}}}", str(param_value)) - elif parameter.get("required", False): - raise ValueError(f"Missing required path parameter: {parameter['name']}") - url = operation.get_server() + path - # method - method = operation.method.lower() - # headers - headers = {} - for parameter in operation.get_parameters("header"): - param_value = kwargs.get(parameter["name"], None) - if param_value: - headers[parameter["name"]] = str(param_value) - elif parameter.get("required", False): - raise ValueError(f"Missing required header parameter: {parameter['name']}") - # query params - query_params = {} - for parameter in operation.get_parameters("query"): - param_value = kwargs.get(parameter["name"], None) - if param_value: - query_params[parameter["name"]] = param_value - elif parameter.get("required", False): - raise ValueError(f"Missing required query parameter: {parameter['name']}") - - json_payload = None - request_body = operation.request_body - if request_body: - content = request_body.get("content", {}) - if "application/json" in content: - json_payload = {**kwargs} - else: - raise NotImplementedError("Request body content type not supported") - return {"url": url, "method": method, "headers": headers, "params": query_params, "json": json_payload} - - def _apply_authentication(self, auth: AuthenticationStrategy, operation: Operation, request: Dict[str, Any]): - auth_config = auth or AuthenticationStrategy() - security_requirements = operation.security_requirements - security_schemes = operation.spec_dict.get("components", {}).get("securitySchemes", {}) - if security_requirements: - for requirement in security_requirements: - for scheme_name in requirement: - if scheme_name in security_schemes: - security_scheme = security_schemes[scheme_name] - auth_config.apply_auth(security_scheme, request) - break - - -class OpenAPIClientError(Exception): - """Exception raised for errors in the OpenAPI client.""" diff --git a/test/components/tools/openapi/__init__.py b/test/components/tools/openapi/__init__.py new file mode 100644 index 00000000..c1764a6e --- /dev/null +++ b/test/components/tools/openapi/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 diff --git a/test/util/conftest.py b/test/components/tools/openapi/conftest.py similarity index 90% rename from test/util/conftest.py rename to test/components/tools/openapi/conftest.py index 05bd940a..2df4f76c 100644 --- a/test/util/conftest.py +++ b/test/components/tools/openapi/conftest.py @@ -10,12 +10,12 @@ from fastapi import FastAPI from fastapi.testclient import TestClient -from haystack_experimental.util.openapi import HttpClientError +from haystack_experimental.components.tools.openapi.openapi import HttpClientError @pytest.fixture() def test_files_path(): - return Path(__file__).parent.parent / "test_files" + return Path(__file__).parent.parent.parent.parent / "test_files" class FastAPITestClient: diff --git a/test/util/test_openapi_client.py b/test/components/tools/openapi/test_openapi_client.py similarity index 95% rename from test/util/test_openapi_client.py rename to test/components/tools/openapi/test_openapi_client.py index 66a74056..4842b75e 100644 --- a/test/util/test_openapi_client.py +++ b/test/components/tools/openapi/test_openapi_client.py @@ -7,8 +7,8 @@ from fastapi.responses import JSONResponse from pydantic import BaseModel -from haystack_experimental.util.openapi import OpenAPIServiceClient, ClientConfiguration -from test.util.conftest import FastAPITestClient +from haystack_experimental.components.tools.openapi.openapi import OpenAPIServiceClient, ClientConfiguration +from test.components.tools.openapi.conftest import FastAPITestClient """ Tests OpenAPIServiceClient with three FastAPI apps for different parameter types: diff --git a/test/util/test_openapi_client_auth.py b/test/components/tools/openapi/test_openapi_client_auth.py similarity index 97% rename from test/util/test_openapi_client_auth.py rename to test/components/tools/openapi/test_openapi_client_auth.py index 014edde5..de91855e 100644 --- a/test/util/test_openapi_client_auth.py +++ b/test/components/tools/openapi/test_openapi_client_auth.py @@ -15,9 +15,9 @@ HTTPBearer, ) -from haystack_experimental.util.openapi import OpenAPIServiceClient, ApiKeyAuthentication, \ +from haystack_experimental.components.tools.openapi.openapi import OpenAPIServiceClient, ApiKeyAuthentication, \ HTTPAuthentication, ClientConfiguration -from test.util.conftest import FastAPITestClient +from test.components.tools.openapi.conftest import FastAPITestClient API_KEY = "secret_api_key" BASIC_AUTH_USERNAME = "admin" diff --git a/test/util/test_openapi_client_complex_request_body.py b/test/components/tools/openapi/test_openapi_client_complex_request_body.py similarity index 92% rename from test/util/test_openapi_client_complex_request_body.py rename to test/components/tools/openapi/test_openapi_client_complex_request_body.py index ccfb63df..e6007efb 100644 --- a/test/util/test_openapi_client_complex_request_body.py +++ b/test/components/tools/openapi/test_openapi_client_complex_request_body.py @@ -11,8 +11,8 @@ from fastapi.responses import JSONResponse from pydantic import BaseModel -from haystack_experimental.util.openapi import OpenAPIServiceClient, ClientConfiguration -from test.util.conftest import FastAPITestClient +from haystack_experimental.components.tools.openapi.openapi import OpenAPIServiceClient, ClientConfiguration +from test.components.tools.openapi.conftest import FastAPITestClient class Customer(BaseModel): diff --git a/test/util/test_openapi_client_complex_request_body_mixed.py b/test/components/tools/openapi/test_openapi_client_complex_request_body_mixed.py similarity index 92% rename from test/util/test_openapi_client_complex_request_body_mixed.py rename to test/components/tools/openapi/test_openapi_client_complex_request_body_mixed.py index 4dc532fc..33e95387 100644 --- a/test/util/test_openapi_client_complex_request_body_mixed.py +++ b/test/components/tools/openapi/test_openapi_client_complex_request_body_mixed.py @@ -9,8 +9,8 @@ from fastapi.responses import JSONResponse from pydantic import BaseModel -from haystack_experimental.util.openapi import OpenAPIServiceClient, ClientConfiguration -from test.util.conftest import FastAPITestClient +from haystack_experimental.components.tools.openapi.openapi import OpenAPIServiceClient, ClientConfiguration +from test.components.tools.openapi.conftest import FastAPITestClient class Identification(BaseModel): diff --git a/test/util/test_openapi_client_edge_cases.py b/test/components/tools/openapi/test_openapi_client_edge_cases.py similarity index 82% rename from test/util/test_openapi_client_edge_cases.py rename to test/components/tools/openapi/test_openapi_client_edge_cases.py index 888580d9..912fe7d5 100644 --- a/test/util/test_openapi_client_edge_cases.py +++ b/test/components/tools/openapi/test_openapi_client_edge_cases.py @@ -5,8 +5,8 @@ import pytest -from haystack_experimental.util.openapi import OpenAPIServiceClient, ClientConfiguration -from test.util.conftest import FastAPITestClient +from haystack_experimental.components.tools.openapi.openapi import OpenAPIServiceClient, ClientConfiguration +from test.components.tools.openapi.conftest import FastAPITestClient class TestEdgeCases: diff --git a/test/util/test_openapi_client_error_handling.py b/test/components/tools/openapi/test_openapi_client_error_handling.py similarity index 87% rename from test/util/test_openapi_client_error_handling.py rename to test/components/tools/openapi/test_openapi_client_error_handling.py index d826e33b..5b6e8dc4 100644 --- a/test/util/test_openapi_client_error_handling.py +++ b/test/components/tools/openapi/test_openapi_client_error_handling.py @@ -8,9 +8,9 @@ import pytest from fastapi import FastAPI, HTTPException -from haystack_experimental.util.openapi import OpenAPIServiceClient, HttpClientError, \ +from haystack_experimental.components.tools.openapi.openapi import OpenAPIServiceClient, HttpClientError, \ ClientConfiguration -from test.util.conftest import FastAPITestClient +from test.components.tools.openapi.conftest import FastAPITestClient def create_error_handling_app() -> FastAPI: diff --git a/test/util/test_openapi_client_live.py b/test/components/tools/openapi/test_openapi_client_live.py similarity index 95% rename from test/util/test_openapi_client_live.py rename to test/components/tools/openapi/test_openapi_client_live.py index ec80b83b..1ee5b9f4 100644 --- a/test/util/test_openapi_client_live.py +++ b/test/components/tools/openapi/test_openapi_client_live.py @@ -7,7 +7,7 @@ import pytest import yaml -from haystack_experimental.util.openapi import OpenAPIServiceClient, ClientConfiguration +from haystack_experimental.components.tools.openapi.openapi import OpenAPIServiceClient, ClientConfiguration class TestClientLive: diff --git a/test/util/test_openapi_client_live_anthropic.py b/test/components/tools/openapi/test_openapi_client_live_anthropic.py similarity index 95% rename from test/util/test_openapi_client_live_anthropic.py rename to test/components/tools/openapi/test_openapi_client_live_anthropic.py index 62c2a853..ede1503a 100644 --- a/test/util/test_openapi_client_live_anthropic.py +++ b/test/components/tools/openapi/test_openapi_client_live_anthropic.py @@ -7,7 +7,7 @@ import anthropic import pytest -from haystack_experimental.util.openapi import ClientConfiguration, OpenAPIServiceClient +from haystack_experimental.components.tools.openapi.openapi import ClientConfiguration, OpenAPIServiceClient class TestClientLiveAnthropic: diff --git a/test/util/test_openapi_client_live_cohere.py b/test/components/tools/openapi/test_openapi_client_live_cohere.py similarity index 96% rename from test/util/test_openapi_client_live_cohere.py rename to test/components/tools/openapi/test_openapi_client_live_cohere.py index 4a316d83..4cd87631 100644 --- a/test/util/test_openapi_client_live_cohere.py +++ b/test/components/tools/openapi/test_openapi_client_live_cohere.py @@ -6,7 +6,7 @@ import cohere import pytest -from haystack_experimental.util.openapi import ClientConfiguration, OpenAPIServiceClient +from haystack_experimental.components.tools.openapi.openapi import ClientConfiguration, OpenAPIServiceClient # Copied from Cohere's documentation preamble = """ diff --git a/test/util/test_openapi_client_live_openai.py b/test/components/tools/openapi/test_openapi_client_live_openai.py similarity index 97% rename from test/util/test_openapi_client_live_openai.py rename to test/components/tools/openapi/test_openapi_client_live_openai.py index cf723939..4f2e2c39 100644 --- a/test/util/test_openapi_client_live_openai.py +++ b/test/components/tools/openapi/test_openapi_client_live_openai.py @@ -7,7 +7,7 @@ import pytest from openai import OpenAI -from haystack_experimental.util.openapi import ClientConfiguration, OpenAPIServiceClient +from haystack_experimental.components.tools.openapi.openapi import ClientConfiguration, OpenAPIServiceClient class TestClientLiveOpenAPI: diff --git a/test/util/test_openapi_cohere_conversion.py b/test/components/tools/openapi/test_openapi_cohere_conversion.py similarity index 97% rename from test/util/test_openapi_cohere_conversion.py rename to test/components/tools/openapi/test_openapi_cohere_conversion.py index 5afb1673..dd84b9b5 100644 --- a/test/util/test_openapi_cohere_conversion.py +++ b/test/components/tools/openapi/test_openapi_cohere_conversion.py @@ -2,7 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 -from haystack_experimental.util.openapi import OpenAPISpecification, cohere_converter +from haystack_experimental.components.tools.openapi.openapi import OpenAPISpecification, cohere_converter class TestOpenAPISchemaConversion: diff --git a/test/util/test_openapi_openai_conversion.py b/test/components/tools/openapi/test_openapi_openai_conversion.py similarity index 97% rename from test/util/test_openapi_openai_conversion.py rename to test/components/tools/openapi/test_openapi_openai_conversion.py index fb3afa0e..3bf7cc19 100644 --- a/test/util/test_openapi_openai_conversion.py +++ b/test/components/tools/openapi/test_openapi_openai_conversion.py @@ -4,7 +4,7 @@ import pytest -from haystack_experimental.util.openapi import openai_converter, anthropic_converter, OpenAPISpecification +from haystack_experimental.components.tools.openapi.openapi import openai_converter, anthropic_converter, OpenAPISpecification class TestOpenAPISchemaConversion: diff --git a/test/util/test_openapi_spec.py b/test/components/tools/openapi/test_openapi_spec.py similarity index 97% rename from test/util/test_openapi_spec.py rename to test/components/tools/openapi/test_openapi_spec.py index 5567b2e7..93e2a972 100644 --- a/test/util/test_openapi_spec.py +++ b/test/components/tools/openapi/test_openapi_spec.py @@ -2,11 +2,9 @@ # # SPDX-License-Identifier: Apache-2.0 -import json - import pytest -from haystack_experimental.util.openapi import OpenAPISpecification +from haystack_experimental.components.tools.openapi.openapi import OpenAPISpecification class TestOpenAPISpecification: From 14d38e95ae6f55ba70b259d166afdb1d1dda6fad Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Wed, 5 Jun 2024 17:39:08 +0200 Subject: [PATCH 09/40] Fix linting --- .../components/tools/openapi/generator_factory.py | 10 +++++----- .../components/tools/openapi/schema_conversion.py | 6 +++--- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/haystack_experimental/components/tools/openapi/generator_factory.py b/haystack_experimental/components/tools/openapi/generator_factory.py index 78bb429f..a0d4c36e 100644 --- a/haystack_experimental/components/tools/openapi/generator_factory.py +++ b/haystack_experimental/components/tools/openapi/generator_factory.py @@ -6,7 +6,7 @@ import re from dataclasses import dataclass from enum import Enum -from typing import Any, Optional, Tuple +from typing import Any, Dict, Optional, Tuple class LLMProvider(Enum): @@ -18,7 +18,7 @@ class LLMProvider(Enum): COHERE = "cohere" -PROVIDER_DETAILS = { +PROVIDER_DETAILS: Dict[LLMProvider, Dict[str, Any]] = { LLMProvider.OPENAI: { "class_path": "haystack.components.generators.chat.openai.OpenAIChatGenerator", "patterns": [re.compile(r"^gpt.*")], @@ -71,19 +71,19 @@ def create_generator( """ Create ChatGenerator instance based on the model name and provider. """ + provider_enum = None if provider: if provider.lower() not in LLMProvider.__members__: raise ValueError(f"Invalid provider: {provider}") provider_enum = LLMProvider[provider.lower()] else: - provider_enum = None for prov, details in PROVIDER_DETAILS.items(): if any(pattern.match(model_name) for pattern in details["patterns"]): provider_enum = prov break - if provider_enum is None: - raise ValueError(f"Could not infer provider for model name: {model_name}") + if provider_enum is None: + raise ValueError(f"Could not infer provider for model name: {model_name}") llm_identifier = LLMIdentifier(provider=provider_enum, model_name=model_name) class_path = PROVIDER_DETAILS[llm_identifier.provider]["class_path"] diff --git a/haystack_experimental/components/tools/openapi/schema_conversion.py b/haystack_experimental/components/tools/openapi/schema_conversion.py index 5b7e24e8..b74bc1cf 100644 --- a/haystack_experimental/components/tools/openapi/schema_conversion.py +++ b/haystack_experimental/components/tools/openapi/schema_conversion.py @@ -15,7 +15,7 @@ logger = logging.getLogger(__name__) -def openai_converter(schema: "OpenAPISpecification") -> List[Dict[str, Any]]: # noqa: F821 +def openai_converter(schema: "OpenAPISpecification") -> List[Dict[str, Any]]: # type: ignore[name-defined] # noqa: F821 """ Converts OpenAPI specification to a list of function suitable for OpenAI LLM function calling. @@ -29,7 +29,7 @@ def openai_converter(schema: "OpenAPISpecification") -> List[Dict[str, Any]]: # return [{"type": "function", "function": fn} for fn in fn_definitions] -def anthropic_converter(schema: "OpenAPISpecification") -> List[Dict[str, Any]]: # noqa: F821 +def anthropic_converter(schema: "OpenAPISpecification") -> List[Dict[str, Any]]: # type: ignore # noqa: F821 """ Converts an OpenAPI specification to a list of function definitions for Anthropic LLM function calling. @@ -42,7 +42,7 @@ def anthropic_converter(schema: "OpenAPISpecification") -> List[Dict[str, Any]]: ) -def cohere_converter(schema: "OpenAPISpecification") -> List[Dict[str, Any]]: # noqa: F821 +def cohere_converter(schema: "OpenAPISpecification") -> List[Dict[str, Any]]: # type: ignore[name-defined] # noqa: F821 """ Converts an OpenAPI specification to a list of function definitions for Cohere LLM function calling. From 406aa1b1e39d1cb26d793b443cbc6fbfafbe0e03 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Wed, 5 Jun 2024 21:25:00 +0200 Subject: [PATCH 10/40] Refactoring step 6 - simplify generator factory --- .../tools/openapi/generator_factory.py | 161 ++++++++++-------- .../components/tools/openapi/openapi_tool.py | 41 +++-- .../tools/openapi/payload_extraction.py | 6 +- 3 files changed, 119 insertions(+), 89 deletions(-) diff --git a/haystack_experimental/components/tools/openapi/generator_factory.py b/haystack_experimental/components/tools/openapi/generator_factory.py index a0d4c36e..0fb992ab 100644 --- a/haystack_experimental/components/tools/openapi/generator_factory.py +++ b/haystack_experimental/components/tools/openapi/generator_factory.py @@ -5,88 +5,101 @@ import importlib import re from dataclasses import dataclass -from enum import Enum -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple -class LLMProvider(Enum): - """ - Enum for different LLM providers - """ - OPENAI = "openai" - ANTHROPIC = "anthropic" - COHERE = "cohere" - - -PROVIDER_DETAILS: Dict[LLMProvider, Dict[str, Any]] = { - LLMProvider.OPENAI: { - "class_path": "haystack.components.generators.chat.openai.OpenAIChatGenerator", - "patterns": [re.compile(r"^gpt.*")], - }, - LLMProvider.ANTHROPIC: { - "class_path": "haystack_integrations.components.generators.anthropic.AnthropicChatGenerator", - "patterns": [re.compile(r"^claude.*")], - }, - LLMProvider.COHERE: { - "class_path": "haystack_integrations.components.generators.cohere.CohereChatGenerator", - "patterns": [re.compile(r"^command-r.*")], - }, -} - - -def load_class(full_class_path: str): +@dataclass +class ChatGeneratorDescriptor: """ - Load a class from a string representation of its path e.g. "module.submodule.class_name" + Dataclass to describe a Chat Generator """ - module_path, _, class_name = full_class_path.rpartition(".") - module = importlib.import_module(module_path) - return getattr(module, class_name) - -@dataclass -class LLMIdentifier: - """ Dataclass to hold the LLM provider and model name""" - provider: LLMProvider + class_path: str + patterns: List[re.Pattern] + name: str model_name: str - def __post_init__(self): - if not isinstance(self.provider, LLMProvider): - raise ValueError(f"Invalid provider: {self.provider}") - - if not isinstance(self.model_name, str): - raise ValueError(f"Model name must be a string: {self.model_name}") - details = PROVIDER_DETAILS.get(self.provider) - if not details or not any( - pattern.match(self.model_name) for pattern in details["patterns"] - ): - raise ValueError( - f"Invalid combination of provider {self.provider} and model name {self.model_name}" - ) - - -def create_generator( - model_name: str, provider: Optional[str] = None, **model_kwargs -) -> Tuple[LLMIdentifier, Any]: +class ChatGeneratorDescriptorManager: """ - Create ChatGenerator instance based on the model name and provider. + Class to manage Chat Generator Descriptors """ - provider_enum = None - if provider: - if provider.lower() not in LLMProvider.__members__: - raise ValueError(f"Invalid provider: {provider}") - provider_enum = LLMProvider[provider.lower()] - else: - for prov, details in PROVIDER_DETAILS.items(): - if any(pattern.match(model_name) for pattern in details["patterns"]): - provider_enum = prov - break - - if provider_enum is None: - raise ValueError(f"Could not infer provider for model name: {model_name}") - llm_identifier = LLMIdentifier(provider=provider_enum, model_name=model_name) - class_path = PROVIDER_DETAILS[llm_identifier.provider]["class_path"] - return llm_identifier, load_class(class_path)( - model=llm_identifier.model_name, **model_kwargs - ) + def __init__(self): + self._descriptors: Dict[str, ChatGeneratorDescriptor] = {} + self._register_default_descriptors() + + def _register_default_descriptors(self): + """ + Register default Chat Generator Descriptors. + """ + default_descriptors = [ + ChatGeneratorDescriptor( + class_path="haystack.components.generators.chat.openai.OpenAIChatGenerator", + patterns=[re.compile(r"^gpt.*")], + name="openai", + model_name="gpt-3.5-turbo", + ), + ChatGeneratorDescriptor( + class_path="haystack_integrations.components.generators.anthropic.AnthropicChatGenerator", + patterns=[re.compile(r"^claude.*")], + name="anthropic", + model_name="claude-1", + ), + ChatGeneratorDescriptor( + class_path="haystack_integrations.components.generators.cohere.CohereChatGenerator", + patterns=[re.compile(r"^command-r.*")], + name="cohere", + model_name="command-r", + ), + ] + + for descriptor in default_descriptors: + self.register_descriptor(descriptor) + + def _load_class(self, full_class_path: str): + """ + Load a class from a string representation of its path e.g. "module.submodule.class_name" + """ + module_path, _, class_name = full_class_path.rpartition(".") + module = importlib.import_module(module_path) + return getattr(module, class_name) + + def register_descriptor(self, descriptor: ChatGeneratorDescriptor): + """ + Register a new Chat Generator Descriptor. + """ + if descriptor.name in self._descriptors: + raise ValueError(f"Descriptor {descriptor.name} already exists.") + + self._descriptors[descriptor.name] = descriptor + + def _infer_descriptor(self, model_name: str) -> Optional[ChatGeneratorDescriptor]: + """ + Infer the descriptor based on the model name. + """ + for descriptor in self._descriptors.values(): + if any(pattern.match(model_name) for pattern in descriptor.patterns): + return descriptor + return None + + def create_generator( + self, model_name: str, descriptor_name: Optional[str] = None, **model_kwargs + ) -> Tuple[ChatGeneratorDescriptor, Any]: + """ + Create ChatGenerator instance based on the model name and descriptor. + """ + if descriptor_name: + descriptor = self._descriptors.get(descriptor_name) + if not descriptor: + raise ValueError(f"Invalid descriptor name: {descriptor_name}") + else: + descriptor = self._infer_descriptor(model_name) + if not descriptor: + raise ValueError( + f"Could not infer descriptor for model name: {model_name}" + ) + + return descriptor, self._load_class(descriptor.class_path)( + model=model_name, **model_kwargs + ) diff --git a/haystack_experimental/components/tools/openapi/openapi_tool.py b/haystack_experimental/components/tools/openapi/openapi_tool.py index 0b2e1ac7..dc232941 100644 --- a/haystack_experimental/components/tools/openapi/openapi_tool.py +++ b/haystack_experimental/components/tools/openapi/openapi_tool.py @@ -9,7 +9,9 @@ from haystack import component, logging from haystack.dataclasses import ChatMessage, ChatRole -from haystack_experimental.components.tools.openapi.generator_factory import create_generator +from haystack_experimental.components.tools.openapi.generator_factory import ( + ChatGeneratorDescriptorManager, +) from haystack_experimental.components.tools.openapi.openapi import ( ClientConfiguration, OpenAPIServiceClient, @@ -48,16 +50,20 @@ def __init__( tool_spec: Optional[Union[str, Path]] = None, tool_credentials: Optional[Union[str, Dict[str, Any]]] = None, ): - self.llm_id, self.chat_generator = create_generator(model) - self.config_openapi = ( - ClientConfiguration( + + manager = ChatGeneratorDescriptorManager() + self.descriptor, self.chat_generator = manager.create_generator( + model_name=model + ) + self.config_openapi: Optional[ClientConfiguration] = None + self.open_api_service: Optional[OpenAPIServiceClient] = None + if tool_spec: + self.config_openapi = ClientConfiguration( openapi_spec=tool_spec, credentials=tool_credentials, - llm_provider=self.llm_id.provider.value, + llm_provider=self.descriptor.name, ) - if tool_spec - else None - ) + self.open_api_service = OpenAPIServiceClient(self.config_openapi) @component.output_types(service_response=List[ChatMessage]) def run( @@ -88,17 +94,29 @@ def run( ClientConfiguration( openapi_spec=tool_spec, credentials=tool_credentials, - llm_provider=self.llm_id.provider.value, + llm_provider=self.descriptor.name, ) if tool_spec else self.config_openapi ) - if not config_openapi: + openapi_service: Optional[OpenAPIServiceClient] = self.open_api_service + if tool_spec: + config_openapi = ClientConfiguration( + openapi_spec=tool_spec, + credentials=tool_credentials, + llm_provider=self.descriptor.name, + ) + openapi_service = OpenAPIServiceClient(config_openapi) + else: + config_openapi = self.config_openapi + + if not openapi_service or not config_openapi: raise ValueError( "OpenAPI specification not provided. Please provide an OpenAPI specification either at initialization " "or during runtime." ) + # merge fc_generator_kwargs, tools definitions comes from the OpenAPI spec, other kwargs are passed by the user fc_generator_kwargs = { "tools": config_openapi.get_tools_definitions(), @@ -110,11 +128,10 @@ def run( f"Invoking chat generator with {last_message.content} to generate function calling payload." ) fc_payload = self.chat_generator.run(messages, fc_generator_kwargs) - - openapi_service = OpenAPIServiceClient(config_openapi) try: invocation_payload = json.loads(fc_payload["replies"][0].content) logger.debug(f"Invoking tool with {invocation_payload}") + # openapi_service is never None here, ignore mypy error service_response = openapi_service.invoke(invocation_payload) except Exception as e: # pylint: disable=broad-exception-caught logger.error(f"Error invoking OpenAPI endpoint. Error: {e}") diff --git a/haystack_experimental/components/tools/openapi/payload_extraction.py b/haystack_experimental/components/tools/openapi/payload_extraction.py index 157070ae..416841bf 100644 --- a/haystack_experimental/components/tools/openapi/payload_extraction.py +++ b/haystack_experimental/components/tools/openapi/payload_extraction.py @@ -27,9 +27,9 @@ def _extract_function_invocation(payload: Any) -> Dict[str, Any]: ) return { "name": fields_and_values.get("name"), - "arguments": json.loads(arguments) - if isinstance(arguments, str) - else arguments, + "arguments": ( + json.loads(arguments) if isinstance(arguments, str) else arguments + ), } return {} From e5db2d0d6a2a67b6c52b0686299e0d82e3a4f382 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Wed, 5 Jun 2024 21:45:12 +0200 Subject: [PATCH 11/40] Cosmetics --- .../components/tools/openapi/openapi_tool.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/haystack_experimental/components/tools/openapi/openapi_tool.py b/haystack_experimental/components/tools/openapi/openapi_tool.py index dc232941..1f28240c 100644 --- a/haystack_experimental/components/tools/openapi/openapi_tool.py +++ b/haystack_experimental/components/tools/openapi/openapi_tool.py @@ -32,7 +32,7 @@ class OpenAPITool: from haystack.components.generators.chat.openai import OpenAIChatGenerator from haystack.dataclasses import ChatMessage - tool = OpenAPITool(llm_provider="openai", model="gpt-3.5-turbo", + tool = OpenAPITool(model="gpt-3.5-turbo", tool_spec="https://raw.githubusercontent.com/mendableai/firecrawl/main/apps/api/openapi.json", tool_credentials="") @@ -40,7 +40,7 @@ class OpenAPITool: print(results) ``` - Similarly, you can use the OpenAPITool component to use any OpenAPI service/tool by providing the OpenAPI + Similarly, you can use the OpenAPITool component to invoke **any** OpenAPI service/tool by providing the OpenAPI specification and credentials. """ @@ -125,16 +125,16 @@ def run( # generate function calling payload with the chat generator logger.debug( - f"Invoking chat generator with {last_message.content} to generate function calling payload." + "Invoking chat generator with {message} to generate function calling payload.", + message=last_message.content, ) fc_payload = self.chat_generator.run(messages, fc_generator_kwargs) try: invocation_payload = json.loads(fc_payload["replies"][0].content) - logger.debug(f"Invoking tool with {invocation_payload}") - # openapi_service is never None here, ignore mypy error + logger.debug("Invoking tool with {payload}", payload=invocation_payload) service_response = openapi_service.invoke(invocation_payload) except Exception as e: # pylint: disable=broad-exception-caught - logger.error(f"Error invoking OpenAPI endpoint. Error: {e}") + logger.error("Error invoking OpenAPI endpoint. Error: {e}", e=str(e)) service_response = {"error": str(e)} response_messages = [ChatMessage.from_user(json.dumps(service_response))] return {"service_response": response_messages} From 4b35fdf7be946618f3487a66e0182936f2f0ba08 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Wed, 5 Jun 2024 21:59:13 +0200 Subject: [PATCH 12/40] Add model_kwargs to OpenAPITool init --- .../components/tools/openapi/generator_factory.py | 2 +- .../components/tools/openapi/openapi_tool.py | 14 +++++++++++--- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/haystack_experimental/components/tools/openapi/generator_factory.py b/haystack_experimental/components/tools/openapi/generator_factory.py index 0fb992ab..28401863 100644 --- a/haystack_experimental/components/tools/openapi/generator_factory.py +++ b/haystack_experimental/components/tools/openapi/generator_factory.py @@ -101,5 +101,5 @@ def create_generator( ) return descriptor, self._load_class(descriptor.class_path)( - model=model_name, **model_kwargs + model=model_name, **(model_kwargs or {}) ) diff --git a/haystack_experimental/components/tools/openapi/openapi_tool.py b/haystack_experimental/components/tools/openapi/openapi_tool.py index 1f28240c..010ff51c 100644 --- a/haystack_experimental/components/tools/openapi/openapi_tool.py +++ b/haystack_experimental/components/tools/openapi/openapi_tool.py @@ -49,11 +49,19 @@ def __init__( model: str, tool_spec: Optional[Union[str, Path]] = None, tool_credentials: Optional[Union[str, Dict[str, Any]]] = None, + model_kwargs: Optional[Dict[str, Any]] = None, ): + """ + Initialize the OpenAPITool component. + :param model: Name of the chat generator model to use. + :param tool_spec: OpenAPI specification for the tool/service. + :param tool_credentials: Credentials for the tool/service. + :param model_kwargs: Additional arguments for the chat generator model. + """ manager = ChatGeneratorDescriptorManager() self.descriptor, self.chat_generator = manager.create_generator( - model_name=model + model_name=model, **(model_kwargs or {}) ) self.config_openapi: Optional[ClientConfiguration] = None self.open_api_service: Optional[OpenAPIServiceClient] = None @@ -78,8 +86,8 @@ def run( :param messages: List of ChatMessages to generate function calling payload (e.g. human instructions). :param fc_generator_kwargs: Additional arguments for the function calling payload generation process. - :param tool_spec: OpenAPI specification for the tool/service. - :param tool_credentials: Credentials for the tool/service. + :param tool_spec: OpenAPI specification for the tool/service, overrides the one provided at initialization. + :param tool_credentials: Credentials for the tool/service, overrides the one provided at initialization. :returns: a dictionary containing the service response with the following key: - `service_response`: List of ChatMessages containing the service response. """ From 7bca18f095382ad006e08fef56e37bacfde77053 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Wed, 5 Jun 2024 22:30:14 +0200 Subject: [PATCH 13/40] Fix double ClientConfiguration creation --- .../components/tools/openapi/openapi_tool.py | 15 ++------------- 1 file changed, 2 insertions(+), 13 deletions(-) diff --git a/haystack_experimental/components/tools/openapi/openapi_tool.py b/haystack_experimental/components/tools/openapi/openapi_tool.py index 010ff51c..b0ea10a2 100644 --- a/haystack_experimental/components/tools/openapi/openapi_tool.py +++ b/haystack_experimental/components/tools/openapi/openapi_tool.py @@ -97,18 +97,9 @@ def run( if not last_message.content: raise ValueError("Function calling instruction message content is empty.") - # build a new ClientConfiguration if a runtime tool_spec is provided - config_openapi = ( - ClientConfiguration( - openapi_spec=tool_spec, - credentials=tool_credentials, - llm_provider=self.descriptor.name, - ) - if tool_spec - else self.config_openapi - ) - + # build a new ClientConfiguration and OpenAPIServiceClient if a runtime tool_spec is provided openapi_service: Optional[OpenAPIServiceClient] = self.open_api_service + config_openapi: Optional[ClientConfiguration] = self.config_openapi if tool_spec: config_openapi = ClientConfiguration( openapi_spec=tool_spec, @@ -116,8 +107,6 @@ def run( llm_provider=self.descriptor.name, ) openapi_service = OpenAPIServiceClient(config_openapi) - else: - config_openapi = self.config_openapi if not openapi_service or not config_openapi: raise ValueError( From 72922a0784ae09091c7f86b1de679eeb2d8d51f1 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Thu, 6 Jun 2024 06:36:50 +0200 Subject: [PATCH 14/40] Update internal pydoc --- .../components/tools/openapi/openapi.py | 15 +++------------ 1 file changed, 3 insertions(+), 12 deletions(-) diff --git a/haystack_experimental/components/tools/openapi/openapi.py b/haystack_experimental/components/tools/openapi/openapi.py index 3764d37e..f8f60bdb 100644 --- a/haystack_experimental/components/tools/openapi/openapi.py +++ b/haystack_experimental/components/tools/openapi/openapi.py @@ -401,9 +401,6 @@ def is_valid_http_url(self, url: str) -> bool: class OpenAPIServiceClient: """ A client for invoking operations on REST services defined by OpenAPI specifications. - - Together with the `ClientConfiguration`, its `ClientConfigurationBuilder`, the `OpenAPIServiceClient` - simplifies the process of (LLMs) with services defined by OpenAPI specifications. """ def __init__(self, client_config: ClientConfiguration): @@ -426,15 +423,9 @@ def invoke(self, function_payload: Any) -> Any: f"Failed to extract function invocation payload from {function_payload}" ) # fn_invocation_payload, if not empty, guaranteed to have "name" and "arguments" keys from here on - operation = self.client_config.openapi_spec.find_operation_by_id( - fn_invocation_payload.get("name") - ) - request = self._build_request( - operation, **fn_invocation_payload.get("arguments") - ) - self._apply_authentication( - self.client_config.get_auth_config(), operation, request - ) + operation = self.client_config.openapi_spec.find_operation_by_id(fn_invocation_payload.get("name")) + request = self._build_request(operation, **fn_invocation_payload.get("arguments")) + self._apply_authentication(self.client_config.get_auth_config(), operation, request) return self.request_sender(request) def _request_sender(self) -> Callable[[Dict[str, Any]], Dict[str, Any]]: From d996cc5af3742bb70a0192ad0617c6edcb17465f Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Thu, 6 Jun 2024 17:31:54 +0200 Subject: [PATCH 15/40] PR feedback --- .../tools/openapi/{openapi.py => _openapi.py} | 469 +++++++++--------- ...d_extraction.py => _payload_extraction.py} | 5 + ...ma_conversion.py => _schema_conversion.py} | 35 +- .../tools/openapi/generator_factory.py | 105 ---- .../components/tools/openapi/openapi_tool.py | 53 +- test/components/tools/openapi/conftest.py | 2 +- .../tools/openapi/test_openapi_client.py | 2 +- .../tools/openapi/test_openapi_client_auth.py | 43 +- ...est_openapi_client_complex_request_body.py | 2 +- ...enapi_client_complex_request_body_mixed.py | 2 +- .../openapi/test_openapi_client_edge_cases.py | 2 +- .../test_openapi_client_error_handling.py | 2 +- .../tools/openapi/test_openapi_client_live.py | 22 +- .../test_openapi_client_live_anthropic.py | 11 +- .../test_openapi_client_live_cohere.py | 7 +- .../test_openapi_client_live_openai.py | 2 +- .../openapi/test_openapi_cohere_conversion.py | 2 +- .../openapi/test_openapi_openai_conversion.py | 2 +- .../tools/openapi/test_openapi_spec.py | 17 +- 19 files changed, 325 insertions(+), 460 deletions(-) rename haystack_experimental/components/tools/openapi/{openapi.py => _openapi.py} (50%) rename haystack_experimental/components/tools/openapi/{payload_extraction.py => _payload_extraction.py} (90%) rename haystack_experimental/components/tools/openapi/{schema_conversion.py => _schema_conversion.py} (85%) delete mode 100644 haystack_experimental/components/tools/openapi/generator_factory.py diff --git a/haystack_experimental/components/tools/openapi/openapi.py b/haystack_experimental/components/tools/openapi/_openapi.py similarity index 50% rename from haystack_experimental/components/tools/openapi/openapi.py rename to haystack_experimental/components/tools/openapi/_openapi.py index f8f60bdb..e171c35d 100644 --- a/haystack_experimental/components/tools/openapi/openapi.py +++ b/haystack_experimental/components/tools/openapi/_openapi.py @@ -5,8 +5,8 @@ import json import logging import os -from base64 import b64encode from dataclasses import dataclass, field +from enum import Enum from pathlib import Path from typing import Any, Callable, Dict, List, Literal, Optional, Union from urllib.parse import urlparse @@ -14,10 +14,10 @@ import requests import yaml -from haystack_experimental.components.tools.openapi.payload_extraction import ( +from haystack_experimental.components.tools.openapi._payload_extraction import ( create_function_payload_extractor, ) -from haystack_experimental.components.tools.openapi.schema_conversion import ( +from haystack_experimental.components.tools.openapi._schema_conversion import ( anthropic_converter, cohere_converter, openai_converter, @@ -37,27 +37,69 @@ logger = logging.getLogger(__name__) -class AuthenticationStrategy: +class LLMProvider(Enum): """ - Represents an authentication strategy that can be applied to an HTTP request. + Enum for the supported LLM providers. """ + OPENAI = "openai" + ANTHROPIC = "anthropic" + COHERE = "cohere" - def apply_auth(self, security_scheme: Dict[str, Any], request: Dict[str, Any]): - """ - Apply the authentication strategy to the given request. - :param security_scheme: the security scheme from the OpenAPI spec. - :param request: the request to apply the authentication to. - """ +def is_valid_http_url(url: str) -> bool: + """ + Check if a URL is a valid HTTP/HTTPS URL. + + :param url: The URL to check. + :return: True if the URL is a valid HTTP/HTTPS URL, False otherwise. + """ + r = urlparse(url) + return all([r.scheme in ["http", "https"], r.netloc]) -@dataclass -class ApiKeyAuthentication(AuthenticationStrategy): - """API key authentication strategy.""" +def send_request(request: Dict[str, Any]) -> Dict[str, Any]: + """ + Send an HTTP request and return the response. - api_key: Optional[str] = None + :param request: The request to send. + :return: The response from the server. + """ + url = request["url"] + headers = {**request.get("headers", {})} + try: + response = requests.request( + request["method"], + url, + headers=headers, + params=request.get("params", {}), + json=request.get("json"), + auth=request.get("auth"), + timeout=10, + ) + response.raise_for_status() + return response.json() + except requests.exceptions.HTTPError as e: + logger.warning("HTTP error occurred: %s while sending request to %s", e, url) + raise HttpClientError(f"HTTP error occurred: {e}") from e + except requests.exceptions.RequestException as e: + logger.warning("Request error occurred: %s while sending request to %s", e, url) + raise HttpClientError(f"HTTP error occurred: {e}") from e + except Exception as e: + logger.warning("An error occurred: %s while sending request to %s", e, url) + raise HttpClientError(f"An error occurred: {e}") from e + + +# Authentication strategies +def create_api_key_auth_function(api_key: str): + """ + Create a function that applies the API key authentication strategy to a given request. - def apply_auth(self, security_scheme: Dict[str, Any], request: Dict[str, Any]): + :param api_key: the API key to use for authentication. + :return: a function that applies the API key authentication to a request + at the schema specified location. + """ + + def apply_api_key_auth(security_scheme: Dict[str, Any], request: Dict[str, Any]): """ Apply the API key authentication strategy to the given request. @@ -65,67 +107,18 @@ def apply_auth(self, security_scheme: Dict[str, Any], request: Dict[str, Any]): :param request: the request to apply the authentication to. """ if security_scheme["in"] == "header": - request.setdefault("headers", {})[security_scheme["name"]] = self.api_key + request.setdefault("headers", {})[security_scheme["name"]] = api_key elif security_scheme["in"] == "query": - request.setdefault("params", {})[security_scheme["name"]] = self.api_key + request.setdefault("params", {})[security_scheme["name"]] = api_key elif security_scheme["in"] == "cookie": - request.setdefault("cookies", {})[security_scheme["name"]] = self.api_key + request.setdefault("cookies", {})[security_scheme["name"]] = api_key else: raise ValueError( f"Unsupported apiKey authentication location: {security_scheme['in']}, " f"must be one of 'header', 'query', or 'cookie'" ) - -@dataclass -class HTTPAuthentication(AuthenticationStrategy): - """HTTP authentication strategy.""" - - username: Optional[str] = None - password: Optional[str] = None - token: Optional[str] = None - - def __post_init__(self): - if not self.token and (not self.username or not self.password): - raise ValueError( - "For HTTP Basic Auth, both username and password must be provided. " - "For Bearer Auth, a token must be provided." - ) - - def apply_auth(self, security_scheme: Dict[str, Any], request: Dict[str, Any]): - """ - Apply the HTTP authentication strategy to the given request. - - :param security_scheme: the security scheme from the OpenAPI spec. - :param request: the request to apply the authentication to. - """ - if security_scheme["type"] == "http": - if security_scheme["scheme"].lower() == "basic": - if not self.username or not self.password: - raise ValueError( - "Username and password must be provided for Basic Auth." - ) - credentials = f"{self.username}:{self.password}" - encoded_credentials = b64encode(credentials.encode("utf-8")).decode( - "utf-8" - ) - request.setdefault("headers", {})[ - "Authorization" - ] = f"Basic {encoded_credentials}" - elif security_scheme["scheme"].lower() == "bearer": - if not self.token: - raise ValueError("Token must be provided for Bearer Auth.") - request.setdefault("headers", {})[ - "Authorization" - ] = f"Bearer {self.token}" - else: - raise ValueError( - f"Unsupported HTTP authentication scheme: {security_scheme['scheme']}" - ) - else: - raise ValueError( - "HTTPAuthentication strategy received a non-HTTP security scheme." - ) + return apply_api_key_auth class HttpClientError(Exception): @@ -134,7 +127,21 @@ class HttpClientError(Exception): @dataclass class Operation: - """Represents an operation in an OpenAPI specification.""" + """ + Represents an operation in an OpenAPI specification + + See https://spec.openapis.org/oas/latest.html#paths-object for details. + Path objects can contain multiple operations, each with a unique combination of path and method. + + Attributes: + path (str): Path of the operation. + method (str): HTTP method of the operation. + operation_dict (Dict[str, Any]): Operation details from OpenAPI spec. + spec_dict (Dict[str, Any]): The encompassing OpenAPI specification. + security_requirements (List[Dict[str, List[str]]]): Security requirements for the operation. + request_body (Dict[str, Any]): Request body details. + parameters (List[Dict[str, Any]]): Parameters for the operation. + """ path: str method: str @@ -161,23 +168,41 @@ def get_parameters( ) -> List[Dict[str, Any]]: """ Get the parameters for the operation. + + :param location: The location of the parameters to get. + :return: The parameters for the operation as a list of dictionaries. """ if location: return [param for param in self.parameters if param["in"] == location] return self.parameters - def get_server(self) -> str: + def get_server(self, server_index: int = 0) -> str: """ Get the servers for the operation. + + :param server_index: The index of the server to use. + :return: The server URL. + :raises ValueError: If no servers are found in the specification. """ servers = self.operation_dict.get("servers", []) or self.spec_dict.get( "servers", [] ) - return servers[0].get("url", "") # just use the first server from the list + if not servers: + raise ValueError("No servers found in the provided specification.") + if server_index >= len(servers): + raise ValueError( + f"Server index {server_index} is out of bounds. " + f"Only {len(servers)} servers found." + ) + return servers[server_index].get( + "url" + ) # just use the first server from the list class OpenAPISpecification: - """Represents an OpenAPI specification.""" + """ + Represents an OpenAPI specification. See https://spec.openapis.org/oas/latest.html for details. + """ def __init__(self, spec_dict: Dict[str, Any]): if not isinstance(spec_dict, Dict): @@ -196,18 +221,13 @@ def __init__(self, spec_dict: Dict[str, Any]): ) self.spec_dict = spec_dict - @classmethod - def from_dict(cls, spec_dict: Dict[str, Any]) -> "OpenAPISpecification": - """ - Create an OpenAPISpecification instance from a dictionary. - """ - parser = cls(spec_dict) - return parser - @classmethod def from_str(cls, content: str) -> "OpenAPISpecification": """ Create an OpenAPISpecification instance from a string. + + :param content: The string content of the OpenAPI specification. + :return: The OpenAPISpecification instance. """ try: loaded_spec = json.loads(content) @@ -224,6 +244,9 @@ def from_str(cls, content: str) -> "OpenAPISpecification": def from_file(cls, spec_file: Union[str, Path]) -> "OpenAPISpecification": """ Create an OpenAPISpecification instance from a file. + + :param spec_file: The file path to the OpenAPI specification. + :return: The OpenAPISpecification instance. """ with open(spec_file, encoding="utf-8") as file: content = file.read() @@ -233,6 +256,9 @@ def from_file(cls, spec_file: Union[str, Path]) -> "OpenAPISpecification": def from_url(cls, url: str) -> "OpenAPISpecification": """ Create an OpenAPISpecification instance from a URL. + + :param url: The URL to fetch the OpenAPI specification from. + :return: The OpenAPISpecification instance. """ try: response = requests.get(url, timeout=10) @@ -248,23 +274,31 @@ def find_operation_by_id( self, op_id: str, method: Optional[str] = None ) -> Operation: """ - Find an operation by operationId. + Find an Operation by operationId. + + :param op_id: The operationId of the operation. + :param method: The HTTP method of the operation. + :return: The matching operation + :raises ValueError: If no operation is found with the given operationId. """ for path, path_item in self.spec_dict.get("paths", {}).items(): op: Operation = self.get_operation_item(path, path_item, method) if op_id in op.operation_dict.get("operationId", ""): return self.get_operation_item(path, path_item, method) - raise ValueError(f"No operation found with operationId {op_id}") + raise ValueError( + f"No operation found with operationId {op_id}, method {method}" + ) def get_operation_item( self, path: str, path_item: Dict[str, Any], method: Optional[str] = None ) -> Operation: """ - Get an operation item from the OpenAPI specification. + Gets a particular Operation item from the OpenAPI specification given the path and method. :param path: The path of the operation. :param path_item: The path item from the OpenAPI specification. :param method: The HTTP method of the operation. + :return: The operation """ if method: operation_dict = path_item.get(method.lower(), {}) @@ -280,11 +314,13 @@ def get_operation_item( raise ValueError( f"Multiple operations found at path {path}, method parameter is required." ) - raise ValueError(f"No operations found at path {path}.") + raise ValueError(f"No operations found at path {path} and method {method}") def get_security_schemes(self) -> Dict[str, Dict[str, Any]]: """ Get the security schemes from the OpenAPI specification. + + :return: The security schemes as a dictionary. """ return self.spec_dict.get("components", {}).get("securitySchemes", {}) @@ -294,19 +330,15 @@ class ClientConfiguration: def __init__( # noqa: PLR0913 pylint: disable=too-many-arguments self, - openapi_spec: Union[str, Path, Dict[str, Any]], - credentials: Optional[ - Union[str, Dict[str, Any], AuthenticationStrategy] - ] = None, + openapi_spec: Union[str, Path], + credentials: Optional[str] = None, request_sender: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None, - llm_provider: Optional[str] = None, + llm_provider: Optional[LLMProvider] = None, ): # noqa: PLR0913 if isinstance(openapi_spec, (str, Path)) and os.path.isfile(openapi_spec): self.openapi_spec = OpenAPISpecification.from_file(openapi_spec) - elif isinstance(openapi_spec, dict): - self.openapi_spec = OpenAPISpecification.from_dict(openapi_spec) elif isinstance(openapi_spec, str): - if self.is_valid_http_url(openapi_spec): + if is_valid_http_url(openapi_spec): self.openapi_spec = OpenAPISpecification.from_url(openapi_spec) else: self.openapi_spec = OpenAPISpecification.from_str(openapi_spec) @@ -316,40 +348,46 @@ def __init__( # noqa: PLR0913 pylint: disable=too-many-arguments ) self.credentials = credentials - self.request_sender = request_sender - self.llm_provider = llm_provider or "openai" + self.request_sender = request_sender or send_request + self.llm_provider: LLMProvider = llm_provider or LLMProvider.OPENAI - def get_auth_config(self) -> AuthenticationStrategy: + def get_auth_function(self) -> Callable[[dict[str, Any], dict[str, Any]], Any]: """ - Get the authentication configuration. + Get the authentication function that sets a schema specified authentication to the request. + + The function takes a security scheme and a request as arguments: + `security_scheme: Dict[str, Any] - The security scheme from the OpenAPI spec.` + `request: Dict[str, Any] - The request to apply the authentication to.` + :return: The authentication function. """ - if not self.credentials: - return AuthenticationStrategy() - if isinstance(self.credentials, AuthenticationStrategy): - return self.credentials security_schemes = self.openapi_spec.get_security_schemes() + if not self.credentials: + return lambda security_scheme, request: None # No-op function if isinstance(self.credentials, str): return self._create_authentication_from_string( self.credentials, security_schemes ) - if isinstance(self.credentials, dict): - return self._create_authentication_from_dict(self.credentials) raise ValueError(f"Unsupported credentials type: {type(self.credentials)}") def get_tools_definitions(self) -> List[Dict[str, Any]]: """ Get the tools definitions used as tools LLM parameter. + + :return: The tools definitions passed to the LLM as tools parameter. """ provider_to_converter = { "anthropic": anthropic_converter, "cohere": cohere_converter, } - converter = provider_to_converter.get(self.llm_provider, openai_converter) + converter = provider_to_converter.get(self.llm_provider.value, openai_converter) return converter(self.openapi_spec) def get_payload_extractor(self): """ Get the payload extractor for the LLM provider. + + This function knows how to extract the exact function payload from the LLM generated function calling payload. + :return: The payload extractor function. """ provider_to_arguments_field_name = { "anthropic": "input", @@ -357,45 +395,102 @@ def get_payload_extractor(self): } # add more providers here # default to OpenAI "arguments" arguments_field_name = provider_to_arguments_field_name.get( - self.llm_provider, "arguments" + self.llm_provider.value, "arguments" ) return create_function_payload_extractor(arguments_field_name) def _create_authentication_from_string( self, credentials: str, security_schemes: Dict[str, Any] - ) -> AuthenticationStrategy: + ) -> Callable[[dict[str, Any], dict[str, Any]], Any]: for scheme in security_schemes.values(): if scheme["type"] == "apiKey": - return ApiKeyAuthentication(api_key=credentials) + return create_api_key_auth_function(api_key=credentials) if scheme["type"] == "http": - return HTTPAuthentication(token=credentials) + raise NotImplementedError("HTTP authentication is not yet supported.") if scheme["type"] == "oauth2": raise NotImplementedError("OAuth2 authentication is not yet supported.") raise ValueError( f"Unable to create authentication from provided credentials: {credentials}" ) - def _create_authentication_from_dict( - self, credentials: Dict[str, Any] - ) -> AuthenticationStrategy: - if "username" in credentials and "password" in credentials: - return HTTPAuthentication( - username=credentials["username"], password=credentials["password"] - ) - if "api_key" in credentials: - return ApiKeyAuthentication(api_key=credentials["api_key"]) - if "token" in credentials: - return HTTPAuthentication(token=credentials["token"]) - if "access_token" in credentials: - raise NotImplementedError("OAuth2 authentication is not yet supported.") - raise ValueError( - "Unable to create authentication from provided credentials: {credentials}" - ) - def is_valid_http_url(self, url: str) -> bool: - """Check if a URL is a valid HTTP/HTTPS URL.""" - r = urlparse(url) - return all([r.scheme in ["http", "https"], r.netloc]) +def build_request(operation: Operation, **kwargs) -> Dict[str, Any]: + """ + Build an HTTP request for the operation. + + :param operation: The operation to build the request for. + :param kwargs: The arguments to use for building the request. + :return: The HTTP request as a dictionary. + """ + path = operation.path + for parameter in operation.get_parameters("path"): + param_value = kwargs.get(parameter["name"], None) + if param_value: + path = path.replace(f"{{{parameter['name']}}}", str(param_value)) + elif parameter.get("required", False): + raise ValueError(f"Missing required path parameter: {parameter['name']}") + url = operation.get_server() + path + # method + method = operation.method.lower() + # headers + headers = {} + for parameter in operation.get_parameters("header"): + param_value = kwargs.get(parameter["name"], None) + if param_value: + headers[parameter["name"]] = str(param_value) + elif parameter.get("required", False): + raise ValueError(f"Missing required header parameter: {parameter['name']}") + # query params + query_params = {} + for parameter in operation.get_parameters("query"): + param_value = kwargs.get(parameter["name"], None) + if param_value: + query_params[parameter["name"]] = param_value + elif parameter.get("required", False): + raise ValueError(f"Missing required query parameter: {parameter['name']}") + + json_payload = None + request_body = operation.request_body + if request_body: + content = request_body.get("content", {}) + if "application/json" in content: + json_payload = {**kwargs} + else: + raise NotImplementedError("Request body content type not supported") + return { + "url": url, + "method": method, + "headers": headers, + "params": query_params, + "json": json_payload, + } + + +def apply_authentication( + auth_strategy: Callable[[Dict[str, Any], Dict[str, Any]], Any], + operation: Operation, + request: Dict[str, Any], +): + """ + Apply the authentication strategy to the given request. + + :param auth_strategy: The authentication strategy to apply. + This is a function that takes a security scheme and a request as arguments (at runtime) + and applies the authentication + :param operation: The operation to apply the authentication to. + :param request: The request to apply the authentication to. + """ + security_requirements = operation.security_requirements + security_schemes = operation.spec_dict.get("components", {}).get( + "securitySchemes", {} + ) + if security_requirements: + for requirement in security_requirements: + for scheme_name in requirement: + if scheme_name in security_schemes: + security_scheme = security_schemes[scheme_name] + auth_strategy(security_scheme, request) + break class OpenAPIServiceClient: @@ -405,7 +500,7 @@ class OpenAPIServiceClient: def __init__(self, client_config: ClientConfiguration): self.client_config = client_config - self.request_sender = client_config.request_sender or self._request_sender() + self.request_sender = client_config.request_sender def invoke(self, function_payload: Any) -> Any: """ @@ -423,118 +518,12 @@ def invoke(self, function_payload: Any) -> Any: f"Failed to extract function invocation payload from {function_payload}" ) # fn_invocation_payload, if not empty, guaranteed to have "name" and "arguments" keys from here on - operation = self.client_config.openapi_spec.find_operation_by_id(fn_invocation_payload.get("name")) - request = self._build_request(operation, **fn_invocation_payload.get("arguments")) - self._apply_authentication(self.client_config.get_auth_config(), operation, request) - return self.request_sender(request) - - def _request_sender(self) -> Callable[[Dict[str, Any]], Dict[str, Any]]: - """ - Returns a callable that sends the request using the HTTP client. - """ - - def send_request(request: Dict[str, Any]) -> Dict[str, Any]: - url = request["url"] - headers = {**request.get("headers", {})} - try: - response = requests.request( - request["method"], - url, - headers=headers, - params=request.get("params", {}), - json=request.get("json"), - auth=request.get("auth"), - timeout=10, - ) - response.raise_for_status() - return response.json() - except requests.exceptions.HTTPError as e: - logger.warning( - "HTTP error occurred: %s while sending request to %s", e, url - ) - raise HttpClientError(f"HTTP error occurred: {e}") from e - except requests.exceptions.RequestException as e: - logger.warning( - "Request error occurred: %s while sending request to %s", e, url - ) - raise HttpClientError(f"HTTP error occurred: {e}") from e - except Exception as e: - logger.warning( - "An error occurred: %s while sending request to %s", e, url - ) - raise HttpClientError(f"An error occurred: {e}") from e - - return send_request - - def _build_request(self, operation: Operation, **kwargs) -> Any: - # url - path = operation.path - for parameter in operation.get_parameters("path"): - param_value = kwargs.get(parameter["name"], None) - if param_value: - path = path.replace(f"{{{parameter['name']}}}", str(param_value)) - elif parameter.get("required", False): - raise ValueError( - f"Missing required path parameter: {parameter['name']}" - ) - url = operation.get_server() + path - # method - method = operation.method.lower() - # headers - headers = {} - for parameter in operation.get_parameters("header"): - param_value = kwargs.get(parameter["name"], None) - if param_value: - headers[parameter["name"]] = str(param_value) - elif parameter.get("required", False): - raise ValueError( - f"Missing required header parameter: {parameter['name']}" - ) - # query params - query_params = {} - for parameter in operation.get_parameters("query"): - param_value = kwargs.get(parameter["name"], None) - if param_value: - query_params[parameter["name"]] = param_value - elif parameter.get("required", False): - raise ValueError( - f"Missing required query parameter: {parameter['name']}" - ) - - json_payload = None - request_body = operation.request_body - if request_body: - content = request_body.get("content", {}) - if "application/json" in content: - json_payload = {**kwargs} - else: - raise NotImplementedError("Request body content type not supported") - return { - "url": url, - "method": method, - "headers": headers, - "params": query_params, - "json": json_payload, - } - - def _apply_authentication( - self, - auth: AuthenticationStrategy, - operation: Operation, - request: Dict[str, Any], - ): - auth_config = auth or AuthenticationStrategy() - security_requirements = operation.security_requirements - security_schemes = operation.spec_dict.get("components", {}).get( - "securitySchemes", {} + operation = self.client_config.openapi_spec.find_operation_by_id( + fn_invocation_payload.get("name") ) - if security_requirements: - for requirement in security_requirements: - for scheme_name in requirement: - if scheme_name in security_schemes: - security_scheme = security_schemes[scheme_name] - auth_config.apply_auth(security_scheme, request) - break + request = build_request(operation, **fn_invocation_payload.get("arguments")) + apply_authentication(self.client_config.get_auth_function(), operation, request) + return self.request_sender(request) class OpenAPIClientError(Exception): diff --git a/haystack_experimental/components/tools/openapi/payload_extraction.py b/haystack_experimental/components/tools/openapi/_payload_extraction.py similarity index 90% rename from haystack_experimental/components/tools/openapi/payload_extraction.py rename to haystack_experimental/components/tools/openapi/_payload_extraction.py index 416841bf..6247c56a 100644 --- a/haystack_experimental/components/tools/openapi/payload_extraction.py +++ b/haystack_experimental/components/tools/openapi/_payload_extraction.py @@ -12,11 +12,16 @@ def create_function_payload_extractor( ) -> Callable[[Any], Dict[str, Any]]: """ Extracts invocation payload from a given LLM completion containing function invocation. + + :param arguments_field_name: The name of the field containing the function arguments. + :return: A function that extracts the function invocation details from the LLM payload. """ def _extract_function_invocation(payload: Any) -> Dict[str, Any]: """ Extract the function invocation details from the payload. + + :param payload: The LLM fc payload to extract the function invocation details from. """ fields_and_values = _search(payload, arguments_field_name) if fields_and_values: diff --git a/haystack_experimental/components/tools/openapi/schema_conversion.py b/haystack_experimental/components/tools/openapi/_schema_conversion.py similarity index 85% rename from haystack_experimental/components/tools/openapi/schema_conversion.py rename to haystack_experimental/components/tools/openapi/_schema_conversion.py index b74bc1cf..abc23c9c 100644 --- a/haystack_experimental/components/tools/openapi/schema_conversion.py +++ b/haystack_experimental/components/tools/openapi/_schema_conversion.py @@ -5,9 +5,6 @@ import logging from typing import Any, Callable, Dict, List, Optional -# SPDX-FileCopyrightText: 2022-present deepset GmbH -# -# SPDX-License-Identifier: Apache-2.0 import jsonref MIN_REQUIRED_OPENAPI_SPEC_VERSION = 3 @@ -19,8 +16,9 @@ def openai_converter(schema: "OpenAPISpecification") -> List[Dict[str, Any]]: # """ Converts OpenAPI specification to a list of function suitable for OpenAI LLM function calling. + See https://platform.openai.com/docs/guides/function-calling for more information about OpenAI's function schema. :param schema: The OpenAPI specification to convert. - :return: A list of dictionaries, each representing a function definition. + :return: A list of dictionaries, each dictionary representing an OpenAI function definition. """ resolved_schema = jsonref.replace_refs(schema.spec_dict) fn_definitions = _openapi_to_functions( @@ -33,8 +31,10 @@ def anthropic_converter(schema: "OpenAPISpecification") -> List[Dict[str, Any]]: """ Converts an OpenAPI specification to a list of function definitions for Anthropic LLM function calling. + See https://docs.anthropic.com/en/docs/tool-use for more information about Anthropic's function schema. + :param schema: The OpenAPI specification to convert. - :return: A list of dictionaries, each representing a function definition. + :return: A list of dictionaries, each dictionary representing Anthropic function definition. """ resolved_schema = jsonref.replace_refs(schema.spec_dict) return _openapi_to_functions( @@ -46,8 +46,10 @@ def cohere_converter(schema: "OpenAPISpecification") -> List[Dict[str, Any]]: # """ Converts an OpenAPI specification to a list of function definitions for Cohere LLM function calling. + See https://docs.cohere.com/docs/tool-use for more information about Cohere's function schema. + :param schema: The OpenAPI specification to convert. - :return: A list of dictionaries, each representing a function definition. + :return: A list of dictionaries, each representing a Cohere style function definition. """ resolved_schema = jsonref.replace_refs(schema.spec_dict) return _openapi_to_functions( @@ -62,6 +64,10 @@ def _openapi_to_functions( ) -> List[Dict[str, Any]]: """ Extracts functions from the OpenAPI specification, converts them into a function schema. + + :param service_openapi_spec: The OpenAPI specification to extract functions from. + :param parameters_name: The name of the parameters field in the function schema. + :param parse_endpoint_fn: The function to parse the endpoint specification. """ # Doesn't enforce rigid spec validation because that would require a lot of dependencies @@ -92,6 +98,9 @@ def _parse_endpoint_spec_openai( ) -> Dict[str, Any]: """ Parses an OpenAPI endpoint specification for OpenAI. + + :param resolved_spec: The resolved OpenAPI specification. + :param parameters_name: The name of the parameters field in the function schema. """ if not isinstance(resolved_spec, dict): logger.warning( @@ -145,6 +154,9 @@ def _parse_property_attributes( ) -> Dict[str, Any]: """ Recursively parses the attributes of a property schema. + + :param property_schema: The property schema to parse. + :param include_attributes: The attributes to include in the parsed schema. """ include_attributes = include_attributes or ["description", "pattern", "enum"] schema_type = property_schema.get("type") @@ -172,6 +184,9 @@ def _parse_endpoint_spec_cohere( ) -> Dict[str, Any]: """ Parses an endpoint specification for Cohere. + + :param operation: The operation specification to parse. + :param ignored_param: ignored, left for compatibility with the OpenAI converter. """ function_name = operation.get("operationId") description = operation.get("description") or operation.get("summary", "") @@ -189,6 +204,9 @@ def _parse_endpoint_spec_cohere( def _parse_parameters(operation: Dict[str, Any]) -> Dict[str, Any]: """ Parses the parameters from an operation specification. + + :param operation: The operation specification to parse. + :return: A dictionary containing the parsed parameters. """ parameters = {} for param in operation.get("parameters", []): @@ -217,6 +235,11 @@ def _parse_schema( ) -> Dict[str, Any]: # noqa: FBT001 """ Parses a schema part of an operation specification. + + :param schema: The schema to parse. + :param required: Whether the schema is required. + :param description: The description of the schema. + :return: A dictionary containing the parsed schema. """ schema_type = _get_type(schema) if schema_type == "object": diff --git a/haystack_experimental/components/tools/openapi/generator_factory.py b/haystack_experimental/components/tools/openapi/generator_factory.py deleted file mode 100644 index 28401863..00000000 --- a/haystack_experimental/components/tools/openapi/generator_factory.py +++ /dev/null @@ -1,105 +0,0 @@ -# SPDX-FileCopyrightText: 2022-present deepset GmbH -# -# SPDX-License-Identifier: Apache-2.0 - -import importlib -import re -from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple - - -@dataclass -class ChatGeneratorDescriptor: - """ - Dataclass to describe a Chat Generator - """ - - class_path: str - patterns: List[re.Pattern] - name: str - model_name: str - - -class ChatGeneratorDescriptorManager: - """ - Class to manage Chat Generator Descriptors - """ - - def __init__(self): - self._descriptors: Dict[str, ChatGeneratorDescriptor] = {} - self._register_default_descriptors() - - def _register_default_descriptors(self): - """ - Register default Chat Generator Descriptors. - """ - default_descriptors = [ - ChatGeneratorDescriptor( - class_path="haystack.components.generators.chat.openai.OpenAIChatGenerator", - patterns=[re.compile(r"^gpt.*")], - name="openai", - model_name="gpt-3.5-turbo", - ), - ChatGeneratorDescriptor( - class_path="haystack_integrations.components.generators.anthropic.AnthropicChatGenerator", - patterns=[re.compile(r"^claude.*")], - name="anthropic", - model_name="claude-1", - ), - ChatGeneratorDescriptor( - class_path="haystack_integrations.components.generators.cohere.CohereChatGenerator", - patterns=[re.compile(r"^command-r.*")], - name="cohere", - model_name="command-r", - ), - ] - - for descriptor in default_descriptors: - self.register_descriptor(descriptor) - - def _load_class(self, full_class_path: str): - """ - Load a class from a string representation of its path e.g. "module.submodule.class_name" - """ - module_path, _, class_name = full_class_path.rpartition(".") - module = importlib.import_module(module_path) - return getattr(module, class_name) - - def register_descriptor(self, descriptor: ChatGeneratorDescriptor): - """ - Register a new Chat Generator Descriptor. - """ - if descriptor.name in self._descriptors: - raise ValueError(f"Descriptor {descriptor.name} already exists.") - - self._descriptors[descriptor.name] = descriptor - - def _infer_descriptor(self, model_name: str) -> Optional[ChatGeneratorDescriptor]: - """ - Infer the descriptor based on the model name. - """ - for descriptor in self._descriptors.values(): - if any(pattern.match(model_name) for pattern in descriptor.patterns): - return descriptor - return None - - def create_generator( - self, model_name: str, descriptor_name: Optional[str] = None, **model_kwargs - ) -> Tuple[ChatGeneratorDescriptor, Any]: - """ - Create ChatGenerator instance based on the model name and descriptor. - """ - if descriptor_name: - descriptor = self._descriptors.get(descriptor_name) - if not descriptor: - raise ValueError(f"Invalid descriptor name: {descriptor_name}") - else: - descriptor = self._infer_descriptor(model_name) - if not descriptor: - raise ValueError( - f"Could not infer descriptor for model name: {model_name}" - ) - - return descriptor, self._load_class(descriptor.class_path)( - model=model_name, **(model_kwargs or {}) - ) diff --git a/haystack_experimental/components/tools/openapi/openapi_tool.py b/haystack_experimental/components/tools/openapi/openapi_tool.py index b0ea10a2..cb26cc5e 100644 --- a/haystack_experimental/components/tools/openapi/openapi_tool.py +++ b/haystack_experimental/components/tools/openapi/openapi_tool.py @@ -7,13 +7,13 @@ from typing import Any, Dict, List, Optional, Union from haystack import component, logging +from haystack.components.generators.chat import OpenAIChatGenerator from haystack.dataclasses import ChatMessage, ChatRole +from haystack.lazy_imports import LazyImport -from haystack_experimental.components.tools.openapi.generator_factory import ( - ChatGeneratorDescriptorManager, -) -from haystack_experimental.components.tools.openapi.openapi import ( +from haystack_experimental.components.tools.openapi._openapi import ( ClientConfiguration, + LLMProvider, OpenAPIServiceClient, ) @@ -46,30 +46,28 @@ class OpenAPITool: def __init__( self, - model: str, + generator_api: LLMProvider, + generator_api_params: Dict[str, Any], tool_spec: Optional[Union[str, Path]] = None, - tool_credentials: Optional[Union[str, Dict[str, Any]]] = None, - model_kwargs: Optional[Dict[str, Any]] = None, + tool_credentials: Optional[str] = None, ): """ Initialize the OpenAPITool component. - :param model: Name of the chat generator model to use. + :param generator_api: The API provider for the chat generator. + :param generator_api_params: Parameters for the chat generator. :param tool_spec: OpenAPI specification for the tool/service. :param tool_credentials: Credentials for the tool/service. - :param model_kwargs: Additional arguments for the chat generator model. """ - manager = ChatGeneratorDescriptorManager() - self.descriptor, self.chat_generator = manager.create_generator( - model_name=model, **(model_kwargs or {}) - ) + self.generator_api = generator_api + self.chat_generator = self._init_generator(generator_api, generator_api_params) self.config_openapi: Optional[ClientConfiguration] = None self.open_api_service: Optional[OpenAPIServiceClient] = None if tool_spec: self.config_openapi = ClientConfiguration( openapi_spec=tool_spec, credentials=tool_credentials, - llm_provider=self.descriptor.name, + llm_provider=generator_api ) self.open_api_service = OpenAPIServiceClient(self.config_openapi) @@ -78,8 +76,8 @@ def run( self, messages: List[ChatMessage], fc_generator_kwargs: Optional[Dict[str, Any]] = None, - tool_spec: Optional[Union[str, Path, Dict[str, Any]]] = None, - tool_credentials: Optional[Union[dict, str]] = None, + tool_spec: Optional[Union[str, Path]] = None, + tool_credentials: Optional[str] = None, ) -> Dict[str, List[ChatMessage]]: """ Invokes the underlying OpenAPI service/tool with the function calling payload generated by the chat generator. @@ -104,7 +102,7 @@ def run( config_openapi = ClientConfiguration( openapi_spec=tool_spec, credentials=tool_credentials, - llm_provider=self.descriptor.name, + llm_provider=self.generator_api, ) openapi_service = OpenAPIServiceClient(config_openapi) @@ -134,4 +132,25 @@ def run( logger.error("Error invoking OpenAPI endpoint. Error: {e}", e=str(e)) service_response = {"error": str(e)} response_messages = [ChatMessage.from_user(json.dumps(service_response))] + return {"service_response": response_messages} + + def _init_generator(self, generator_api: LLMProvider, generator_api_params: Dict[str, Any]): + """ + Initialize the chat generator based on the specified API provider and parameters. + """ + if generator_api == LLMProvider.OPENAI: + return OpenAIChatGenerator(**generator_api_params) + if generator_api == LLMProvider.ANTHROPIC: + with LazyImport("Run 'pip install anthropic-haystack'") as anthropic_import: + anthropic_import.check() + # pylint: disable=import-outside-toplevel + from haystack_integrations.components.generators.anthropic import AnthropicChatGenerator + return AnthropicChatGenerator(**generator_api_params) + if generator_api == LLMProvider.COHERE: + with LazyImport("Run 'pip install cohere-haystack'") as cohere_import: + cohere_import.check() + # pylint: disable=import-outside-toplevel + from haystack_integrations.components.generators.cohere import CohereChatGenerator + return CohereChatGenerator(**generator_api_params) + raise ValueError(f"Unsupported generator API: {generator_api}") diff --git a/test/components/tools/openapi/conftest.py b/test/components/tools/openapi/conftest.py index 2df4f76c..89ec74d4 100644 --- a/test/components/tools/openapi/conftest.py +++ b/test/components/tools/openapi/conftest.py @@ -10,7 +10,7 @@ from fastapi import FastAPI from fastapi.testclient import TestClient -from haystack_experimental.components.tools.openapi.openapi import HttpClientError +from haystack_experimental.components.tools.openapi._openapi import HttpClientError @pytest.fixture() diff --git a/test/components/tools/openapi/test_openapi_client.py b/test/components/tools/openapi/test_openapi_client.py index 4842b75e..c622642e 100644 --- a/test/components/tools/openapi/test_openapi_client.py +++ b/test/components/tools/openapi/test_openapi_client.py @@ -7,7 +7,7 @@ from fastapi.responses import JSONResponse from pydantic import BaseModel -from haystack_experimental.components.tools.openapi.openapi import OpenAPIServiceClient, ClientConfiguration +from haystack_experimental.components.tools.openapi._openapi import OpenAPIServiceClient, ClientConfiguration from test.components.tools.openapi.conftest import FastAPITestClient """ diff --git a/test/components/tools/openapi/test_openapi_client_auth.py b/test/components/tools/openapi/test_openapi_client_auth.py index de91855e..ab6205e8 100644 --- a/test/components/tools/openapi/test_openapi_client_auth.py +++ b/test/components/tools/openapi/test_openapi_client_auth.py @@ -15,8 +15,7 @@ HTTPBearer, ) -from haystack_experimental.components.tools.openapi.openapi import OpenAPIServiceClient, ApiKeyAuthentication, \ - HTTPAuthentication, ClientConfiguration +from haystack_experimental.components.tools.openapi._openapi import OpenAPIServiceClient, ClientConfiguration from test.components.tools.openapi.conftest import FastAPITestClient API_KEY = "secret_api_key" @@ -140,7 +139,7 @@ class TestOpenAPIAuth: def test_greet_api_key_auth(self, test_files_path): config = ClientConfiguration(openapi_spec=test_files_path / "yaml" / "openapi_greeting_service.yml", request_sender=FastAPITestClient(create_greet_api_key_auth_app()), - credentials=ApiKeyAuthentication(API_KEY)) + credentials=API_KEY) client = OpenAPIServiceClient(config) payload = { "id": "call_NJr1NBz2Th7iUWJpRIJZoJIA", @@ -153,26 +152,10 @@ def test_greet_api_key_auth(self, test_files_path): response = client.invoke(payload) assert response == {"greeting": "Hello, John from api_key_auth, using secret_api_key"} - def test_greet_basic_auth(self, test_files_path): - config = ClientConfiguration(openapi_spec=test_files_path / "yaml" / "openapi_greeting_service.yml", - request_sender=FastAPITestClient(create_greet_basic_auth_app()), - credentials=HTTPAuthentication(BASIC_AUTH_USERNAME, BASIC_AUTH_PASSWORD)) - client = OpenAPIServiceClient(config) - payload = { - "id": "call_NJr1NBz2Th7iUWJpRIJZoJIA", - "function": { - "arguments": '{"name": "John"}', - "name": "greetBasicAuth", - }, - "type": "function", - } - response = client.invoke(payload) - assert response == {"greeting": "Hello, John from basic_auth, using admin"} - def test_greet_api_key_query_auth(self, test_files_path): config = ClientConfiguration(openapi_spec=test_files_path / "yaml" / "openapi_greeting_service.yml", request_sender=FastAPITestClient(create_greet_api_key_query_app()), - credentials=ApiKeyAuthentication(API_KEY_QUERY)) + credentials=API_KEY_QUERY) client = OpenAPIServiceClient(config) payload = { "id": "call_NJr1NBz2Th7iUWJpRIJZoJIA", @@ -189,7 +172,7 @@ def test_greet_api_key_cookie_auth(self, test_files_path): config = ClientConfiguration(openapi_spec=test_files_path / "yaml" / "openapi_greeting_service.yml", request_sender=FastAPITestClient(create_greet_api_key_cookie_app()), - credentials=ApiKeyAuthentication(API_KEY_COOKIE)) + credentials=API_KEY_COOKIE) client = OpenAPIServiceClient(config) payload = { @@ -201,20 +184,4 @@ def test_greet_api_key_cookie_auth(self, test_files_path): "type": "function", } response = client.invoke(payload) - assert response == {"greeting": "Hello, John from api_key_cookie_auth, using secret_api_key_cookie"} - - def test_greet_bearer_auth(self, test_files_path): - config = ClientConfiguration(openapi_spec=test_files_path / "yaml" / "openapi_greeting_service.yml", - request_sender=FastAPITestClient(create_greet_bearer_auth_app()), - credentials=HTTPAuthentication(token=BEARER_TOKEN)) - client = OpenAPIServiceClient(config) - payload = { - "id": "call_NJr1NBz2Th7iUWJpRIJZoJIA", - "function": { - "arguments": '{"name": "John"}', - "name": "greetBearerAuth", - }, - "type": "function", - } - response = client.invoke(payload) - assert response == {"greeting": "Hello, John from bearer_auth, using secret_bearer_token"} + assert response == {"greeting": "Hello, John from api_key_cookie_auth, using secret_api_key_cookie"} \ No newline at end of file diff --git a/test/components/tools/openapi/test_openapi_client_complex_request_body.py b/test/components/tools/openapi/test_openapi_client_complex_request_body.py index e6007efb..4b4b5a20 100644 --- a/test/components/tools/openapi/test_openapi_client_complex_request_body.py +++ b/test/components/tools/openapi/test_openapi_client_complex_request_body.py @@ -11,7 +11,7 @@ from fastapi.responses import JSONResponse from pydantic import BaseModel -from haystack_experimental.components.tools.openapi.openapi import OpenAPIServiceClient, ClientConfiguration +from haystack_experimental.components.tools.openapi._openapi import OpenAPIServiceClient, ClientConfiguration from test.components.tools.openapi.conftest import FastAPITestClient diff --git a/test/components/tools/openapi/test_openapi_client_complex_request_body_mixed.py b/test/components/tools/openapi/test_openapi_client_complex_request_body_mixed.py index 33e95387..bcb5cf48 100644 --- a/test/components/tools/openapi/test_openapi_client_complex_request_body_mixed.py +++ b/test/components/tools/openapi/test_openapi_client_complex_request_body_mixed.py @@ -9,7 +9,7 @@ from fastapi.responses import JSONResponse from pydantic import BaseModel -from haystack_experimental.components.tools.openapi.openapi import OpenAPIServiceClient, ClientConfiguration +from haystack_experimental.components.tools.openapi._openapi import OpenAPIServiceClient, ClientConfiguration from test.components.tools.openapi.conftest import FastAPITestClient diff --git a/test/components/tools/openapi/test_openapi_client_edge_cases.py b/test/components/tools/openapi/test_openapi_client_edge_cases.py index 912fe7d5..4dfe7a06 100644 --- a/test/components/tools/openapi/test_openapi_client_edge_cases.py +++ b/test/components/tools/openapi/test_openapi_client_edge_cases.py @@ -5,7 +5,7 @@ import pytest -from haystack_experimental.components.tools.openapi.openapi import OpenAPIServiceClient, ClientConfiguration +from haystack_experimental.components.tools.openapi._openapi import OpenAPIServiceClient, ClientConfiguration from test.components.tools.openapi.conftest import FastAPITestClient diff --git a/test/components/tools/openapi/test_openapi_client_error_handling.py b/test/components/tools/openapi/test_openapi_client_error_handling.py index 5b6e8dc4..a1d730aa 100644 --- a/test/components/tools/openapi/test_openapi_client_error_handling.py +++ b/test/components/tools/openapi/test_openapi_client_error_handling.py @@ -8,7 +8,7 @@ import pytest from fastapi import FastAPI, HTTPException -from haystack_experimental.components.tools.openapi.openapi import OpenAPIServiceClient, HttpClientError, \ +from haystack_experimental.components.tools.openapi._openapi import OpenAPIServiceClient, HttpClientError, \ ClientConfiguration from test.components.tools.openapi.conftest import FastAPITestClient diff --git a/test/components/tools/openapi/test_openapi_client_live.py b/test/components/tools/openapi/test_openapi_client_live.py index 1ee5b9f4..02ae0b74 100644 --- a/test/components/tools/openapi/test_openapi_client_live.py +++ b/test/components/tools/openapi/test_openapi_client_live.py @@ -7,7 +7,7 @@ import pytest import yaml -from haystack_experimental.components.tools.openapi.openapi import OpenAPIServiceClient, ClientConfiguration +from haystack_experimental.components.tools.openapi._openapi import OpenAPIServiceClient, ClientConfiguration class TestClientLive: @@ -28,26 +28,6 @@ def test_serperdev(self, test_files_path): response = serper_api.invoke(payload) assert "invention" in str(response) - @pytest.mark.skipif("SERPERDEV_API_KEY" not in os.environ, reason="SERPERDEV_API_KEY not set") - @pytest.mark.integration - def test_serperdev_load_spec_first(self, test_files_path): - with open(test_files_path / "yaml" / "serper.yml") as file: - loaded_spec = yaml.safe_load(file) - - # use builder with dict spec - config = ClientConfiguration(openapi_spec=loaded_spec, credentials=os.getenv("SERPERDEV_API_KEY")) - serper_api = OpenAPIServiceClient(config) - payload = { - "id": "call_NJr1NBz2Th7iUWJpRIJZoJIA", - "function": { - "arguments": '{"q": "Who was Nikola Tesla?"}', - "name": "serperdev_search", - }, - "type": "function", - } - response = serper_api.invoke(payload) - assert "invention" in str(response) - @pytest.mark.integration def test_github(self, test_files_path): config = ClientConfiguration(openapi_spec=test_files_path / "yaml" / "github_compare.yml") diff --git a/test/components/tools/openapi/test_openapi_client_live_anthropic.py b/test/components/tools/openapi/test_openapi_client_live_anthropic.py index ede1503a..91ca6334 100644 --- a/test/components/tools/openapi/test_openapi_client_live_anthropic.py +++ b/test/components/tools/openapi/test_openapi_client_live_anthropic.py @@ -7,7 +7,8 @@ import anthropic import pytest -from haystack_experimental.components.tools.openapi.openapi import ClientConfiguration, OpenAPIServiceClient +from haystack_experimental.components.tools.openapi._openapi import ClientConfiguration, OpenAPIServiceClient, \ + LLMProvider class TestClientLiveAnthropic: @@ -18,9 +19,9 @@ class TestClientLiveAnthropic: def test_serperdev(self, test_files_path): config = ClientConfiguration(openapi_spec=test_files_path / "yaml" / "serper.yml", credentials=os.getenv("SERPERDEV_API_KEY"), - llm_provider="anthropic") + llm_provider=LLMProvider.ANTHROPIC) client = anthropic.Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY")) - response = client.beta.tools.messages.create( + response = client.messages.create( model="claude-3-opus-20240229", max_tokens=1024, tools=config.get_tools_definitions(), @@ -41,10 +42,10 @@ def test_serperdev(self, test_files_path): @pytest.mark.integration def test_github(self, test_files_path): config = ClientConfiguration(openapi_spec=test_files_path / "yaml" / "github_compare.yml", - llm_provider="anthropic") + llm_provider=LLMProvider.ANTHROPIC) client = anthropic.Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY")) - response = client.beta.tools.messages.create( + response = client.messages.create( model="claude-3-opus-20240229", max_tokens=1024, tools=config.get_tools_definitions(), diff --git a/test/components/tools/openapi/test_openapi_client_live_cohere.py b/test/components/tools/openapi/test_openapi_client_live_cohere.py index 4cd87631..891bb5fa 100644 --- a/test/components/tools/openapi/test_openapi_client_live_cohere.py +++ b/test/components/tools/openapi/test_openapi_client_live_cohere.py @@ -6,7 +6,8 @@ import cohere import pytest -from haystack_experimental.components.tools.openapi.openapi import ClientConfiguration, OpenAPIServiceClient +from haystack_experimental.components.tools.openapi._openapi import ClientConfiguration, OpenAPIServiceClient, \ + LLMProvider # Copied from Cohere's documentation preamble = """ @@ -30,7 +31,7 @@ class TestClientLiveCohere: def test_serperdev(self, test_files_path): config = ClientConfiguration(openapi_spec=test_files_path / "yaml" / "serper.yml", credentials=os.getenv("SERPERDEV_API_KEY"), - llm_provider="cohere") + llm_provider=LLMProvider.COHERE) client = cohere.Client(api_key=os.getenv("COHERE_API_KEY")) response = client.chat( model="command-r", @@ -53,7 +54,7 @@ def test_serperdev(self, test_files_path): @pytest.mark.integration def test_github(self, test_files_path): config = ClientConfiguration(openapi_spec=test_files_path / "yaml" / "github_compare.yml", - llm_provider="cohere") + llm_provider=LLMProvider.COHERE) client = cohere.Client(api_key=os.getenv("COHERE_API_KEY")) response = client.chat( diff --git a/test/components/tools/openapi/test_openapi_client_live_openai.py b/test/components/tools/openapi/test_openapi_client_live_openai.py index 4f2e2c39..c95f6614 100644 --- a/test/components/tools/openapi/test_openapi_client_live_openai.py +++ b/test/components/tools/openapi/test_openapi_client_live_openai.py @@ -7,7 +7,7 @@ import pytest from openai import OpenAI -from haystack_experimental.components.tools.openapi.openapi import ClientConfiguration, OpenAPIServiceClient +from haystack_experimental.components.tools.openapi._openapi import ClientConfiguration, OpenAPIServiceClient class TestClientLiveOpenAPI: diff --git a/test/components/tools/openapi/test_openapi_cohere_conversion.py b/test/components/tools/openapi/test_openapi_cohere_conversion.py index dd84b9b5..5837c040 100644 --- a/test/components/tools/openapi/test_openapi_cohere_conversion.py +++ b/test/components/tools/openapi/test_openapi_cohere_conversion.py @@ -2,7 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 -from haystack_experimental.components.tools.openapi.openapi import OpenAPISpecification, cohere_converter +from haystack_experimental.components.tools.openapi._openapi import OpenAPISpecification, cohere_converter class TestOpenAPISchemaConversion: diff --git a/test/components/tools/openapi/test_openapi_openai_conversion.py b/test/components/tools/openapi/test_openapi_openai_conversion.py index 3bf7cc19..090f9c4a 100644 --- a/test/components/tools/openapi/test_openapi_openai_conversion.py +++ b/test/components/tools/openapi/test_openapi_openai_conversion.py @@ -4,7 +4,7 @@ import pytest -from haystack_experimental.components.tools.openapi.openapi import openai_converter, anthropic_converter, OpenAPISpecification +from haystack_experimental.components.tools.openapi._openapi import openai_converter, anthropic_converter, OpenAPISpecification class TestOpenAPISchemaConversion: diff --git a/test/components/tools/openapi/test_openapi_spec.py b/test/components/tools/openapi/test_openapi_spec.py index 93e2a972..4e38de2d 100644 --- a/test/components/tools/openapi/test_openapi_spec.py +++ b/test/components/tools/openapi/test_openapi_spec.py @@ -4,26 +4,11 @@ import pytest -from haystack_experimental.components.tools.openapi.openapi import OpenAPISpecification +from haystack_experimental.components.tools.openapi._openapi import OpenAPISpecification class TestOpenAPISpecification: - # can be initialized from a dictionary - def test_initialized_from_dictionary(self): - spec_dict = { - "openapi": "3.0.0", - "info": {"title": "Test API", "version": "1.0.0"}, - "servers": [{"url": "https://api.example.com"}], - "paths": { - "/users": { - "get": {"summary": "Get all users", "responses": {"200": {"description": "Successful response"}}} - } - }, - } - openapi_spec = OpenAPISpecification.from_dict(spec_dict) - assert openapi_spec.spec_dict == spec_dict - # can be initialized from a string def test_initialized_from_string(self): content = """ From 2f011800a6bfbaffe296be18b0312e80feb4e79c Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Thu, 6 Jun 2024 17:39:32 +0200 Subject: [PATCH 16/40] Remove lazy imports --- .../components/tools/openapi/openapi_tool.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/haystack_experimental/components/tools/openapi/openapi_tool.py b/haystack_experimental/components/tools/openapi/openapi_tool.py index cb26cc5e..78b3939f 100644 --- a/haystack_experimental/components/tools/openapi/openapi_tool.py +++ b/haystack_experimental/components/tools/openapi/openapi_tool.py @@ -9,7 +9,6 @@ from haystack import component, logging from haystack.components.generators.chat import OpenAIChatGenerator from haystack.dataclasses import ChatMessage, ChatRole -from haystack.lazy_imports import LazyImport from haystack_experimental.components.tools.openapi._openapi import ( ClientConfiguration, @@ -141,16 +140,4 @@ def _init_generator(self, generator_api: LLMProvider, generator_api_params: Dict """ if generator_api == LLMProvider.OPENAI: return OpenAIChatGenerator(**generator_api_params) - if generator_api == LLMProvider.ANTHROPIC: - with LazyImport("Run 'pip install anthropic-haystack'") as anthropic_import: - anthropic_import.check() - # pylint: disable=import-outside-toplevel - from haystack_integrations.components.generators.anthropic import AnthropicChatGenerator - return AnthropicChatGenerator(**generator_api_params) - if generator_api == LLMProvider.COHERE: - with LazyImport("Run 'pip install cohere-haystack'") as cohere_import: - cohere_import.check() - # pylint: disable=import-outside-toplevel - from haystack_integrations.components.generators.cohere import CohereChatGenerator - return CohereChatGenerator(**generator_api_params) raise ValueError(f"Unsupported generator API: {generator_api}") From 4d90cee2b4187144be9af90045864c6f16064059 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Thu, 6 Jun 2024 17:46:56 +0200 Subject: [PATCH 17/40] Typing fixes --- haystack_experimental/components/tools/openapi/_openapi.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/haystack_experimental/components/tools/openapi/_openapi.py b/haystack_experimental/components/tools/openapi/_openapi.py index e171c35d..990470cd 100644 --- a/haystack_experimental/components/tools/openapi/_openapi.py +++ b/haystack_experimental/components/tools/openapi/_openapi.py @@ -351,7 +351,7 @@ def __init__( # noqa: PLR0913 pylint: disable=too-many-arguments self.request_sender = request_sender or send_request self.llm_provider: LLMProvider = llm_provider or LLMProvider.OPENAI - def get_auth_function(self) -> Callable[[dict[str, Any], dict[str, Any]], Any]: + def get_auth_function(self) -> Callable[[Dict[str, Any], Dict[str, Any]], Any]: """ Get the authentication function that sets a schema specified authentication to the request. @@ -401,7 +401,7 @@ def get_payload_extractor(self): def _create_authentication_from_string( self, credentials: str, security_schemes: Dict[str, Any] - ) -> Callable[[dict[str, Any], dict[str, Any]], Any]: + ) -> Callable[[Dict[str, Any], Dict[str, Any]], Any]: for scheme in security_schemes.values(): if scheme["type"] == "apiKey": return create_api_key_auth_function(api_key=credentials) From e57ac8a9e6422659354de3ac7e01991b990b3c4a Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Thu, 6 Jun 2024 17:57:25 +0200 Subject: [PATCH 18/40] Add lazy imports --- .../components/tools/openapi/openapi_tool.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/haystack_experimental/components/tools/openapi/openapi_tool.py b/haystack_experimental/components/tools/openapi/openapi_tool.py index 78b3939f..bcb504cc 100644 --- a/haystack_experimental/components/tools/openapi/openapi_tool.py +++ b/haystack_experimental/components/tools/openapi/openapi_tool.py @@ -9,6 +9,7 @@ from haystack import component, logging from haystack.components.generators.chat import OpenAIChatGenerator from haystack.dataclasses import ChatMessage, ChatRole +from haystack.lazy_imports import LazyImport from haystack_experimental.components.tools.openapi._openapi import ( ClientConfiguration, @@ -16,6 +17,14 @@ OpenAPIServiceClient, ) +with LazyImport("Run 'pip install anthropic-haystack'") as anthropic_import: + # pylint: disable=import-error + from haystack_integrations.components.generators.anthropic import AnthropicChatGenerator + +with LazyImport("Run 'pip install cohere-haystack'") as cohere_import: + # pylint: disable=import-error + from haystack_integrations.components.generators.cohere import CohereChatGenerator + logger = logging.getLogger(__name__) @@ -140,4 +149,10 @@ def _init_generator(self, generator_api: LLMProvider, generator_api_params: Dict """ if generator_api == LLMProvider.OPENAI: return OpenAIChatGenerator(**generator_api_params) + if generator_api == LLMProvider.COHERE: + cohere_import.check() + return CohereChatGenerator(**generator_api_params) + if generator_api == LLMProvider.ANTHROPIC: + anthropic_import.check() + return AnthropicChatGenerator(**generator_api_params) raise ValueError(f"Unsupported generator API: {generator_api}") From 1b87f825cd1711ad7de0f08a5bdef62314948165 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Thu, 6 Jun 2024 18:27:52 +0200 Subject: [PATCH 19/40] Expose LLMProvider --- haystack_experimental/components/tools/openapi/__init__.py | 3 ++- haystack_experimental/components/tools/openapi/openapi_tool.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/haystack_experimental/components/tools/openapi/__init__.py b/haystack_experimental/components/tools/openapi/__init__.py index c109a7d3..f8774880 100644 --- a/haystack_experimental/components/tools/openapi/__init__.py +++ b/haystack_experimental/components/tools/openapi/__init__.py @@ -2,6 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 +from haystack_experimental.components.tools.openapi._openapi import LLMProvider from haystack_experimental.components.tools.openapi.openapi_tool import OpenAPITool -__all__ = ["OpenAPITool"] +__all__ = ["LLMProvider", "OpenAPITool"] diff --git a/haystack_experimental/components/tools/openapi/openapi_tool.py b/haystack_experimental/components/tools/openapi/openapi_tool.py index bcb504cc..ea9f4787 100644 --- a/haystack_experimental/components/tools/openapi/openapi_tool.py +++ b/haystack_experimental/components/tools/openapi/openapi_tool.py @@ -10,10 +10,10 @@ from haystack.components.generators.chat import OpenAIChatGenerator from haystack.dataclasses import ChatMessage, ChatRole from haystack.lazy_imports import LazyImport +from haystack_experimental.components.tools.openapi import LLMProvider from haystack_experimental.components.tools.openapi._openapi import ( ClientConfiguration, - LLMProvider, OpenAPIServiceClient, ) From 7e0948023fb771249b1bee8405419053af53ce20 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Thu, 6 Jun 2024 21:08:12 +0200 Subject: [PATCH 20/40] Avoid circular deps --- .../components/tools/openapi/__init__.py | 2 +- .../components/tools/openapi/_openapi.py | 11 +---------- .../components/tools/openapi/openapi_tool.py | 2 +- .../components/tools/openapi/types.py | 10 ++++++++++ 4 files changed, 13 insertions(+), 12 deletions(-) create mode 100644 haystack_experimental/components/tools/openapi/types.py diff --git a/haystack_experimental/components/tools/openapi/__init__.py b/haystack_experimental/components/tools/openapi/__init__.py index f8774880..6867e62b 100644 --- a/haystack_experimental/components/tools/openapi/__init__.py +++ b/haystack_experimental/components/tools/openapi/__init__.py @@ -2,7 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 -from haystack_experimental.components.tools.openapi._openapi import LLMProvider from haystack_experimental.components.tools.openapi.openapi_tool import OpenAPITool +from haystack_experimental.components.tools.openapi.types import LLMProvider __all__ = ["LLMProvider", "OpenAPITool"] diff --git a/haystack_experimental/components/tools/openapi/_openapi.py b/haystack_experimental/components/tools/openapi/_openapi.py index 990470cd..ed876de2 100644 --- a/haystack_experimental/components/tools/openapi/_openapi.py +++ b/haystack_experimental/components/tools/openapi/_openapi.py @@ -6,7 +6,6 @@ import logging import os from dataclasses import dataclass, field -from enum import Enum from pathlib import Path from typing import Any, Callable, Dict, List, Literal, Optional, Union from urllib.parse import urlparse @@ -22,6 +21,7 @@ cohere_converter, openai_converter, ) +from haystack_experimental.components.tools.openapi.types import LLMProvider VALID_HTTP_METHODS = [ "get", @@ -37,15 +37,6 @@ logger = logging.getLogger(__name__) -class LLMProvider(Enum): - """ - Enum for the supported LLM providers. - """ - OPENAI = "openai" - ANTHROPIC = "anthropic" - COHERE = "cohere" - - def is_valid_http_url(url: str) -> bool: """ Check if a URL is a valid HTTP/HTTPS URL. diff --git a/haystack_experimental/components/tools/openapi/openapi_tool.py b/haystack_experimental/components/tools/openapi/openapi_tool.py index ea9f4787..de5062d3 100644 --- a/haystack_experimental/components/tools/openapi/openapi_tool.py +++ b/haystack_experimental/components/tools/openapi/openapi_tool.py @@ -10,12 +10,12 @@ from haystack.components.generators.chat import OpenAIChatGenerator from haystack.dataclasses import ChatMessage, ChatRole from haystack.lazy_imports import LazyImport -from haystack_experimental.components.tools.openapi import LLMProvider from haystack_experimental.components.tools.openapi._openapi import ( ClientConfiguration, OpenAPIServiceClient, ) +from haystack_experimental.components.tools.openapi.types import LLMProvider with LazyImport("Run 'pip install anthropic-haystack'") as anthropic_import: # pylint: disable=import-error diff --git a/haystack_experimental/components/tools/openapi/types.py b/haystack_experimental/components/tools/openapi/types.py new file mode 100644 index 00000000..ed2a2c99 --- /dev/null +++ b/haystack_experimental/components/tools/openapi/types.py @@ -0,0 +1,10 @@ +from enum import Enum + + +class LLMProvider(Enum): + """ + Enum for the supported LLM providers. + """ + OPENAI = "openai" + ANTHROPIC = "anthropic" + COHERE = "cohere" From a7fcadc5f87d04aee1ebc234f5f81cfd0b23ce1e Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Thu, 6 Jun 2024 21:14:02 +0200 Subject: [PATCH 21/40] Add header for types.py --- haystack_experimental/components/tools/openapi/types.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/haystack_experimental/components/tools/openapi/types.py b/haystack_experimental/components/tools/openapi/types.py index ed2a2c99..542fea20 100644 --- a/haystack_experimental/components/tools/openapi/types.py +++ b/haystack_experimental/components/tools/openapi/types.py @@ -1,3 +1,7 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + from enum import Enum From 190f113ed0b76bf8b3956288a7de7fae7ec3d571 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Fri, 7 Jun 2024 05:48:03 +0200 Subject: [PATCH 22/40] Improve pydoc --- .../components/tools/openapi/openapi_tool.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/haystack_experimental/components/tools/openapi/openapi_tool.py b/haystack_experimental/components/tools/openapi/openapi_tool.py index de5062d3..17c98da4 100644 --- a/haystack_experimental/components/tools/openapi/openapi_tool.py +++ b/haystack_experimental/components/tools/openapi/openapi_tool.py @@ -31,16 +31,16 @@ @component class OpenAPITool: """ - The OpenAPITool calls an OpenAPI service using payloads generated by the chat generator from human instructions. + The OpenAPITool calls a RESTful endpoint of an OpenAPI service using payloads generated from human instructions. Here is an example of how to use the OpenAPITool component to scrape a URL using the FireCrawl API: ```python - from haystack_experimental.components.tools import OpenAPITool - from haystack.components.generators.chat.openai import OpenAIChatGenerator from haystack.dataclasses import ChatMessage + from haystack_experimental.components.tools.openapi import OpenAPITool, LLMProvider - tool = OpenAPITool(model="gpt-3.5-turbo", + tool = OpenAPITool(generator_api=LLMProvider.OPENAI, + generator_api_params={"model":"gpt-3.5-turbo"}, tool_spec="https://raw.githubusercontent.com/mendableai/firecrawl/main/apps/api/openapi.json", tool_credentials="") @@ -63,8 +63,9 @@ def __init__( Initialize the OpenAPITool component. :param generator_api: The API provider for the chat generator. - :param generator_api_params: Parameters for the chat generator. - :param tool_spec: OpenAPI specification for the tool/service. + :param generator_api_params: Parameters to pass for the chat generator creation. + :param tool_spec: OpenAPI specification for the tool/service. This can be a URL, a local file path, or + an OpenAPI service specification provided as a string. :param tool_credentials: Credentials for the tool/service. """ self.generator_api = generator_api @@ -95,7 +96,10 @@ def run( :param tool_spec: OpenAPI specification for the tool/service, overrides the one provided at initialization. :param tool_credentials: Credentials for the tool/service, overrides the one provided at initialization. :returns: a dictionary containing the service response with the following key: - - `service_response`: List of ChatMessages containing the service response. + - `service_response`: List of ChatMessages containing the service response. ChatMessages are generated + based on the response from the OpenAPI service/tool and contains the JSON response from the service. + If there is an error during the invocation, the response will be a ChatMessage with the error message under + the `error` key. """ last_message = messages[-1] if not last_message.is_from(ChatRole.USER): From d21bcadd19418cc92988b54bb328f17a028635fd Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Mon, 17 Jun 2024 13:59:05 +0200 Subject: [PATCH 23/40] Add back in http bearer auth --- .../components/tools/openapi/_openapi.py | 42 +++++++++++++++++-- 1 file changed, 39 insertions(+), 3 deletions(-) diff --git a/haystack_experimental/components/tools/openapi/_openapi.py b/haystack_experimental/components/tools/openapi/_openapi.py index ed876de2..edd640e5 100644 --- a/haystack_experimental/components/tools/openapi/_openapi.py +++ b/haystack_experimental/components/tools/openapi/_openapi.py @@ -90,7 +90,7 @@ def create_api_key_auth_function(api_key: str): at the schema specified location. """ - def apply_api_key_auth(security_scheme: Dict[str, Any], request: Dict[str, Any]): + def apply_auth(security_scheme: Dict[str, Any], request: Dict[str, Any]): """ Apply the API key authentication strategy to the given request. @@ -109,7 +109,43 @@ def apply_api_key_auth(security_scheme: Dict[str, Any], request: Dict[str, Any]) f"must be one of 'header', 'query', or 'cookie'" ) - return apply_api_key_auth + return apply_auth + + +def create_http_auth_function(token: str): + """ + Create a function that applies the http authentication strategy to a given request. + + :param token: the authentication token to use. + :return: a function that applies the API key authentication to a request + at the schema specified location. + """ + + def apply_auth(security_scheme: Dict[str, Any], request: Dict[str, Any]): + """ + Apply the HTTP authentication strategy to the given request. + + :param security_scheme: the security scheme from the OpenAPI spec. + :param request: the request to apply the authentication to. + """ + if security_scheme["type"] == "http": + # support bearer http auth, no basic support yet + if security_scheme["scheme"].lower() == "bearer": + if not token: + raise ValueError("Token must be provided for Bearer Auth.") + request.setdefault("headers", {})[ + "Authorization" + ] = f"Bearer {token}" + else: + raise ValueError( + f"Unsupported HTTP authentication scheme: {security_scheme['scheme']}" + ) + else: + raise ValueError( + "HTTPAuthentication strategy received a non-HTTP security scheme." + ) + + return apply_auth class HttpClientError(Exception): @@ -397,7 +433,7 @@ def _create_authentication_from_string( if scheme["type"] == "apiKey": return create_api_key_auth_function(api_key=credentials) if scheme["type"] == "http": - raise NotImplementedError("HTTP authentication is not yet supported.") + return create_http_auth_function(token=credentials) if scheme["type"] == "oauth2": raise NotImplementedError("OAuth2 authentication is not yet supported.") raise ValueError( From e264ec20c024f8717a4d8d90ef34558a60bb2507 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Tue, 18 Jun 2024 11:21:46 +0200 Subject: [PATCH 24/40] Add firecrawl openapi conversion tests --- .../tools/openapi/_schema_conversion.py | 2 +- .../openapi/test_openapi_cohere_conversion.py | 87 +- .../openapi/test_openapi_openai_conversion.py | 6 +- .../json/firecrawl_openapi_spec.json | 881 ++++++++++++++++++ 4 files changed, 973 insertions(+), 3 deletions(-) create mode 100644 test/test_files/json/firecrawl_openapi_spec.json diff --git a/haystack_experimental/components/tools/openapi/_schema_conversion.py b/haystack_experimental/components/tools/openapi/_schema_conversion.py index abc23c9c..ca967b34 100644 --- a/haystack_experimental/components/tools/openapi/_schema_conversion.py +++ b/haystack_experimental/components/tools/openapi/_schema_conversion.py @@ -248,7 +248,7 @@ def _parse_schema( nested_parameters = { name: _parse_schema( schema=prop_schema, - required=bool(name in schema.get("required", False)), + required=bool(name in schema.get("required", [])), description=prop_schema.get("description", ""), ) for name, prop_schema in properties.items() diff --git a/test/components/tools/openapi/test_openapi_cohere_conversion.py b/test/components/tools/openapi/test_openapi_cohere_conversion.py index 5837c040..ac8f42b2 100644 --- a/test/components/tools/openapi/test_openapi_cohere_conversion.py +++ b/test/components/tools/openapi/test_openapi_cohere_conversion.py @@ -16,7 +16,92 @@ def test_serperdev(self, test_files_path): function = functions[0] assert function["name"] == "serperdev_search" assert function["description"] == "Search the web with Google" - assert function["parameter_definitions"] == {"q": {"description": "", "type": "str", "required": True}} + assert function["parameter_definitions"] == { + "q": {"description": "", "type": "str", "required": True} + } + + def test_firecrawler(self, test_files_path): + spec = OpenAPISpecification.from_file( + test_files_path / "json" / "firecrawl_openapi_spec.json" + ) + functions = cohere_converter(schema=spec) + assert functions + assert len(functions) == 5 + function = functions[0] + assert function["name"] == "scrapeAndExtractFromUrl" + assert ( + function["description"] + == "Scrape a single URL and optionally extract information using an LLM" + ) + assert function["parameter_definitions"] == { + "url": {"type": "str", "description": "The URL to scrape", "required": True}, + "pageOptions": { + "type": "object", + "description": "", + "required": False, + "properties": { + "onlyMainContent": { + "type": "bool", + "description": "Only return the main content of the page excluding headers, navs, footers, etc.", + "required": False, + }, + "includeHtml": { + "type": "bool", + "description": "Include the raw HTML content of the page. Will output a html key in the response.", + "required": False, + }, + "screenshot": { + "type": "bool", + "description": "Include a screenshot of the top of the page that you are scraping.", + "required": False, + }, + "waitFor": { + "type": "int", + "description": "Wait x amount of milliseconds for the page to load to fetch content", + "required": False, + }, + "removeTags": { + "type": "list", + "description": "Tags, classes and ids to remove from the page. Use comma separated values. Example: 'script, .ad, #footer'", + "required": False, + }, + "headers": { + "type": "object", + "description": "Headers to send with the request. Can be used to send cookies, user-agent, etc.", + "properties": {}, + "required": False, + }, + }, + }, + "extractorOptions": { + "type": "object", + "description": "Options for LLM-based extraction of structured information from the page content", + "required": False, + "properties": { + "mode": { + "type": "str", + "description": "The extraction mode to use, currently supports 'llm-extraction'", + "required": False, + }, + "extractionPrompt": { + "type": "str", + "description": "A prompt describing what information to extract from the page", + "required": False, + }, + "extractionSchema": { + "type": "object", + "description": "The schema for the data to be extracted", + "properties": {}, + "required": False, + }, + }, + }, + "timeout": { + "type": "int", + "description": "Timeout in milliseconds for the request", + "required": False, + }, + } def test_github(self, test_files_path): spec = OpenAPISpecification.from_file(test_files_path / "yaml" / "github_compare.yml") diff --git a/test/components/tools/openapi/test_openapi_openai_conversion.py b/test/components/tools/openapi/test_openapi_openai_conversion.py index 090f9c4a..9e7285dc 100644 --- a/test/components/tools/openapi/test_openapi_openai_conversion.py +++ b/test/components/tools/openapi/test_openapi_openai_conversion.py @@ -4,7 +4,11 @@ import pytest -from haystack_experimental.components.tools.openapi._openapi import openai_converter, anthropic_converter, OpenAPISpecification +from haystack_experimental.components.tools.openapi._openapi import ( + openai_converter, + anthropic_converter, + OpenAPISpecification, +) class TestOpenAPISchemaConversion: diff --git a/test/test_files/json/firecrawl_openapi_spec.json b/test/test_files/json/firecrawl_openapi_spec.json new file mode 100644 index 00000000..e5868571 --- /dev/null +++ b/test/test_files/json/firecrawl_openapi_spec.json @@ -0,0 +1,881 @@ +{ + "openapi": "3.0.0", + "info": { + "title": "Firecrawl API", + "version": "1.0.0", + "description": "API for interacting with Firecrawl services to perform web scraping and crawling tasks.", + "contact": { + "name": "Firecrawl Support", + "url": "https://firecrawl.dev/support", + "email": "support@firecrawl.dev" + } + }, + "servers": [ + { + "url": "https://api.firecrawl.dev/v0" + } + ], + "paths": { + "/scrape": { + "post": { + "summary": "Scrape a single URL and optionally extract information using an LLM", + "operationId": "scrapeAndExtractFromUrl", + "tags": ["Scraping"], + "security": [ + { + "bearerAuth": [] + } + ], + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "url": { + "type": "string", + "format": "uri", + "description": "The URL to scrape" + }, + "pageOptions": { + "type": "object", + "properties": { + "onlyMainContent": { + "type": "boolean", + "description": "Only return the main content of the page excluding headers, navs, footers, etc.", + "default": false + }, + "includeHtml": { + "type": "boolean", + "description": "Include the raw HTML content of the page. Will output a html key in the response.", + "default": false + }, + "screenshot": { + "type": "boolean", + "description": "Include a screenshot of the top of the page that you are scraping.", + "default": false + }, + "waitFor": { + "type": "integer", + "description": "Wait x amount of milliseconds for the page to load to fetch content", + "default": 0 + }, + "removeTags": { + "type": "array", + "items": { + "type": "string" + }, + "description": "Tags, classes and ids to remove from the page. Use comma separated values. Example: 'script, .ad, #footer'" + }, + "headers": { + "type": "object", + "description": "Headers to send with the request. Can be used to send cookies, user-agent, etc." + } + } + }, + "extractorOptions": { + "type": "object", + "description": "Options for LLM-based extraction of structured information from the page content", + "properties": { + "mode": { + "type": "string", + "enum": ["llm-extraction"], + "description": "The extraction mode to use, currently supports 'llm-extraction'" + }, + "extractionPrompt": { + "type": "string", + "description": "A prompt describing what information to extract from the page" + }, + "extractionSchema": { + "type": "object", + "additionalProperties": true, + "description": "The schema for the data to be extracted", + "required": [ + "company_mission", + "supports_sso", + "is_open_source" + ] + } + } + }, + "timeout": { + "type": "integer", + "description": "Timeout in milliseconds for the request", + "default": 30000 + } + }, + "required": ["url"] + } + } + } + }, + "responses": { + "200": { + "description": "Successful response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ScrapeResponse" + } + } + } + }, + "402": { + "description": "Payment required" + }, + "429": { + "description": "Too many requests" + }, + "500": { + "description": "Server error" + } + } + } + }, + "/crawl": { + "post": { + "summary": "Crawl multiple URLs based on options", + "operationId": "crawlUrls", + "tags": ["Crawling"], + "security": [ + { + "bearerAuth": [] + } + ], + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "url": { + "type": "string", + "format": "uri", + "description": "The base URL to start crawling from" + }, + "crawlerOptions": { + "type": "object", + "properties": { + "includes": { + "type": "array", + "items": { + "type": "string" + }, + "description": "URL patterns to include" + }, + "excludes": { + "type": "array", + "items": { + "type": "string" + }, + "description": "URL patterns to exclude" + }, + "generateImgAltText": { + "type": "boolean", + "description": "Generate alt text for images using LLMs (must have a paid plan)", + "default": false + }, + "returnOnlyUrls": { + "type": "boolean", + "description": "If true, returns only the URLs as a list on the crawl status. Attention: the return response will be a list of URLs inside the data, not a list of documents.", + "default": false + }, + "maxDepth": { + "type": "integer", + "description": "Maximum depth to crawl. Depth 1 is the base URL, depth 2 is the base URL and its direct children, and so on." + }, + "mode": { + "type": "string", + "enum": ["default", "fast"], + "description": "The crawling mode to use. Fast mode crawls 4x faster websites without sitemap, but may not be as accurate and shouldn't be used in heavy js-rendered websites.", + "default": "default" + }, + "ignoreSitemap": { + "type": "boolean", + "description": "Ignore the website sitemap when crawling", + "default": false + }, + "limit": { + "type": "integer", + "description": "Maximum number of pages to crawl", + "default": 10000 + }, + "allowBackwardCrawling": { + "type": "boolean", + "description": "Allow backward crawling (crawl from the base URL to the previous URLs)", + "default": false + } + } + }, + "pageOptions": { + "type": "object", + "properties": { + "onlyMainContent": { + "type": "boolean", + "description": "Only return the main content of the page excluding headers, navs, footers, etc.", + "default": false + }, + "includeHtml": { + "type": "boolean", + "description": "Include the raw HTML content of the page. Will output a html key in the response.", + "default": false + }, + "screenshot": { + "type": "boolean", + "description": "Include a screenshot of the top of the page that you are scraping.", + "default": false + }, + "headers": { + "type": "object", + "description": "Headers to send with the request when scraping. Can be used to send cookies, user-agent, etc." + }, + "removeTags": { + "type": "array", + "items": { + "type": "string" + }, + "description": "Tags, classes and ids to remove from the page. Use comma separated values. Example: 'script, .ad, #footer'" + }, + "replaceAllPathsWithAbsolutePaths": { + "type": "boolean", + "description": "Replace all relative paths with absolute paths for images and links", + "default": false + } + } + } + }, + "required": ["url"] + } + } + } + }, + "responses": { + "200": { + "description": "Successful response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/CrawlResponse" + } + } + } + }, + "402": { + "description": "Payment required" + }, + "429": { + "description": "Too many requests" + }, + "500": { + "description": "Server error" + } + } + } + }, + "/search": { + "post": { + "summary": "Search for a keyword in Google, returns top page results with markdown content for each page", + "operationId": "searchGoogle", + "tags": ["Search"], + "security": [ + { + "bearerAuth": [] + } + ], + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "query": { + "type": "string", + "format": "uri", + "description": "The query to search for" + }, + "pageOptions": { + "type": "object", + "properties": { + "onlyMainContent": { + "type": "boolean", + "description": "Only return the main content of the page excluding headers, navs, footers, etc.", + "default": false + }, + "fetchPageContent": { + "type": "boolean", + "description": "Fetch the content of each page. If false, defaults to a basic fast serp API.", + "default": true + }, + "includeHtml": { + "type": "boolean", + "description": "Include the raw HTML content of the page. Will output a html key in the response.", + "default": false + } + } + }, + "searchOptions": { + "type": "object", + "properties": { + "limit": { + "type": "integer", + "description": "Maximum number of results. Max is 20 during beta." + } + } + } + }, + "required": ["query"] + } + } + } + }, + "responses": { + "200": { + "description": "Successful response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/SearchResponse" + } + } + } + }, + "402": { + "description": "Payment required" + }, + "429": { + "description": "Too many requests" + }, + "500": { + "description": "Server error" + } + } + } + }, + "/crawl/status/{jobId}": { + "get": { + "tags": ["Crawl"], + "summary": "Get the status of a crawl job", + "operationId": "getCrawlStatus", + "security": [ + { + "bearerAuth": [] + } + ], + "parameters": [ + { + "name": "jobId", + "in": "path", + "description": "ID of the crawl job", + "required": true, + "schema": { + "type": "string" + } + } + ], + "responses": { + "200": { + "description": "Successful response", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "status": { + "type": "string", + "description": "Status of the job (completed, active, failed, paused)" + }, + "current": { + "type": "integer", + "description": "Current page number" + }, + "current_url": { + "type": "string", + "description": "Current URL being scraped" + }, + "current_step": { + "type": "string", + "description": "Current step in the process" + }, + "total": { + "type": "integer", + "description": "Total number of pages" + }, + "data": { + "type": "array", + "items": { + "$ref": "#/components/schemas/CrawlStatusResponseObj" + }, + "description": "Data returned from the job (null when it is in progress)" + }, + "partial_data": { + "type": "array", + "items": { + "$ref": "#/components/schemas/CrawlStatusResponseObj" + }, + "description": "Partial documents returned as it is being crawled (streaming). **This feature is currently in alpha - expect breaking changes** When a page is ready, it will append to the partial_data array, so there is no need to wait for the entire website to be crawled. There is a max of 50 items in the array response. The oldest item (top of the array) will be removed when the new item is added to the array." + } + } + } + } + } + }, + "402": { + "description": "Payment required" + }, + "429": { + "description": "Too many requests" + }, + "500": { + "description": "Server error" + } + } + } + }, + "/crawl/cancel/{jobId}": { + "delete": { + "tags": ["Crawl"], + "summary": "Cancel a crawl job", + "operationId": "cancelCrawlJob", + "security": [ + { + "bearerAuth": [] + } + ], + "parameters": [ + { + "name": "jobId", + "in": "path", + "description": "ID of the crawl job", + "required": true, + "schema": { + "type": "string" + } + } + ], + "responses": { + "200": { + "description": "Successful response", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "status": { + "type": "string", + "description": "Returns cancelled." + } + } + } + } + } + }, + "402": { + "description": "Payment required" + }, + "429": { + "description": "Too many requests" + }, + "500": { + "description": "Server error" + } + } + } + } + }, + "components": { + "securitySchemes": { + "bearerAuth": { + "type": "http", + "scheme": "bearer" + } + }, + "schemas": { + "ScrapeResponse": { + "type": "object", + "properties": { + "success": { + "type": "boolean" + }, + "data": { + "type": "object", + "properties": { + "markdown": { + "type": "string" + }, + "content": { + "type": "string" + }, + "html": { + "type": "string", + "nullable": true, + "description": "Raw HTML content of the page if `includeHtml` is true" + }, + "metadata": { + "type": "object", + "properties": { + "title": { + "type": "string" + }, + "description": { + "type": "string" + }, + "language": { + "type": "string", + "nullable": true + }, + "keywords": { + "type": "string", + "nullable": true + }, + "robots": { + "type": "string", + "nullable": true + }, + "ogTitle": { + "type": "string", + "nullable": true + }, + "ogDescription": { + "type": "string", + "nullable": true + }, + "ogUrl": { + "type": "string", + "format": "uri", + "nullable": true + }, + "ogImage": { + "type": "string", + "nullable": true + }, + "ogAudio": { + "type": "string", + "nullable": true + }, + "ogDeterminer": { + "type": "string", + "nullable": true + }, + "ogLocale": { + "type": "string", + "nullable": true + }, + "ogLocaleAlternate": { + "type": "array", + "items": { + "type": "string" + }, + "nullable": true + }, + "ogSiteName": { + "type": "string", + "nullable": true + }, + "ogVideo": { + "type": "string", + "nullable": true + }, + "dctermsCreated": { + "type": "string", + "nullable": true + }, + "dcDateCreated": { + "type": "string", + "nullable": true + }, + "dcDate": { + "type": "string", + "nullable": true + }, + "dctermsType": { + "type": "string", + "nullable": true + }, + "dcType": { + "type": "string", + "nullable": true + }, + "dctermsAudience": { + "type": "string", + "nullable": true + }, + "dctermsSubject": { + "type": "string", + "nullable": true + }, + "dcSubject": { + "type": "string", + "nullable": true + }, + "dcDescription": { + "type": "string", + "nullable": true + }, + "dctermsKeywords": { + "type": "string", + "nullable": true + }, + "modifiedTime": { + "type": "string", + "nullable": true + }, + "publishedTime": { + "type": "string", + "nullable": true + }, + "articleTag": { + "type": "string", + "nullable": true + }, + "articleSection": { + "type": "string", + "nullable": true + }, + "sourceURL": { + "type": "string", + "format": "uri" + }, + "pageStatusCode": { + "type": "integer", + "description": "The status code of the page" + }, + "pageError": { + "type": "string", + "nullable": true, + "description": "The error message of the page" + } + } + }, + "llm_extraction": { + "type": "object", + "description": "Displayed when using LLM Extraction. Extracted data from the page following the schema defined.", + "nullable": true + }, + "warning": { + "type": "string", + "nullable": true, + "description": "Can be displayed when using LLM Extraction. Warning message will let you know any issues with the extraction." + } + } + } + } + }, + "CrawlStatusResponseObj": { + "type": "object", + "properties": { + "markdown": { + "type": "string" + }, + "content": { + "type": "string" + }, + "html": { + "type": "string", + "nullable": true, + "description": "Raw HTML content of the page if `includeHtml` is true" + }, + "index": { + "type": "integer", + "description": "The number of the page that was crawled. This is useful for `partial_data` so you know which page the data is from." + }, + "metadata": { + "type": "object", + "properties": { + "title": { + "type": "string" + }, + "description": { + "type": "string" + }, + "language": { + "type": "string", + "nullable": true + }, + "keywords": { + "type": "string", + "nullable": true + }, + "robots": { + "type": "string", + "nullable": true + }, + "ogTitle": { + "type": "string", + "nullable": true + }, + "ogDescription": { + "type": "string", + "nullable": true + }, + "ogUrl": { + "type": "string", + "format": "uri", + "nullable": true + }, + "ogImage": { + "type": "string", + "nullable": true + }, + "ogAudio": { + "type": "string", + "nullable": true + }, + "ogDeterminer": { + "type": "string", + "nullable": true + }, + "ogLocale": { + "type": "string", + "nullable": true + }, + "ogLocaleAlternate": { + "type": "array", + "items": { + "type": "string" + }, + "nullable": true + }, + "ogSiteName": { + "type": "string", + "nullable": true + }, + "ogVideo": { + "type": "string", + "nullable": true + }, + "dctermsCreated": { + "type": "string", + "nullable": true + }, + "dcDateCreated": { + "type": "string", + "nullable": true + }, + "dcDate": { + "type": "string", + "nullable": true + }, + "dctermsType": { + "type": "string", + "nullable": true + }, + "dcType": { + "type": "string", + "nullable": true + }, + "dctermsAudience": { + "type": "string", + "nullable": true + }, + "dctermsSubject": { + "type": "string", + "nullable": true + }, + "dcSubject": { + "type": "string", + "nullable": true + }, + "dcDescription": { + "type": "string", + "nullable": true + }, + "dctermsKeywords": { + "type": "string", + "nullable": true + }, + "modifiedTime": { + "type": "string", + "nullable": true + }, + "publishedTime": { + "type": "string", + "nullable": true + }, + "articleTag": { + "type": "string", + "nullable": true + }, + "articleSection": { + "type": "string", + "nullable": true + }, + "sourceURL": { + "type": "string", + "format": "uri" + }, + "pageStatusCode": { + "type": "integer", + "description": "The status code of the page" + }, + "pageError": { + "type": "string", + "nullable": true, + "description": "The error message of the page" + } + } + } + } + }, + "SearchResponse": { + "type": "object", + "properties": { + "success": { + "type": "boolean" + }, + "data": { + "type": "array", + "items": { + "type": "object", + "properties": { + "url": { + "type": "string" + }, + "markdown": { + "type": "string" + }, + "content": { + "type": "string" + }, + "metadata": { + "type": "object", + "properties": { + "title": { + "type": "string" + }, + "description": { + "type": "string" + }, + "language": { + "type": "string", + "nullable": true + }, + "sourceURL": { + "type": "string", + "format": "uri" + } + } + } + } + } + } + } + }, + "CrawlResponse": { + "type": "object", + "properties": { + "jobId": { + "type": "string" + } + } + } + } + }, + "security": [ + { + "bearerAuth": [] + } + ] +} \ No newline at end of file From 57bb34485512fda85685ce91dd7340542e3f2bc8 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Tue, 18 Jun 2024 12:34:10 +0200 Subject: [PATCH 25/40] PR feedback --- .../components/tools/openapi/_openapi.py | 311 ++++-------------- .../tools/openapi/_schema_conversion.py | 22 +- .../components/tools/openapi/types.py | 230 ++++++++++++- .../tools/openapi/test_openapi_spec.py | 2 +- 4 files changed, 303 insertions(+), 262 deletions(-) diff --git a/haystack_experimental/components/tools/openapi/_openapi.py b/haystack_experimental/components/tools/openapi/_openapi.py index edd640e5..d99dea86 100644 --- a/haystack_experimental/components/tools/openapi/_openapi.py +++ b/haystack_experimental/components/tools/openapi/_openapi.py @@ -2,16 +2,14 @@ # # SPDX-License-Identifier: Apache-2.0 -import json import logging import os -from dataclasses import dataclass, field +from collections import defaultdict from pathlib import Path -from typing import Any, Callable, Dict, List, Literal, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union from urllib.parse import urlparse import requests -import yaml from haystack_experimental.components.tools.openapi._payload_extraction import ( create_function_payload_extractor, @@ -21,18 +19,8 @@ cohere_converter, openai_converter, ) -from haystack_experimental.components.tools.openapi.types import LLMProvider - -VALID_HTTP_METHODS = [ - "get", - "put", - "post", - "delete", - "options", - "head", - "patch", - "trace", -] +from haystack_experimental.components.tools.openapi.types import LLMProvider, OpenAPISpecification, Operation + MIN_REQUIRED_OPENAPI_SPEC_VERSION = 3 logger = logging.getLogger(__name__) @@ -42,7 +30,7 @@ def is_valid_http_url(url: str) -> bool: Check if a URL is a valid HTTP/HTTPS URL. :param url: The URL to check. - :return: True if the URL is a valid HTTP/HTTPS URL, False otherwise. + :returns: True if the URL is a valid HTTP/HTTPS URL, False otherwise. """ r = urlparse(url) return all([r.scheme in ["http", "https"], r.netloc]) @@ -53,7 +41,7 @@ def send_request(request: Dict[str, Any]) -> Dict[str, Any]: Send an HTTP request and return the response. :param request: The request to send. - :return: The response from the server. + :returns: The response from the server. """ url = request["url"] headers = {**request.get("headers", {})} @@ -86,7 +74,7 @@ def create_api_key_auth_function(api_key: str): Create a function that applies the API key authentication strategy to a given request. :param api_key: the API key to use for authentication. - :return: a function that applies the API key authentication to a request + :returns: a function that applies the API key authentication to a request at the schema specified location. """ @@ -117,7 +105,7 @@ def create_http_auth_function(token: str): Create a function that applies the http authentication strategy to a given request. :param token: the authentication token to use. - :return: a function that applies the API key authentication to a request + :returns: a function that applies the API key authentication to a request at the schema specified location. """ @@ -152,206 +140,6 @@ class HttpClientError(Exception): """Exception raised for errors in the HTTP client.""" -@dataclass -class Operation: - """ - Represents an operation in an OpenAPI specification - - See https://spec.openapis.org/oas/latest.html#paths-object for details. - Path objects can contain multiple operations, each with a unique combination of path and method. - - Attributes: - path (str): Path of the operation. - method (str): HTTP method of the operation. - operation_dict (Dict[str, Any]): Operation details from OpenAPI spec. - spec_dict (Dict[str, Any]): The encompassing OpenAPI specification. - security_requirements (List[Dict[str, List[str]]]): Security requirements for the operation. - request_body (Dict[str, Any]): Request body details. - parameters (List[Dict[str, Any]]): Parameters for the operation. - """ - - path: str - method: str - operation_dict: Dict[str, Any] - spec_dict: Dict[str, Any] - security_requirements: List[Dict[str, List[str]]] = field(init=False) - request_body: Dict[str, Any] = field(init=False) - parameters: List[Dict[str, Any]] = field(init=False) - - def __post_init__(self): - if self.method.lower() not in VALID_HTTP_METHODS: - raise ValueError(f"Invalid HTTP method: {self.method}") - self.method = self.method.lower() - self.security_requirements = self.operation_dict.get( - "security", [] - ) or self.spec_dict.get("security", []) - self.request_body = self.operation_dict.get("requestBody", {}) - self.parameters = self.operation_dict.get( - "parameters", [] - ) + self.spec_dict.get("paths", {}).get(self.path, {}).get("parameters", []) - - def get_parameters( - self, location: Optional[Literal["header", "query", "path"]] = None - ) -> List[Dict[str, Any]]: - """ - Get the parameters for the operation. - - :param location: The location of the parameters to get. - :return: The parameters for the operation as a list of dictionaries. - """ - if location: - return [param for param in self.parameters if param["in"] == location] - return self.parameters - - def get_server(self, server_index: int = 0) -> str: - """ - Get the servers for the operation. - - :param server_index: The index of the server to use. - :return: The server URL. - :raises ValueError: If no servers are found in the specification. - """ - servers = self.operation_dict.get("servers", []) or self.spec_dict.get( - "servers", [] - ) - if not servers: - raise ValueError("No servers found in the provided specification.") - if server_index >= len(servers): - raise ValueError( - f"Server index {server_index} is out of bounds. " - f"Only {len(servers)} servers found." - ) - return servers[server_index].get( - "url" - ) # just use the first server from the list - - -class OpenAPISpecification: - """ - Represents an OpenAPI specification. See https://spec.openapis.org/oas/latest.html for details. - """ - - def __init__(self, spec_dict: Dict[str, Any]): - if not isinstance(spec_dict, Dict): - raise ValueError( - f"Invalid OpenAPI specification, expected a dictionary: {spec_dict}" - ) - # just a crude sanity check, by no means a full validation - if ( - "openapi" not in spec_dict - or "paths" not in spec_dict - or "servers" not in spec_dict - ): - raise ValueError( - "Invalid OpenAPI specification format. See https://swagger.io/specification/ for details.", - spec_dict, - ) - self.spec_dict = spec_dict - - @classmethod - def from_str(cls, content: str) -> "OpenAPISpecification": - """ - Create an OpenAPISpecification instance from a string. - - :param content: The string content of the OpenAPI specification. - :return: The OpenAPISpecification instance. - """ - try: - loaded_spec = json.loads(content) - except json.JSONDecodeError: - try: - loaded_spec = yaml.safe_load(content) - except yaml.YAMLError as e: - raise ValueError( - "Content cannot be decoded as JSON or YAML: " + str(e) - ) from e - return cls(loaded_spec) - - @classmethod - def from_file(cls, spec_file: Union[str, Path]) -> "OpenAPISpecification": - """ - Create an OpenAPISpecification instance from a file. - - :param spec_file: The file path to the OpenAPI specification. - :return: The OpenAPISpecification instance. - """ - with open(spec_file, encoding="utf-8") as file: - content = file.read() - return cls.from_str(content) - - @classmethod - def from_url(cls, url: str) -> "OpenAPISpecification": - """ - Create an OpenAPISpecification instance from a URL. - - :param url: The URL to fetch the OpenAPI specification from. - :return: The OpenAPISpecification instance. - """ - try: - response = requests.get(url, timeout=10) - response.raise_for_status() - content = response.text - except requests.RequestException as e: - raise ConnectionError( - f"Failed to fetch the specification from URL: {url}. {e!s}" - ) from e - return cls.from_str(content) - - def find_operation_by_id( - self, op_id: str, method: Optional[str] = None - ) -> Operation: - """ - Find an Operation by operationId. - - :param op_id: The operationId of the operation. - :param method: The HTTP method of the operation. - :return: The matching operation - :raises ValueError: If no operation is found with the given operationId. - """ - for path, path_item in self.spec_dict.get("paths", {}).items(): - op: Operation = self.get_operation_item(path, path_item, method) - if op_id in op.operation_dict.get("operationId", ""): - return self.get_operation_item(path, path_item, method) - raise ValueError( - f"No operation found with operationId {op_id}, method {method}" - ) - - def get_operation_item( - self, path: str, path_item: Dict[str, Any], method: Optional[str] = None - ) -> Operation: - """ - Gets a particular Operation item from the OpenAPI specification given the path and method. - - :param path: The path of the operation. - :param path_item: The path item from the OpenAPI specification. - :param method: The HTTP method of the operation. - :return: The operation - """ - if method: - operation_dict = path_item.get(method.lower(), {}) - if not operation_dict: - raise ValueError( - f"No operation found for method {method} at path {path}" - ) - return Operation(path, method.lower(), operation_dict, self.spec_dict) - if len(path_item) == 1: - method, operation_dict = next(iter(path_item.items())) - return Operation(path, method, operation_dict, self.spec_dict) - if len(path_item) > 1: - raise ValueError( - f"Multiple operations found at path {path}, method parameter is required." - ) - raise ValueError(f"No operations found at path {path} and method {method}") - - def get_security_schemes(self) -> Dict[str, Dict[str, Any]]: - """ - Get the security schemes from the OpenAPI specification. - - :return: The security schemes as a dictionary. - """ - return self.spec_dict.get("components", {}).get("securitySchemes", {}) - - class ClientConfiguration: """Configuration for the OpenAPI client.""" @@ -362,13 +150,22 @@ def __init__( # noqa: PLR0913 pylint: disable=too-many-arguments request_sender: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None, llm_provider: Optional[LLMProvider] = None, ): # noqa: PLR0913 + """ + Initialize a ClientConfiguration instance. + + :param openapi_spec: The OpenAPI specification as a file path, URL, or dictionary. + :param credentials: The credentials to use for authentication. + :param request_sender: The function to use for sending requests. + :param llm_provider: The LLM provider to use for generating tools definitions. + :raises ValueError: If the OpenAPI specification format is invalid. + """ if isinstance(openapi_spec, (str, Path)) and os.path.isfile(openapi_spec): self.openapi_spec = OpenAPISpecification.from_file(openapi_spec) elif isinstance(openapi_spec, str): if is_valid_http_url(openapi_spec): self.openapi_spec = OpenAPISpecification.from_url(openapi_spec) else: - self.openapi_spec = OpenAPISpecification.from_str(openapi_spec) + self.openapi_spec = OpenAPISpecification._from_str(openapi_spec) else: raise ValueError( "Invalid OpenAPI specification format. Expected file path or dictionary." @@ -385,7 +182,8 @@ def get_auth_function(self) -> Callable[[Dict[str, Any], Dict[str, Any]], Any]: The function takes a security scheme and a request as arguments: `security_scheme: Dict[str, Any] - The security scheme from the OpenAPI spec.` `request: Dict[str, Any] - The request to apply the authentication to.` - :return: The authentication function. + :returns: The authentication function. + :raises ValueError: If the credentials type is not supported. """ security_schemes = self.openapi_spec.get_security_schemes() if not self.credentials: @@ -400,42 +198,46 @@ def get_tools_definitions(self) -> List[Dict[str, Any]]: """ Get the tools definitions used as tools LLM parameter. - :return: The tools definitions passed to the LLM as tools parameter. + :returns: The tools definitions passed to the LLM as tools parameter. """ - provider_to_converter = { - "anthropic": anthropic_converter, - "cohere": cohere_converter, - } - converter = provider_to_converter.get(self.llm_provider.value, openai_converter) + provider_to_converter = defaultdict( + lambda: openai_converter, + { + LLMProvider.ANTHROPIC.value: anthropic_converter, + LLMProvider.COHERE.value: cohere_converter, + } + ) + converter = provider_to_converter[self.llm_provider.value] return converter(self.openapi_spec) - def get_payload_extractor(self): + def get_payload_extractor(self) -> Callable[[Dict[str, Any]], Dict[str, Any]]: """ Get the payload extractor for the LLM provider. This function knows how to extract the exact function payload from the LLM generated function calling payload. - :return: The payload extractor function. + :returns: The payload extractor function. """ - provider_to_arguments_field_name = { - "anthropic": "input", - "cohere": "parameters", - } # add more providers here - # default to OpenAI "arguments" - arguments_field_name = provider_to_arguments_field_name.get( - self.llm_provider.value, "arguments" + provider_to_arguments_field_name = defaultdict( + lambda: "arguments", + { + LLMProvider.ANTHROPIC.value: "input", + LLMProvider.COHERE.value: "parameters", + } ) + arguments_field_name = provider_to_arguments_field_name[self.llm_provider.value] return create_function_payload_extractor(arguments_field_name) def _create_authentication_from_string( - self, credentials: str, security_schemes: Dict[str, Any] + self, credentials: str, security_schemes: Dict[str, Any] ) -> Callable[[Dict[str, Any], Dict[str, Any]], Any]: for scheme in security_schemes.values(): if scheme["type"] == "apiKey": return create_api_key_auth_function(api_key=credentials) if scheme["type"] == "http": return create_http_auth_function(token=credentials) - if scheme["type"] == "oauth2": - raise NotImplementedError("OAuth2 authentication is not yet supported.") + raise ValueError( + f"Unsupported authentication type '{scheme['type']}' provided." + ) raise ValueError( f"Unable to create authentication from provided credentials: {credentials}" ) @@ -447,7 +249,9 @@ def build_request(operation: Operation, **kwargs) -> Dict[str, Any]: :param operation: The operation to build the request for. :param kwargs: The arguments to use for building the request. - :return: The HTTP request as a dictionary. + :returns: The HTTP request as a dictionary. + :raises ValueError: If a required parameter is missing. + :raises NotImplementedError: If the request body content type is not supported. We only support JSON payloads. """ path = operation.path for parameter in operation.get_parameters("path"): @@ -527,7 +331,6 @@ class OpenAPIServiceClient: def __init__(self, client_config: ClientConfiguration): self.client_config = client_config - self.request_sender = client_config.request_sender def invoke(self, function_payload: Any) -> Any: """ @@ -538,19 +341,25 @@ def invoke(self, function_payload: Any) -> Any: :raises OpenAPIClientError: If the function invocation payload cannot be extracted from the function payload. :raises HttpClientError: If an error occurs while sending the request and receiving the response. """ - fn_extractor = self.client_config.get_payload_extractor() - fn_invocation_payload = fn_extractor(function_payload) - if not fn_invocation_payload: + fn_invocation_payload = {} + try: + fn_extractor = self.client_config.get_payload_extractor() + fn_invocation_payload = fn_extractor(function_payload) + except Exception as e: raise OpenAPIClientError( - f"Failed to extract function invocation payload from {function_payload}" + f"Error extracting function invocation payload: {str(e)}" + ) from e + + if "name" not in fn_invocation_payload or "arguments" not in fn_invocation_payload: + raise OpenAPIClientError( + f"Function invocation payload does not contain 'name' or 'arguments' keys: {fn_invocation_payload}, " + f"the payload extraction function may be incorrect." ) # fn_invocation_payload, if not empty, guaranteed to have "name" and "arguments" keys from here on - operation = self.client_config.openapi_spec.find_operation_by_id( - fn_invocation_payload.get("name") - ) - request = build_request(operation, **fn_invocation_payload.get("arguments")) + operation = self.client_config.openapi_spec.find_operation_by_id(fn_invocation_payload["name"]) + request = build_request(operation, **fn_invocation_payload["arguments"]) apply_authentication(self.client_config.get_auth_function(), operation, request) - return self.request_sender(request) + return self.client_config.request_sender(request) class OpenAPIClientError(Exception): diff --git a/haystack_experimental/components/tools/openapi/_schema_conversion.py b/haystack_experimental/components/tools/openapi/_schema_conversion.py index ca967b34..1ed05152 100644 --- a/haystack_experimental/components/tools/openapi/_schema_conversion.py +++ b/haystack_experimental/components/tools/openapi/_schema_conversion.py @@ -7,18 +7,20 @@ import jsonref +from haystack_experimental.components.tools.openapi.types import OpenAPISpecification + MIN_REQUIRED_OPENAPI_SPEC_VERSION = 3 logger = logging.getLogger(__name__) -def openai_converter(schema: "OpenAPISpecification") -> List[Dict[str, Any]]: # type: ignore[name-defined] # noqa: F821 +def openai_converter(schema: OpenAPISpecification) -> List[Dict[str, Any]]: """ Converts OpenAPI specification to a list of function suitable for OpenAI LLM function calling. See https://platform.openai.com/docs/guides/function-calling for more information about OpenAI's function schema. :param schema: The OpenAPI specification to convert. - :return: A list of dictionaries, each dictionary representing an OpenAI function definition. + :returns: A list of dictionaries, each dictionary representing an OpenAI function definition. """ resolved_schema = jsonref.replace_refs(schema.spec_dict) fn_definitions = _openapi_to_functions( @@ -27,14 +29,14 @@ def openai_converter(schema: "OpenAPISpecification") -> List[Dict[str, Any]]: # return [{"type": "function", "function": fn} for fn in fn_definitions] -def anthropic_converter(schema: "OpenAPISpecification") -> List[Dict[str, Any]]: # type: ignore # noqa: F821 +def anthropic_converter(schema: OpenAPISpecification) -> List[Dict[str, Any]]: """ Converts an OpenAPI specification to a list of function definitions for Anthropic LLM function calling. See https://docs.anthropic.com/en/docs/tool-use for more information about Anthropic's function schema. :param schema: The OpenAPI specification to convert. - :return: A list of dictionaries, each dictionary representing Anthropic function definition. + :returns: A list of dictionaries, each dictionary representing Anthropic function definition. """ resolved_schema = jsonref.replace_refs(schema.spec_dict) return _openapi_to_functions( @@ -42,14 +44,14 @@ def anthropic_converter(schema: "OpenAPISpecification") -> List[Dict[str, Any]]: ) -def cohere_converter(schema: "OpenAPISpecification") -> List[Dict[str, Any]]: # type: ignore[name-defined] # noqa: F821 +def cohere_converter(schema: OpenAPISpecification) -> List[Dict[str, Any]]: """ Converts an OpenAPI specification to a list of function definitions for Cohere LLM function calling. See https://docs.cohere.com/docs/tool-use for more information about Cohere's function schema. :param schema: The OpenAPI specification to convert. - :return: A list of dictionaries, each representing a Cohere style function definition. + :returns: A list of dictionaries, each representing a Cohere style function definition. """ resolved_schema = jsonref.replace_refs(schema.spec_dict) return _openapi_to_functions( @@ -68,6 +70,7 @@ def _openapi_to_functions( :param service_openapi_spec: The OpenAPI specification to extract functions from. :param parameters_name: The name of the parameters field in the function schema. :param parse_endpoint_fn: The function to parse the endpoint specification. + :returns: A list of dictionaries, each dictionary representing a function schema. """ # Doesn't enforce rigid spec validation because that would require a lot of dependencies @@ -101,6 +104,7 @@ def _parse_endpoint_spec_openai( :param resolved_spec: The resolved OpenAPI specification. :param parameters_name: The name of the parameters field in the function schema. + :returns: A dictionary containing the parsed function schema. """ if not isinstance(resolved_spec, dict): logger.warning( @@ -157,6 +161,7 @@ def _parse_property_attributes( :param property_schema: The property schema to parse. :param include_attributes: The attributes to include in the parsed schema. + :returns: A dictionary containing the parsed property schema. """ include_attributes = include_attributes or ["description", "pattern", "enum"] schema_type = property_schema.get("type") @@ -187,6 +192,7 @@ def _parse_endpoint_spec_cohere( :param operation: The operation specification to parse. :param ignored_param: ignored, left for compatibility with the OpenAI converter. + :returns: A dictionary containing the parsed function schema. """ function_name = operation.get("operationId") description = operation.get("description") or operation.get("summary", "") @@ -206,7 +212,7 @@ def _parse_parameters(operation: Dict[str, Any]) -> Dict[str, Any]: Parses the parameters from an operation specification. :param operation: The operation specification to parse. - :return: A dictionary containing the parsed parameters. + :returns: A dictionary containing the parsed parameters. """ parameters = {} for param in operation.get("parameters", []): @@ -239,7 +245,7 @@ def _parse_schema( :param schema: The schema to parse. :param required: Whether the schema is required. :param description: The description of the schema. - :return: A dictionary containing the parsed schema. + :returns: A dictionary containing the parsed schema. """ schema_type = _get_type(schema) if schema_type == "object": diff --git a/haystack_experimental/components/tools/openapi/types.py b/haystack_experimental/components/tools/openapi/types.py index 542fea20..863f204e 100644 --- a/haystack_experimental/components/tools/openapi/types.py +++ b/haystack_experimental/components/tools/openapi/types.py @@ -1,14 +1,240 @@ # SPDX-FileCopyrightText: 2022-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 - +import json +from dataclasses import dataclass, field from enum import Enum +from pathlib import Path +from typing import Any, Dict, List, Literal, Optional, Union + +import requests +import yaml + +VALID_HTTP_METHODS = [ + "get", + "put", + "post", + "delete", + "options", + "head", + "patch", + "trace", +] class LLMProvider(Enum): """ - Enum for the supported LLM providers. + LLM providers supported by `OpenAPITool`. """ OPENAI = "openai" ANTHROPIC = "anthropic" COHERE = "cohere" + + +@dataclass +class Operation: + """ + Represents an operation in an OpenAPI specification + + See https://spec.openapis.org/oas/latest.html#paths-object for details. + Path objects can contain multiple operations, each with a unique combination of path and method. + + :param path: Path of the operation. + :param method: HTTP method of the operation. + :param operation_dict: Operation details from OpenAPI spec + :param spec_dict: The encompassing OpenAPI specification. + :param security_requirements: A list of security requirements for the operation. + :param request_body: Request body details. + :param parameters: Parameters for the operation. + """ + + path: str + method: str + operation_dict: Dict[str, Any] + spec_dict: Dict[str, Any] + security_requirements: List[Dict[str, List[str]]] = field(init=False) + request_body: Dict[str, Any] = field(init=False) + parameters: List[Dict[str, Any]] = field(init=False) + + def __post_init__(self): + if self.method.lower() not in VALID_HTTP_METHODS: + raise ValueError(f"Invalid HTTP method: {self.method}") + self.method = self.method.lower() + self.security_requirements = self.operation_dict.get( + "security", [] + ) or self.spec_dict.get("security", []) + self.request_body = self.operation_dict.get("requestBody", {}) + self.parameters = self.operation_dict.get( + "parameters", [] + ) + self.spec_dict.get("paths", {}).get(self.path, {}).get("parameters", []) + + def get_parameters( + self, location: Optional[Literal["header", "query", "path"]] = None + ) -> List[Dict[str, Any]]: + """ + Get the parameters for the operation. + + :param location: The location of the parameters to get. + :returns: The parameters for the operation as a list of dictionaries. + """ + if location: + return [param for param in self.parameters if param["in"] == location] + return self.parameters + + def get_server(self, server_index: int = 0) -> str: + """ + Get the servers for the operation. + + :param server_index: The index of the server to use. + :returns: The server URL. + :raises ValueError: If no servers are found in the specification. + """ + servers = self.operation_dict.get("servers", []) or self.spec_dict.get( + "servers", [] + ) + if not servers: + raise ValueError("No servers found in the provided specification.") + if not 0 <= server_index < len(servers): + raise ValueError( + f"Server index {server_index} is out of bounds. " + f"Only {len(servers)} servers found." + ) + return servers[server_index].get( + "url" + ) # just use the first server from the list + + +class OpenAPISpecification: + """ + Represents an OpenAPI specification. See https://spec.openapis.org/oas/latest.html for details. + """ + + def __init__(self, spec_dict: Dict[str, Any]): + """ + Initialize an OpenAPISpecification instance. + + :param spec_dict: The OpenAPI specification as a dictionary. + """ + if not isinstance(spec_dict, Dict): + raise ValueError( + f"Invalid OpenAPI specification, expected a dictionary: {spec_dict}" + ) + # just a crude sanity check, by no means a full validation + if ( + "openapi" not in spec_dict + or "paths" not in spec_dict + or "servers" not in spec_dict + ): + raise ValueError( + "Invalid OpenAPI specification format. See https://swagger.io/specification/ for details.", + spec_dict, + ) + self.spec_dict = spec_dict + + @classmethod + def _from_str(cls, content: str) -> "OpenAPISpecification": + """ + Create an OpenAPISpecification instance from a string. + + :param content: The string content of the OpenAPI specification. + :returns: The OpenAPISpecification instance. + :raises ValueError: If the content cannot be decoded as JSON or YAML. + """ + try: + loaded_spec = json.loads(content) + except json.JSONDecodeError: + try: + loaded_spec = yaml.safe_load(content) + except yaml.YAMLError as e: + raise ValueError( + "Content cannot be decoded as JSON or YAML: " + str(e) + ) from e + return cls(loaded_spec) + + @classmethod + def from_file(cls, spec_file: Union[str, Path]) -> "OpenAPISpecification": + """ + Create an OpenAPISpecification instance from a file. + + :param spec_file: The file path to the OpenAPI specification. + :returns: The OpenAPISpecification instance. + :raises FileNotFoundError: If the specified file does not exist. + :raises IOError: If an I/O error occurs while reading the file. + :raises ValueError: If the file content cannot be decoded as JSON or YAML. + """ + with open(spec_file, encoding="utf-8") as file: + content = file.read() + return cls._from_str(content) + + @classmethod + def from_url(cls, url: str) -> "OpenAPISpecification": + """ + Create an OpenAPISpecification instance from a URL. + + :param url: The URL to fetch the OpenAPI specification from. + :returns: The OpenAPISpecification instance. + :raises ConnectionError: If fetching the specification from the URL fails. + """ + try: + response = requests.get(url, timeout=10) + response.raise_for_status() + content = response.text + except requests.RequestException as e: + raise ConnectionError( + f"Failed to fetch the specification from URL: {url}. {e!s}" + ) from e + return cls._from_str(content) + + def find_operation_by_id( + self, op_id: str, method: Optional[str] = None + ) -> Operation: + """ + Find an Operation by operationId. + + :param op_id: The operationId of the operation. + :param method: The HTTP method of the operation. + :returns: The matching operation + :raises ValueError: If no operation is found with the given operationId. + """ + for path, path_item in self.spec_dict.get("paths", {}).items(): + op: Operation = self.get_operation_item(path, path_item, method) + if op_id in op.operation_dict.get("operationId", ""): + return self.get_operation_item(path, path_item, method) + raise ValueError( + f"No operation found with operationId {op_id}, method {method}" + ) + + def get_operation_item( + self, path: str, path_item: Dict[str, Any], method: Optional[str] = None + ) -> Operation: + """ + Gets a particular Operation item from the OpenAPI specification given the path and method. + + :param path: The path of the operation. + :param path_item: The path item from the OpenAPI specification. + :param method: The HTTP method of the operation. + :returns: The operation + """ + if method: + operation_dict = path_item.get(method.lower(), {}) + if not operation_dict: + raise ValueError( + f"No operation found for method {method} at path {path}" + ) + return Operation(path, method.lower(), operation_dict, self.spec_dict) + if len(path_item) == 1: + method, operation_dict = next(iter(path_item.items())) + return Operation(path, method, operation_dict, self.spec_dict) + if len(path_item) > 1: + raise ValueError( + f"Multiple operations found at path {path}, method parameter is required." + ) + raise ValueError(f"No operations found at path {path} and method {method}") + + def get_security_schemes(self) -> Dict[str, Dict[str, Any]]: + """ + Get the security schemes from the OpenAPI specification. + + :returns: The security schemes as a dictionary. + """ + return self.spec_dict.get("components", {}).get("securitySchemes", {}) diff --git a/test/components/tools/openapi/test_openapi_spec.py b/test/components/tools/openapi/test_openapi_spec.py index 4e38de2d..fce2bb98 100644 --- a/test/components/tools/openapi/test_openapi_spec.py +++ b/test/components/tools/openapi/test_openapi_spec.py @@ -26,7 +26,7 @@ def test_initialized_from_string(self): '200': description: Successful response """ - openapi_spec = OpenAPISpecification.from_str(content) + openapi_spec = OpenAPISpecification._from_str(content) assert openapi_spec.spec_dict == { "openapi": "3.0.0", "info": {"title": "Test API", "version": "1.0.0"}, From df71f3a287a0e7a73c591aa7239e03c995d1fc04 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Tue, 18 Jun 2024 12:36:01 +0200 Subject: [PATCH 26/40] Fix header --- haystack_experimental/components/tools/openapi/types.py | 1 + 1 file changed, 1 insertion(+) diff --git a/haystack_experimental/components/tools/openapi/types.py b/haystack_experimental/components/tools/openapi/types.py index 863f204e..af04776e 100644 --- a/haystack_experimental/components/tools/openapi/types.py +++ b/haystack_experimental/components/tools/openapi/types.py @@ -1,6 +1,7 @@ # SPDX-FileCopyrightText: 2022-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 + import json from dataclasses import dataclass, field from enum import Enum From bfdcf94bfdcdb3000dcfd0e3b07326f76f808a38 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Tue, 18 Jun 2024 14:34:50 +0200 Subject: [PATCH 27/40] PR feedback - details --- .../components/tools/__init__.py | 2 +- .../components/tools/openapi/_openapi.py | 22 +++++++++---------- .../components/tools/openapi/types.py | 6 ++--- .../tools/openapi/test_openapi_spec.py | 2 +- 4 files changed, 16 insertions(+), 16 deletions(-) diff --git a/haystack_experimental/components/tools/__init__.py b/haystack_experimental/components/tools/__init__.py index be1a6773..65434145 100644 --- a/haystack_experimental/components/tools/__init__.py +++ b/haystack_experimental/components/tools/__init__.py @@ -4,4 +4,4 @@ from .openai.function_caller import OpenAIFunctionCaller -_all_ = ["OpenAIFunctionCaller"] \ No newline at end of file +_all_ = ["OpenAIFunctionCaller"] diff --git a/haystack_experimental/components/tools/openapi/_openapi.py b/haystack_experimental/components/tools/openapi/_openapi.py index d99dea86..34884ce9 100644 --- a/haystack_experimental/components/tools/openapi/_openapi.py +++ b/haystack_experimental/components/tools/openapi/_openapi.py @@ -69,7 +69,7 @@ def send_request(request: Dict[str, Any]) -> Dict[str, Any]: # Authentication strategies -def create_api_key_auth_function(api_key: str): +def create_api_key_auth_function(api_key: str) -> Callable[[Dict[str, Any], Dict[str, Any]], None]: """ Create a function that applies the API key authentication strategy to a given request. @@ -78,7 +78,7 @@ def create_api_key_auth_function(api_key: str): at the schema specified location. """ - def apply_auth(security_scheme: Dict[str, Any], request: Dict[str, Any]): + def apply_auth(security_scheme: Dict[str, Any], request: Dict[str, Any]) -> None: """ Apply the API key authentication strategy to the given request. @@ -100,7 +100,7 @@ def apply_auth(security_scheme: Dict[str, Any], request: Dict[str, Any]): return apply_auth -def create_http_auth_function(token: str): +def create_http_auth_function(token: str) -> Callable[[Dict[str, Any], Dict[str, Any]], None]: """ Create a function that applies the http authentication strategy to a given request. @@ -109,7 +109,7 @@ def create_http_auth_function(token: str): at the schema specified location. """ - def apply_auth(security_scheme: Dict[str, Any], request: Dict[str, Any]): + def apply_auth(security_scheme: Dict[str, Any], request: Dict[str, Any]) -> None: """ Apply the HTTP authentication strategy to the given request. @@ -165,7 +165,7 @@ def __init__( # noqa: PLR0913 pylint: disable=too-many-arguments if is_valid_http_url(openapi_spec): self.openapi_spec = OpenAPISpecification.from_url(openapi_spec) else: - self.openapi_spec = OpenAPISpecification._from_str(openapi_spec) + self.openapi_spec = OpenAPISpecification.from_str(openapi_spec) else: raise ValueError( "Invalid OpenAPI specification format. Expected file path or dictionary." @@ -203,11 +203,11 @@ def get_tools_definitions(self) -> List[Dict[str, Any]]: provider_to_converter = defaultdict( lambda: openai_converter, { - LLMProvider.ANTHROPIC.value: anthropic_converter, - LLMProvider.COHERE.value: cohere_converter, + LLMProvider.ANTHROPIC: anthropic_converter, + LLMProvider.COHERE: cohere_converter, } ) - converter = provider_to_converter[self.llm_provider.value] + converter = provider_to_converter[self.llm_provider] return converter(self.openapi_spec) def get_payload_extractor(self) -> Callable[[Dict[str, Any]], Dict[str, Any]]: @@ -220,11 +220,11 @@ def get_payload_extractor(self) -> Callable[[Dict[str, Any]], Dict[str, Any]]: provider_to_arguments_field_name = defaultdict( lambda: "arguments", { - LLMProvider.ANTHROPIC.value: "input", - LLMProvider.COHERE.value: "parameters", + LLMProvider.ANTHROPIC: "input", + LLMProvider.COHERE: "parameters", } ) - arguments_field_name = provider_to_arguments_field_name[self.llm_provider.value] + arguments_field_name = provider_to_arguments_field_name[self.llm_provider] return create_function_payload_extractor(arguments_field_name) def _create_authentication_from_string( diff --git a/haystack_experimental/components/tools/openapi/types.py b/haystack_experimental/components/tools/openapi/types.py index af04776e..00f3c2aa 100644 --- a/haystack_experimental/components/tools/openapi/types.py +++ b/haystack_experimental/components/tools/openapi/types.py @@ -133,7 +133,7 @@ def __init__(self, spec_dict: Dict[str, Any]): self.spec_dict = spec_dict @classmethod - def _from_str(cls, content: str) -> "OpenAPISpecification": + def from_str(cls, content: str) -> "OpenAPISpecification": """ Create an OpenAPISpecification instance from a string. @@ -165,7 +165,7 @@ def from_file(cls, spec_file: Union[str, Path]) -> "OpenAPISpecification": """ with open(spec_file, encoding="utf-8") as file: content = file.read() - return cls._from_str(content) + return cls.from_str(content) @classmethod def from_url(cls, url: str) -> "OpenAPISpecification": @@ -184,7 +184,7 @@ def from_url(cls, url: str) -> "OpenAPISpecification": raise ConnectionError( f"Failed to fetch the specification from URL: {url}. {e!s}" ) from e - return cls._from_str(content) + return cls.from_str(content) def find_operation_by_id( self, op_id: str, method: Optional[str] = None diff --git a/test/components/tools/openapi/test_openapi_spec.py b/test/components/tools/openapi/test_openapi_spec.py index fce2bb98..4e38de2d 100644 --- a/test/components/tools/openapi/test_openapi_spec.py +++ b/test/components/tools/openapi/test_openapi_spec.py @@ -26,7 +26,7 @@ def test_initialized_from_string(self): '200': description: Successful response """ - openapi_spec = OpenAPISpecification._from_str(content) + openapi_spec = OpenAPISpecification.from_str(content) assert openapi_spec.spec_dict == { "openapi": "3.0.0", "info": {"title": "Test API", "version": "1.0.0"}, From c6cca91919713cab9d9aa1bbe529b800dadf31d6 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Tue, 18 Jun 2024 15:13:49 +0200 Subject: [PATCH 28/40] Lift up OpenAPISpecification --- .../components/tools/openapi/_openapi.py | 35 +++---------------- .../components/tools/openapi/openapi_tool.py | 20 +++++++++-- test/components/tools/openapi/conftest.py | 20 +++++++++-- .../tools/openapi/test_openapi_client.py | 8 ++--- .../tools/openapi/test_openapi_client_auth.py | 8 ++--- ...est_openapi_client_complex_request_body.py | 4 +-- ...enapi_client_complex_request_body_mixed.py | 4 +-- .../openapi/test_openapi_client_edge_cases.py | 4 +-- .../test_openapi_client_error_handling.py | 4 +-- .../tools/openapi/test_openapi_client_live.py | 5 +-- .../test_openapi_client_live_anthropic.py | 5 +-- .../test_openapi_client_live_cohere.py | 5 +-- .../test_openapi_client_live_openai.py | 7 ++-- 13 files changed, 69 insertions(+), 60 deletions(-) diff --git a/haystack_experimental/components/tools/openapi/_openapi.py b/haystack_experimental/components/tools/openapi/_openapi.py index 34884ce9..a19c9c58 100644 --- a/haystack_experimental/components/tools/openapi/_openapi.py +++ b/haystack_experimental/components/tools/openapi/_openapi.py @@ -3,11 +3,8 @@ # SPDX-License-Identifier: Apache-2.0 import logging -import os from collections import defaultdict -from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Union -from urllib.parse import urlparse +from typing import Any, Callable, Dict, List, Optional import requests @@ -25,17 +22,6 @@ logger = logging.getLogger(__name__) -def is_valid_http_url(url: str) -> bool: - """ - Check if a URL is a valid HTTP/HTTPS URL. - - :param url: The URL to check. - :returns: True if the URL is a valid HTTP/HTTPS URL, False otherwise. - """ - r = urlparse(url) - return all([r.scheme in ["http", "https"], r.netloc]) - - def send_request(request: Dict[str, Any]) -> Dict[str, Any]: """ Send an HTTP request and return the response. @@ -143,9 +129,9 @@ class HttpClientError(Exception): class ClientConfiguration: """Configuration for the OpenAPI client.""" - def __init__( # noqa: PLR0913 pylint: disable=too-many-arguments + def __init__( self, - openapi_spec: Union[str, Path], + openapi_spec: OpenAPISpecification, credentials: Optional[str] = None, request_sender: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None, llm_provider: Optional[LLMProvider] = None, @@ -153,24 +139,13 @@ def __init__( # noqa: PLR0913 pylint: disable=too-many-arguments """ Initialize a ClientConfiguration instance. - :param openapi_spec: The OpenAPI specification as a file path, URL, or dictionary. + :param openapi_spec: The OpenAPI specification to use for the client. :param credentials: The credentials to use for authentication. :param request_sender: The function to use for sending requests. :param llm_provider: The LLM provider to use for generating tools definitions. :raises ValueError: If the OpenAPI specification format is invalid. """ - if isinstance(openapi_spec, (str, Path)) and os.path.isfile(openapi_spec): - self.openapi_spec = OpenAPISpecification.from_file(openapi_spec) - elif isinstance(openapi_spec, str): - if is_valid_http_url(openapi_spec): - self.openapi_spec = OpenAPISpecification.from_url(openapi_spec) - else: - self.openapi_spec = OpenAPISpecification.from_str(openapi_spec) - else: - raise ValueError( - "Invalid OpenAPI specification format. Expected file path or dictionary." - ) - + self.openapi_spec = openapi_spec self.credentials = credentials self.request_sender = request_sender or send_request self.llm_provider: LLMProvider = llm_provider or LLMProvider.OPENAI diff --git a/haystack_experimental/components/tools/openapi/openapi_tool.py b/haystack_experimental/components/tools/openapi/openapi_tool.py index 17c98da4..90b372c6 100644 --- a/haystack_experimental/components/tools/openapi/openapi_tool.py +++ b/haystack_experimental/components/tools/openapi/openapi_tool.py @@ -3,6 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 import json +import os from pathlib import Path from typing import Any, Dict, List, Optional, Union @@ -10,12 +11,13 @@ from haystack.components.generators.chat import OpenAIChatGenerator from haystack.dataclasses import ChatMessage, ChatRole from haystack.lazy_imports import LazyImport +from haystack.utils.url_validation import is_valid_http_url from haystack_experimental.components.tools.openapi._openapi import ( ClientConfiguration, OpenAPIServiceClient, ) -from haystack_experimental.components.tools.openapi.types import LLMProvider +from haystack_experimental.components.tools.openapi.types import LLMProvider, OpenAPISpecification with LazyImport("Run 'pip install anthropic-haystack'") as anthropic_import: # pylint: disable=import-error @@ -74,7 +76,7 @@ def __init__( self.open_api_service: Optional[OpenAPIServiceClient] = None if tool_spec: self.config_openapi = ClientConfiguration( - openapi_spec=tool_spec, + openapi_spec=self._create_openapi_spec(tool_spec), credentials=tool_credentials, llm_provider=generator_api ) @@ -112,7 +114,7 @@ def run( config_openapi: Optional[ClientConfiguration] = self.config_openapi if tool_spec: config_openapi = ClientConfiguration( - openapi_spec=tool_spec, + openapi_spec=self._create_openapi_spec(tool_spec), credentials=tool_credentials, llm_provider=self.generator_api, ) @@ -160,3 +162,15 @@ def _init_generator(self, generator_api: LLMProvider, generator_api_params: Dict anthropic_import.check() return AnthropicChatGenerator(**generator_api_params) raise ValueError(f"Unsupported generator API: {generator_api}") + + def _create_openapi_spec(self, openapi_spec: Union[Path, str]) -> OpenAPISpecification: + if isinstance(openapi_spec, (str, Path)) and os.path.isfile(openapi_spec): + return OpenAPISpecification.from_file(openapi_spec) + if isinstance(openapi_spec, str): + if is_valid_http_url(openapi_spec): + return OpenAPISpecification.from_url(openapi_spec) + return OpenAPISpecification.from_str(openapi_spec) + + raise ValueError( + "Invalid OpenAPI specification format. Expected file path or dictionary." + ) diff --git a/test/components/tools/openapi/conftest.py b/test/components/tools/openapi/conftest.py index 89ec74d4..224a29f0 100644 --- a/test/components/tools/openapi/conftest.py +++ b/test/components/tools/openapi/conftest.py @@ -1,16 +1,18 @@ # SPDX-FileCopyrightText: 2022-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 - - +import os from pathlib import Path +from typing import Union from urllib.parse import urlparse import pytest from fastapi import FastAPI from fastapi.testclient import TestClient +from haystack.utils.url_validation import is_valid_http_url from haystack_experimental.components.tools.openapi._openapi import HttpClientError +from haystack_experimental.components.tools.openapi.types import OpenAPISpecification @pytest.fixture() @@ -18,6 +20,20 @@ def test_files_path(): return Path(__file__).parent.parent.parent.parent / "test_files" +def create_openapi_spec(openapi_spec: Union[Path, str]) -> OpenAPISpecification: + if isinstance(openapi_spec, (str, Path)) and os.path.isfile(openapi_spec): + return OpenAPISpecification.from_file(openapi_spec) + elif isinstance(openapi_spec, str): + if is_valid_http_url(openapi_spec): + return OpenAPISpecification.from_url(openapi_spec) + else: + return OpenAPISpecification.from_str(openapi_spec) + else: + raise ValueError( + "Invalid OpenAPI specification format. Expected file path or dictionary." + ) + + class FastAPITestClient: def __init__(self, app: FastAPI): diff --git a/test/components/tools/openapi/test_openapi_client.py b/test/components/tools/openapi/test_openapi_client.py index c622642e..ad745125 100644 --- a/test/components/tools/openapi/test_openapi_client.py +++ b/test/components/tools/openapi/test_openapi_client.py @@ -8,7 +8,7 @@ from pydantic import BaseModel from haystack_experimental.components.tools.openapi._openapi import OpenAPIServiceClient, ClientConfiguration -from test.components.tools.openapi.conftest import FastAPITestClient +from test.components.tools.openapi.conftest import FastAPITestClient, create_openapi_spec """ Tests OpenAPIServiceClient with three FastAPI apps for different parameter types: @@ -72,7 +72,7 @@ def greet_request_body(body: GreetBody): class TestOpenAPI: def test_greet_mix_params_body(self, test_files_path): - config = ClientConfiguration(openapi_spec=test_files_path / "yaml" / "openapi_greeting_service.yml", + config = ClientConfiguration(openapi_spec=create_openapi_spec(test_files_path / "yaml" / "openapi_greeting_service.yml"), request_sender=FastAPITestClient(create_greet_mix_params_body_app())) client = OpenAPIServiceClient(config) payload = { @@ -87,7 +87,7 @@ def test_greet_mix_params_body(self, test_files_path): assert response == {"greeting": "Bonjour, John from mix_params_body!"} def test_greet_params_only(self, test_files_path): - config = ClientConfiguration(openapi_spec=test_files_path / "yaml" / "openapi_greeting_service.yml", + config = ClientConfiguration(openapi_spec=create_openapi_spec(test_files_path / "yaml" / "openapi_greeting_service.yml"), request_sender=FastAPITestClient(create_greet_params_only_app())) client = OpenAPIServiceClient(config) payload = { @@ -102,7 +102,7 @@ def test_greet_params_only(self, test_files_path): assert response == {"greeting": "Hello, John from params_only!"} def test_greet_request_body_only(self, test_files_path): - config = ClientConfiguration(openapi_spec=test_files_path / "yaml" / "openapi_greeting_service.yml", + config = ClientConfiguration(openapi_spec=create_openapi_spec(test_files_path / "yaml" / "openapi_greeting_service.yml"), request_sender=FastAPITestClient(create_greet_request_body_only_app())) client = OpenAPIServiceClient(config) payload = { diff --git a/test/components/tools/openapi/test_openapi_client_auth.py b/test/components/tools/openapi/test_openapi_client_auth.py index ab6205e8..6a91bcdd 100644 --- a/test/components/tools/openapi/test_openapi_client_auth.py +++ b/test/components/tools/openapi/test_openapi_client_auth.py @@ -16,7 +16,7 @@ ) from haystack_experimental.components.tools.openapi._openapi import OpenAPIServiceClient, ClientConfiguration -from test.components.tools.openapi.conftest import FastAPITestClient +from test.components.tools.openapi.conftest import FastAPITestClient, create_openapi_spec API_KEY = "secret_api_key" BASIC_AUTH_USERNAME = "admin" @@ -137,7 +137,7 @@ def greet_oauth(name: str, token: HTTPAuthorizationCredentials = Depends(oauth_a class TestOpenAPIAuth: def test_greet_api_key_auth(self, test_files_path): - config = ClientConfiguration(openapi_spec=test_files_path / "yaml" / "openapi_greeting_service.yml", + config = ClientConfiguration(openapi_spec=create_openapi_spec(test_files_path / "yaml" / "openapi_greeting_service.yml"), request_sender=FastAPITestClient(create_greet_api_key_auth_app()), credentials=API_KEY) client = OpenAPIServiceClient(config) @@ -153,7 +153,7 @@ def test_greet_api_key_auth(self, test_files_path): assert response == {"greeting": "Hello, John from api_key_auth, using secret_api_key"} def test_greet_api_key_query_auth(self, test_files_path): - config = ClientConfiguration(openapi_spec=test_files_path / "yaml" / "openapi_greeting_service.yml", + config = ClientConfiguration(openapi_spec=create_openapi_spec(test_files_path / "yaml" / "openapi_greeting_service.yml"), request_sender=FastAPITestClient(create_greet_api_key_query_app()), credentials=API_KEY_QUERY) client = OpenAPIServiceClient(config) @@ -170,7 +170,7 @@ def test_greet_api_key_query_auth(self, test_files_path): def test_greet_api_key_cookie_auth(self, test_files_path): - config = ClientConfiguration(openapi_spec=test_files_path / "yaml" / "openapi_greeting_service.yml", + config = ClientConfiguration(openapi_spec=create_openapi_spec(test_files_path / "yaml" / "openapi_greeting_service.yml"), request_sender=FastAPITestClient(create_greet_api_key_cookie_app()), credentials=API_KEY_COOKIE) diff --git a/test/components/tools/openapi/test_openapi_client_complex_request_body.py b/test/components/tools/openapi/test_openapi_client_complex_request_body.py index 4b4b5a20..8ce2273e 100644 --- a/test/components/tools/openapi/test_openapi_client_complex_request_body.py +++ b/test/components/tools/openapi/test_openapi_client_complex_request_body.py @@ -12,7 +12,7 @@ from pydantic import BaseModel from haystack_experimental.components.tools.openapi._openapi import OpenAPIServiceClient, ClientConfiguration -from test.components.tools.openapi.conftest import FastAPITestClient +from test.components.tools.openapi.conftest import FastAPITestClient, create_openapi_spec class Customer(BaseModel): @@ -58,7 +58,7 @@ class TestComplexRequestBody: def test_create_order(self, spec_file_path, test_files_path): path_element = "yaml" if spec_file_path.endswith(".yml") else "json" - config = ClientConfiguration(openapi_spec=test_files_path / path_element / spec_file_path, + config = ClientConfiguration(openapi_spec=create_openapi_spec(test_files_path / path_element / spec_file_path), request_sender=FastAPITestClient(create_order_app())) client = OpenAPIServiceClient(config) diff --git a/test/components/tools/openapi/test_openapi_client_complex_request_body_mixed.py b/test/components/tools/openapi/test_openapi_client_complex_request_body_mixed.py index bcb5cf48..9624f7d7 100644 --- a/test/components/tools/openapi/test_openapi_client_complex_request_body_mixed.py +++ b/test/components/tools/openapi/test_openapi_client_complex_request_body_mixed.py @@ -10,7 +10,7 @@ from pydantic import BaseModel from haystack_experimental.components.tools.openapi._openapi import OpenAPIServiceClient, ClientConfiguration -from test.components.tools.openapi.conftest import FastAPITestClient +from test.components.tools.openapi.conftest import FastAPITestClient, create_openapi_spec class Identification(BaseModel): @@ -56,7 +56,7 @@ def process_payment(payment: PaymentRequest): class TestPaymentProcess: def test_process_payment(self, test_files_path): - config = ClientConfiguration(openapi_spec=test_files_path / "json" / "complex_types_openapi_service.json", + config = ClientConfiguration(openapi_spec=create_openapi_spec(test_files_path / "json" / "complex_types_openapi_service.json"), request_sender=FastAPITestClient(create_payment_app())) client = OpenAPIServiceClient(config) diff --git a/test/components/tools/openapi/test_openapi_client_edge_cases.py b/test/components/tools/openapi/test_openapi_client_edge_cases.py index 4dfe7a06..f6272baa 100644 --- a/test/components/tools/openapi/test_openapi_client_edge_cases.py +++ b/test/components/tools/openapi/test_openapi_client_edge_cases.py @@ -6,13 +6,13 @@ import pytest from haystack_experimental.components.tools.openapi._openapi import OpenAPIServiceClient, ClientConfiguration -from test.components.tools.openapi.conftest import FastAPITestClient +from test.components.tools.openapi.conftest import FastAPITestClient, create_openapi_spec class TestEdgeCases: def test_missing_operation_id(self, test_files_path): - config = ClientConfiguration(openapi_spec=test_files_path / "yaml" / "openapi_edge_cases.yml", + config = ClientConfiguration(openapi_spec=create_openapi_spec(test_files_path / "yaml" / "openapi_edge_cases.yml"), request_sender=FastAPITestClient(None)) client = OpenAPIServiceClient(config) diff --git a/test/components/tools/openapi/test_openapi_client_error_handling.py b/test/components/tools/openapi/test_openapi_client_error_handling.py index a1d730aa..c399c68d 100644 --- a/test/components/tools/openapi/test_openapi_client_error_handling.py +++ b/test/components/tools/openapi/test_openapi_client_error_handling.py @@ -10,7 +10,7 @@ from haystack_experimental.components.tools.openapi._openapi import OpenAPIServiceClient, HttpClientError, \ ClientConfiguration -from test.components.tools.openapi.conftest import FastAPITestClient +from test.components.tools.openapi.conftest import FastAPITestClient, create_openapi_spec def create_error_handling_app() -> FastAPI: @@ -26,7 +26,7 @@ def raise_http_error(status_code: int): class TestErrorHandling: @pytest.mark.parametrize("status_code", [400, 401, 403, 404, 500]) def test_http_error_handling(self, test_files_path, status_code): - config = ClientConfiguration(openapi_spec=test_files_path / "yaml" / "openapi_error_handling.yml", + config = ClientConfiguration(openapi_spec=create_openapi_spec(test_files_path / "yaml" / "openapi_error_handling.yml"), request_sender=FastAPITestClient(create_error_handling_app())) client = OpenAPIServiceClient(config) json_error = {"status_code": status_code} diff --git a/test/components/tools/openapi/test_openapi_client_live.py b/test/components/tools/openapi/test_openapi_client_live.py index 02ae0b74..a248502b 100644 --- a/test/components/tools/openapi/test_openapi_client_live.py +++ b/test/components/tools/openapi/test_openapi_client_live.py @@ -8,6 +8,7 @@ import pytest import yaml from haystack_experimental.components.tools.openapi._openapi import OpenAPIServiceClient, ClientConfiguration +from test.components.tools.openapi.conftest import create_openapi_spec class TestClientLive: @@ -15,7 +16,7 @@ class TestClientLive: @pytest.mark.skipif("SERPERDEV_API_KEY" not in os.environ, reason="SERPERDEV_API_KEY not set") @pytest.mark.integration def test_serperdev(self, test_files_path): - config = ClientConfiguration(openapi_spec=test_files_path / "yaml" / "serper.yml", credentials=os.getenv("SERPERDEV_API_KEY")) + config = ClientConfiguration(openapi_spec=create_openapi_spec(test_files_path / "yaml" / "serper.yml"), credentials=os.getenv("SERPERDEV_API_KEY")) serper_api = OpenAPIServiceClient(config) payload = { "id": "call_NJr1NBz2Th7iUWJpRIJZoJIA", @@ -30,7 +31,7 @@ def test_serperdev(self, test_files_path): @pytest.mark.integration def test_github(self, test_files_path): - config = ClientConfiguration(openapi_spec=test_files_path / "yaml" / "github_compare.yml") + config = ClientConfiguration(openapi_spec=create_openapi_spec(test_files_path / "yaml" / "github_compare.yml")) api = OpenAPIServiceClient(config) params = {"owner": "deepset-ai", "repo": "haystack", "basehead": "main...add_default_adapter_filters"} diff --git a/test/components/tools/openapi/test_openapi_client_live_anthropic.py b/test/components/tools/openapi/test_openapi_client_live_anthropic.py index 91ca6334..5467915a 100644 --- a/test/components/tools/openapi/test_openapi_client_live_anthropic.py +++ b/test/components/tools/openapi/test_openapi_client_live_anthropic.py @@ -9,6 +9,7 @@ from haystack_experimental.components.tools.openapi._openapi import ClientConfiguration, OpenAPIServiceClient, \ LLMProvider +from test.components.tools.openapi.conftest import create_openapi_spec class TestClientLiveAnthropic: @@ -17,7 +18,7 @@ class TestClientLiveAnthropic: @pytest.mark.skipif("ANTHROPIC_API_KEY" not in os.environ, reason="ANTHROPIC_API_KEY not set") @pytest.mark.integration def test_serperdev(self, test_files_path): - config = ClientConfiguration(openapi_spec=test_files_path / "yaml" / "serper.yml", + config = ClientConfiguration(openapi_spec=create_openapi_spec(test_files_path / "yaml" / "serper.yml"), credentials=os.getenv("SERPERDEV_API_KEY"), llm_provider=LLMProvider.ANTHROPIC) client = anthropic.Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY")) @@ -41,7 +42,7 @@ def test_serperdev(self, test_files_path): @pytest.mark.skipif("ANTHROPIC_API_KEY" not in os.environ, reason="ANTHROPIC_API_KEY not set") @pytest.mark.integration def test_github(self, test_files_path): - config = ClientConfiguration(openapi_spec=test_files_path / "yaml" / "github_compare.yml", + config = ClientConfiguration(openapi_spec=create_openapi_spec(test_files_path / "yaml" / "github_compare.yml"), llm_provider=LLMProvider.ANTHROPIC) client = anthropic.Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY")) diff --git a/test/components/tools/openapi/test_openapi_client_live_cohere.py b/test/components/tools/openapi/test_openapi_client_live_cohere.py index 891bb5fa..49e8cde6 100644 --- a/test/components/tools/openapi/test_openapi_client_live_cohere.py +++ b/test/components/tools/openapi/test_openapi_client_live_cohere.py @@ -8,6 +8,7 @@ from haystack_experimental.components.tools.openapi._openapi import ClientConfiguration, OpenAPIServiceClient, \ LLMProvider +from test.components.tools.openapi.conftest import create_openapi_spec # Copied from Cohere's documentation preamble = """ @@ -29,7 +30,7 @@ class TestClientLiveCohere: @pytest.mark.skipif("COHERE_API_KEY" not in os.environ, reason="COHERE_API_KEY not set") @pytest.mark.integration def test_serperdev(self, test_files_path): - config = ClientConfiguration(openapi_spec=test_files_path / "yaml" / "serper.yml", + config = ClientConfiguration(openapi_spec=create_openapi_spec(test_files_path / "yaml" / "serper.yml"), credentials=os.getenv("SERPERDEV_API_KEY"), llm_provider=LLMProvider.COHERE) client = cohere.Client(api_key=os.getenv("COHERE_API_KEY")) @@ -53,7 +54,7 @@ def test_serperdev(self, test_files_path): @pytest.mark.skipif("COHERE_API_KEY" not in os.environ, reason="COHERE_API_KEY not set") @pytest.mark.integration def test_github(self, test_files_path): - config = ClientConfiguration(openapi_spec=test_files_path / "yaml" / "github_compare.yml", + config = ClientConfiguration(openapi_spec=create_openapi_spec(test_files_path / "yaml" / "github_compare.yml"), llm_provider=LLMProvider.COHERE) client = cohere.Client(api_key=os.getenv("COHERE_API_KEY")) diff --git a/test/components/tools/openapi/test_openapi_client_live_openai.py b/test/components/tools/openapi/test_openapi_client_live_openai.py index c95f6614..716f04ea 100644 --- a/test/components/tools/openapi/test_openapi_client_live_openai.py +++ b/test/components/tools/openapi/test_openapi_client_live_openai.py @@ -8,6 +8,7 @@ from openai import OpenAI from haystack_experimental.components.tools.openapi._openapi import ClientConfiguration, OpenAPIServiceClient +from test.components.tools.openapi.conftest import create_openapi_spec class TestClientLiveOpenAPI: @@ -17,7 +18,7 @@ class TestClientLiveOpenAPI: @pytest.mark.integration def test_serperdev(self, test_files_path): - config = ClientConfiguration(openapi_spec=test_files_path / "yaml" / "serper.yml", + config = ClientConfiguration(openapi_spec=create_openapi_spec(test_files_path / "yaml" / "serper.yml"), credentials=os.getenv("SERPERDEV_API_KEY")) client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) response = client.chat.completions.create( @@ -39,7 +40,7 @@ def test_serperdev(self, test_files_path): @pytest.mark.skipif("OPENAI_API_KEY" not in os.environ, reason="OPENAI_API_KEY not set") @pytest.mark.integration def test_github(self, test_files_path): - config = ClientConfiguration(openapi_spec=test_files_path / "yaml" / "github_compare.yml") + config = ClientConfiguration(openapi_spec=create_openapi_spec(test_files_path / "yaml" / "github_compare.yml")) client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) response = client.chat.completions.create( model="gpt-3.5-turbo", @@ -61,7 +62,7 @@ def test_github(self, test_files_path): @pytest.mark.integration def test_firecrawl(self): openapi_spec_url = "https://raw.githubusercontent.com/mendableai/firecrawl/main/apps/api/openapi.json" - config = ClientConfiguration(openapi_spec=openapi_spec_url, credentials=os.getenv("FIRECRAWL_API_KEY")) + config = ClientConfiguration(openapi_spec=create_openapi_spec(openapi_spec_url), credentials=os.getenv("FIRECRAWL_API_KEY")) client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) response = client.chat.completions.create( model="gpt-3.5-turbo", From 05e38f2d1b023a8c5a6143646706f36aa7098bc2 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Tue, 18 Jun 2024 15:27:53 +0200 Subject: [PATCH 29/40] Update OpenAPITool --- .../components/tools/openapi/openapi_tool.py | 37 ++++++++++--------- 1 file changed, 20 insertions(+), 17 deletions(-) diff --git a/haystack_experimental/components/tools/openapi/openapi_tool.py b/haystack_experimental/components/tools/openapi/openapi_tool.py index 90b372c6..7f412544 100644 --- a/haystack_experimental/components/tools/openapi/openapi_tool.py +++ b/haystack_experimental/components/tools/openapi/openapi_tool.py @@ -11,6 +11,7 @@ from haystack.components.generators.chat import OpenAIChatGenerator from haystack.dataclasses import ChatMessage, ChatRole from haystack.lazy_imports import LazyImport +from haystack.utils import Secret from haystack.utils.url_validation import is_valid_http_url from haystack_experimental.components.tools.openapi._openapi import ( @@ -57,27 +58,27 @@ class OpenAPITool: def __init__( self, generator_api: LLMProvider, - generator_api_params: Dict[str, Any], - tool_spec: Optional[Union[str, Path]] = None, - tool_credentials: Optional[str] = None, + generator_api_params: Optional[Dict[str, Any]] = None, + spec: Optional[Union[str, Path]] = None, + credentials: Optional[Secret] = None, ): """ Initialize the OpenAPITool component. :param generator_api: The API provider for the chat generator. :param generator_api_params: Parameters to pass for the chat generator creation. - :param tool_spec: OpenAPI specification for the tool/service. This can be a URL, a local file path, or + :param spec: OpenAPI specification for the tool/service. This can be a URL, a local file path, or an OpenAPI service specification provided as a string. - :param tool_credentials: Credentials for the tool/service. + :param credentials: Credentials for the tool/service. """ self.generator_api = generator_api - self.chat_generator = self._init_generator(generator_api, generator_api_params) + self.chat_generator = self._init_generator(generator_api, generator_api_params or {}) self.config_openapi: Optional[ClientConfiguration] = None self.open_api_service: Optional[OpenAPIServiceClient] = None - if tool_spec: + if spec: self.config_openapi = ClientConfiguration( - openapi_spec=self._create_openapi_spec(tool_spec), - credentials=tool_credentials, + openapi_spec=self._create_openapi_spec(spec), + credentials=credentials.resolve_value() if credentials else None, llm_provider=generator_api ) self.open_api_service = OpenAPIServiceClient(self.config_openapi) @@ -87,16 +88,18 @@ def run( self, messages: List[ChatMessage], fc_generator_kwargs: Optional[Dict[str, Any]] = None, - tool_spec: Optional[Union[str, Path]] = None, - tool_credentials: Optional[str] = None, + spec: Optional[Union[str, Path]] = None, + credentials: Optional[Secret] = None, ) -> Dict[str, List[ChatMessage]]: """ Invokes the underlying OpenAPI service/tool with the function calling payload generated by the chat generator. - :param messages: List of ChatMessages to generate function calling payload (e.g. human instructions). + :param messages: List of ChatMessages to generate function calling payload (e.g. human instructions). The last + message should be human instruction containing enough information to generate the function calling payload + suitable for the OpenAPI service/tool used. See the examples in the class docstring. :param fc_generator_kwargs: Additional arguments for the function calling payload generation process. - :param tool_spec: OpenAPI specification for the tool/service, overrides the one provided at initialization. - :param tool_credentials: Credentials for the tool/service, overrides the one provided at initialization. + :param spec: OpenAPI specification for the tool/service, overrides the one provided at initialization. + :param credentials: Credentials for the tool/service, overrides the one provided at initialization. :returns: a dictionary containing the service response with the following key: - `service_response`: List of ChatMessages containing the service response. ChatMessages are generated based on the response from the OpenAPI service/tool and contains the JSON response from the service. @@ -112,10 +115,10 @@ def run( # build a new ClientConfiguration and OpenAPIServiceClient if a runtime tool_spec is provided openapi_service: Optional[OpenAPIServiceClient] = self.open_api_service config_openapi: Optional[ClientConfiguration] = self.config_openapi - if tool_spec: + if spec: config_openapi = ClientConfiguration( - openapi_spec=self._create_openapi_spec(tool_spec), - credentials=tool_credentials, + openapi_spec=self._create_openapi_spec(spec), + credentials=credentials.resolve_value() if credentials else None, llm_provider=self.generator_api, ) openapi_service = OpenAPIServiceClient(config_openapi) From 5de700310472fb7b1cd0ac74efd76a79392025d5 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Tue, 18 Jun 2024 15:31:19 +0200 Subject: [PATCH 30/40] Final touches --- .../components/tools/openapi/openapi_tool.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/haystack_experimental/components/tools/openapi/openapi_tool.py b/haystack_experimental/components/tools/openapi/openapi_tool.py index 7f412544..2b5dd876 100644 --- a/haystack_experimental/components/tools/openapi/openapi_tool.py +++ b/haystack_experimental/components/tools/openapi/openapi_tool.py @@ -44,8 +44,8 @@ class OpenAPITool: tool = OpenAPITool(generator_api=LLMProvider.OPENAI, generator_api_params={"model":"gpt-3.5-turbo"}, - tool_spec="https://raw.githubusercontent.com/mendableai/firecrawl/main/apps/api/openapi.json", - tool_credentials="") + spec="https://raw.githubusercontent.com/mendableai/firecrawl/main/apps/api/openapi.json", + credentials=Secret.from_token("")) results = tool.run(messages=[ChatMessage.from_user("Scrape URL: https://news.ycombinator.com/")]) print(results) @@ -167,6 +167,12 @@ def _init_generator(self, generator_api: LLMProvider, generator_api_params: Dict raise ValueError(f"Unsupported generator API: {generator_api}") def _create_openapi_spec(self, openapi_spec: Union[Path, str]) -> OpenAPISpecification: + """ + Create an OpenAPISpecification object from the provided OpenAPI specification. + + :param openapi_spec: OpenAPI specification for the tool/service. This can be a URL, a local file path, or + an OpenAPI service specification provided as a string. + """ if isinstance(openapi_spec, (str, Path)) and os.path.isfile(openapi_spec): return OpenAPISpecification.from_file(openapi_spec) if isinstance(openapi_spec, str): From 32d8ca04fc0c853eb5542a0325f9e0e4c2177a3f Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Tue, 18 Jun 2024 15:32:57 +0200 Subject: [PATCH 31/40] Final touches - pydoc --- haystack_experimental/components/tools/openapi/openapi_tool.py | 1 + 1 file changed, 1 insertion(+) diff --git a/haystack_experimental/components/tools/openapi/openapi_tool.py b/haystack_experimental/components/tools/openapi/openapi_tool.py index 2b5dd876..d0bf7a48 100644 --- a/haystack_experimental/components/tools/openapi/openapi_tool.py +++ b/haystack_experimental/components/tools/openapi/openapi_tool.py @@ -41,6 +41,7 @@ class OpenAPITool: ```python from haystack.dataclasses import ChatMessage from haystack_experimental.components.tools.openapi import OpenAPITool, LLMProvider + from haystack.utils import Secret tool = OpenAPITool(generator_api=LLMProvider.OPENAI, generator_api_params={"model":"gpt-3.5-turbo"}, From 4f55c96606af473060c6af46fb64c3537a264d7d Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Tue, 18 Jun 2024 15:57:57 +0200 Subject: [PATCH 32/40] Minor detail around ClientConfiguration LLMProvider setting --- haystack_experimental/components/tools/openapi/_openapi.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/haystack_experimental/components/tools/openapi/_openapi.py b/haystack_experimental/components/tools/openapi/_openapi.py index a19c9c58..de3675fd 100644 --- a/haystack_experimental/components/tools/openapi/_openapi.py +++ b/haystack_experimental/components/tools/openapi/_openapi.py @@ -134,7 +134,7 @@ def __init__( openapi_spec: OpenAPISpecification, credentials: Optional[str] = None, request_sender: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None, - llm_provider: Optional[LLMProvider] = None, + llm_provider: LLMProvider = LLMProvider.OPENAI, ): # noqa: PLR0913 """ Initialize a ClientConfiguration instance. @@ -148,7 +148,7 @@ def __init__( self.openapi_spec = openapi_spec self.credentials = credentials self.request_sender = request_sender or send_request - self.llm_provider: LLMProvider = llm_provider or LLMProvider.OPENAI + self.llm_provider: LLMProvider = llm_provider def get_auth_function(self) -> Callable[[Dict[str, Any], Dict[str, Any]], Any]: """ From dfcae5c258cf98dea5dd81c076d91ea9d68910dc Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Tue, 18 Jun 2024 16:55:51 +0200 Subject: [PATCH 33/40] Make use of OpenAPISpecification explicit --- .../components/tools/openapi/openapi_tool.py | 38 +++++++++---------- 1 file changed, 17 insertions(+), 21 deletions(-) diff --git a/haystack_experimental/components/tools/openapi/openapi_tool.py b/haystack_experimental/components/tools/openapi/openapi_tool.py index d0bf7a48..f8e72dcc 100644 --- a/haystack_experimental/components/tools/openapi/openapi_tool.py +++ b/haystack_experimental/components/tools/openapi/openapi_tool.py @@ -77,10 +77,17 @@ def __init__( self.config_openapi: Optional[ClientConfiguration] = None self.open_api_service: Optional[OpenAPIServiceClient] = None if spec: + if os.path.isfile(spec): + openapi_spec = OpenAPISpecification.from_file(spec) + elif is_valid_http_url(str(spec)): + openapi_spec = OpenAPISpecification.from_url(str(spec)) + else: + raise ValueError(f"Invalid OpenAPI specification source {spec}. Expected valid file path or URL") + self.config_openapi = ClientConfiguration( - openapi_spec=self._create_openapi_spec(spec), + openapi_spec=openapi_spec, credentials=credentials.resolve_value() if credentials else None, - llm_provider=generator_api + llm_provider=generator_api, ) self.open_api_service = OpenAPIServiceClient(self.config_openapi) @@ -117,8 +124,15 @@ def run( openapi_service: Optional[OpenAPIServiceClient] = self.open_api_service config_openapi: Optional[ClientConfiguration] = self.config_openapi if spec: + if os.path.isfile(spec): + openapi_spec = OpenAPISpecification.from_file(spec) + elif is_valid_http_url(str(spec)): + openapi_spec = OpenAPISpecification.from_url(str(spec)) + else: + raise ValueError(f"Invalid OpenAPI specification source {spec}. Expected valid file path or URL") + config_openapi = ClientConfiguration( - openapi_spec=self._create_openapi_spec(spec), + openapi_spec=openapi_spec, credentials=credentials.resolve_value() if credentials else None, llm_provider=self.generator_api, ) @@ -166,21 +180,3 @@ def _init_generator(self, generator_api: LLMProvider, generator_api_params: Dict anthropic_import.check() return AnthropicChatGenerator(**generator_api_params) raise ValueError(f"Unsupported generator API: {generator_api}") - - def _create_openapi_spec(self, openapi_spec: Union[Path, str]) -> OpenAPISpecification: - """ - Create an OpenAPISpecification object from the provided OpenAPI specification. - - :param openapi_spec: OpenAPI specification for the tool/service. This can be a URL, a local file path, or - an OpenAPI service specification provided as a string. - """ - if isinstance(openapi_spec, (str, Path)) and os.path.isfile(openapi_spec): - return OpenAPISpecification.from_file(openapi_spec) - if isinstance(openapi_spec, str): - if is_valid_http_url(openapi_spec): - return OpenAPISpecification.from_url(openapi_spec) - return OpenAPISpecification.from_str(openapi_spec) - - raise ValueError( - "Invalid OpenAPI specification format. Expected file path or dictionary." - ) From 503e71f0047e93400b99bea845a8a1140c4755f8 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Tue, 18 Jun 2024 17:21:08 +0200 Subject: [PATCH 34/40] First batch of OpenAPITool unit and integration tests --- .../tools/openapi/test_openapi_tool.py | 124 ++++++++++++++++++ 1 file changed, 124 insertions(+) create mode 100644 test/components/tools/openapi/test_openapi_tool.py diff --git a/test/components/tools/openapi/test_openapi_tool.py b/test/components/tools/openapi/test_openapi_tool.py new file mode 100644 index 00000000..b10aaa2f --- /dev/null +++ b/test/components/tools/openapi/test_openapi_tool.py @@ -0,0 +1,124 @@ +import json +import os + +from haystack.components.generators.chat import OpenAIChatGenerator +from haystack.dataclasses import ChatMessage +from haystack.utils import Secret + +from haystack_experimental.components.tools.openapi import LLMProvider +from haystack_experimental.components.tools.openapi.openapi_tool import OpenAPITool + +import pytest + + +class TestOpenAPITool: + + def test_initialize_with_valid_openapi_spec_url_and_credentials(self): + openapi_spec_url = "https://raw.githubusercontent.com/mendableai/firecrawl/main/apps/api/openapi.json" + credentials = Secret.from_token("") + tool = OpenAPITool( + generator_api=LLMProvider.OPENAI, + generator_api_params={ + "model": "gpt-3.5-turbo", + "api_key": Secret.from_token("not_needed"), + }, + spec=openapi_spec_url, + credentials=credentials, + ) + + assert tool.generator_api == LLMProvider.OPENAI + assert isinstance(tool.chat_generator, OpenAIChatGenerator) + assert tool.config_openapi is not None + assert tool.open_api_service is not None + + @pytest.mark.skipif( + "SERPERDEV_API_KEY" not in os.environ, reason="SERPERDEV_API_KEY not set" + ) + @pytest.mark.skipif( + "OPENAI_API_KEY" not in os.environ, reason="OPENAI_API_KEY not set" + ) + @pytest.mark.integration + def test_run_live_openai(self): + tool = OpenAPITool( + generator_api=LLMProvider.OPENAI, + spec="https://bit.ly/serper_dev_spec_yaml", + credentials=Secret.from_env_var("SERPERDEV_API_KEY"), + ) + + user_message = ChatMessage.from_user( + "Scrape URL: https://news.ycombinator.com/" + ) + + results = tool.run(messages=[user_message]) + + assert isinstance(results["service_response"], list) + assert len(results["service_response"]) == 1 + assert isinstance(results["service_response"][0], ChatMessage) + + try: + json_response = json.loads(results["service_response"][0].content) + assert isinstance(json_response, dict) + except json.JSONDecodeError: + pytest.fail("Response content is not valid JSON") + + @pytest.mark.skipif( + "SERPERDEV_API_KEY" not in os.environ, reason="SERPERDEV_API_KEY not set" + ) + @pytest.mark.skipif( + "ANTHROPIC_API_KEY" not in os.environ, reason="ANTHROPIC_API_KEY not set" + ) + @pytest.mark.integration + def test_run_live_anthropic(self): + tool = OpenAPITool( + generator_api=LLMProvider.ANTHROPIC, + generator_api_params={"model": "claude-3-opus-20240229"}, + spec="https://bit.ly/serper_dev_spec_yaml", + credentials=Secret.from_env_var("SERPERDEV_API_KEY"), + ) + + user_message = ChatMessage.from_user( + "Scrape URL: https://news.ycombinator.com/" + ) + + results = tool.run(messages=[user_message]) + + assert isinstance(results["service_response"], list) + assert len(results["service_response"]) == 1 + assert isinstance(results["service_response"][0], ChatMessage) + + try: + json_response = json.loads(results["service_response"][0].content) + assert isinstance(json_response, dict) + except json.JSONDecodeError: + pytest.fail("Response content is not valid JSON") + + @pytest.mark.skipif( + "SERPERDEV_API_KEY" not in os.environ, reason="SERPERDEV_API_KEY not set" + ) + @pytest.mark.skipif( + "COHERE_API_KEY" not in os.environ, reason="COHERE_API_KEY not set" + ) + @pytest.mark.integration + def test_run_live_cohere(self): + tool = OpenAPITool( + generator_api=LLMProvider.COHERE, + generator_api_params={"model": "command-r"}, + spec="https://bit.ly/serper_dev_spec_yaml", + credentials=Secret.from_env_var("SERPERDEV_API_KEY"), + ) + + user_message = ChatMessage.from_user( + "Scrape URL: https://news.ycombinator.com/" + ) + + results = tool.run(messages=[user_message]) + + assert isinstance(results["service_response"], list) + assert len(results["service_response"]) == 1 + assert isinstance(results["service_response"][0], ChatMessage) + + try: + json_response = json.loads(results["service_response"][0].content) + assert isinstance(json_response, dict) + except json.JSONDecodeError: + pytest.fail("Response content is not valid JSON") From d371dba7aff72f1a30de474d3237179529f9deb4 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Tue, 18 Jun 2024 18:22:48 +0200 Subject: [PATCH 35/40] Add serde and unit tests --- .../components/tools/openapi/openapi_tool.py | 43 +++++++++++++- .../tools/openapi/test_openapi_tool.py | 57 +++++++++++++++++++ 2 files changed, 97 insertions(+), 3 deletions(-) diff --git a/haystack_experimental/components/tools/openapi/openapi_tool.py b/haystack_experimental/components/tools/openapi/openapi_tool.py index f8e72dcc..19d24f70 100644 --- a/haystack_experimental/components/tools/openapi/openapi_tool.py +++ b/haystack_experimental/components/tools/openapi/openapi_tool.py @@ -7,11 +7,11 @@ from pathlib import Path from typing import Any, Dict, List, Optional, Union -from haystack import component, logging +from haystack import component, default_from_dict, default_to_dict, logging from haystack.components.generators.chat import OpenAIChatGenerator from haystack.dataclasses import ChatMessage, ChatRole from haystack.lazy_imports import LazyImport -from haystack.utils import Secret +from haystack.utils import Secret, deserialize_secrets_inplace from haystack.utils.url_validation import is_valid_http_url from haystack_experimental.components.tools.openapi._openapi import ( @@ -73,6 +73,7 @@ def __init__( :param credentials: Credentials for the tool/service. """ self.generator_api = generator_api + self.generator_api_params = generator_api_params or {} # store the generator API parameters for serialization self.chat_generator = self._init_generator(generator_api, generator_api_params or {}) self.config_openapi: Optional[ClientConfiguration] = None self.open_api_service: Optional[OpenAPIServiceClient] = None @@ -83,7 +84,8 @@ def __init__( openapi_spec = OpenAPISpecification.from_url(str(spec)) else: raise ValueError(f"Invalid OpenAPI specification source {spec}. Expected valid file path or URL") - + self.spec = spec # store the spec for serialization + self.credentials = credentials # store the credentials for serialization self.config_openapi = ClientConfiguration( openapi_spec=openapi_spec, credentials=credentials.resolve_value() if credentials else None, @@ -167,6 +169,41 @@ def run( return {"service_response": response_messages} + def to_dict(self) -> Dict[str, Any]: + """ + Serialize this component to a dictionary. + + :returns: + The serialized component as a dictionary. + """ + if "api_key" in self.generator_api_params: + self.generator_api_params["api_key"] = self.generator_api_params["api_key"].to_dict() + + return default_to_dict( + self, + generator_api=self.generator_api.value, + generator_api_params=self.generator_api_params, + spec=self.spec, + credentials=self.credentials.to_dict() if self.credentials else None, + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "OpenAPITool": + """ + Deserialize this component from a dictionary. + + :param data: The dictionary representation of this component. + :returns: + The deserialized component instance. + """ + deserialize_secrets_inplace(data["init_parameters"], keys=["credentials"]) + if "generator_api_params" in data["init_parameters"]: + deserialize_secrets_inplace(data["init_parameters"]["generator_api_params"], keys=["api_key"]) + init_params = data.get("init_parameters", {}) + generator_api = init_params.get("generator_api") + data["init_parameters"]["generator_api"] = LLMProvider(generator_api) + return default_from_dict(cls, data) + def _init_generator(self, generator_api: LLMProvider, generator_api_params: Dict[str, Any]): """ Initialize the chat generator based on the specified API provider and parameters. diff --git a/test/components/tools/openapi/test_openapi_tool.py b/test/components/tools/openapi/test_openapi_tool.py index b10aaa2f..65358616 100644 --- a/test/components/tools/openapi/test_openapi_tool.py +++ b/test/components/tools/openapi/test_openapi_tool.py @@ -13,6 +13,63 @@ class TestOpenAPITool: + def test_to_dict(self, monkeypatch): + monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") + monkeypatch.setenv("SERPERDEV_API_KEY", "fake-api-key") + + openapi_spec_url = "https://raw.githubusercontent.com/mendableai/firecrawl/main/apps/api/openapi.json" + + tool = OpenAPITool( + generator_api=LLMProvider.OPENAI, + generator_api_params={ + "model": "gpt-3.5-turbo", + "api_key": Secret.from_env_var("OPENAI_API_KEY"), + }, + spec=openapi_spec_url, + credentials=Secret.from_env_var("SERPERDEV_API_KEY"), + ) + + data = tool.to_dict() + assert data == { + "type": "haystack_experimental.components.tools.openapi.openapi_tool.OpenAPITool", + "init_parameters": { + "generator_api": "openai", + "generator_api_params": { + "model": "gpt-3.5-turbo", + "api_key": {"env_vars": ["OPENAI_API_KEY"], "strict": True, "type": "env_var"}, + }, + "spec": openapi_spec_url, + "credentials": {"env_vars": ["SERPERDEV_API_KEY"], "strict": True, "type": "env_var"}, + }, + } + + def test_from_dict(self, monkeypatch): + monkeypatch.setenv("OPENAI_API_KEY", "fake-api-key") + monkeypatch.setenv("SERPERDEV_API_KEY", "fake-api-key") + openapi_spec_url = "https://raw.githubusercontent.com/mendableai/firecrawl/main/apps/api/openapi.json" + data = { + "type": "haystack_experimental.components.tools.openapi.openapi_tool.OpenAPITool", + "init_parameters": { + "generator_api": "openai", + "generator_api_params": { + "model": "gpt-3.5-turbo", + "api_key": {"env_vars": ["OPENAI_API_KEY"], "strict": True, "type": "env_var"}, + }, + "spec": openapi_spec_url, + "credentials": {"env_vars": ["SERPERDEV_API_KEY"], "strict": True, "type": "env_var"}, + }, + } + + tool = OpenAPITool.from_dict(data) + + assert tool.generator_api == LLMProvider.OPENAI + assert tool.generator_api_params == { + "model": "gpt-3.5-turbo", + "api_key": Secret.from_env_var("OPENAI_API_KEY") + } + assert tool.spec == openapi_spec_url + assert tool.credentials == Secret.from_env_var("SERPERDEV_API_KEY") + def test_initialize_with_valid_openapi_spec_url_and_credentials(self): openapi_spec_url = "https://raw.githubusercontent.com/mendableai/firecrawl/main/apps/api/openapi.json" credentials = Secret.from_token("") From e5c76c7f34855bca13769817829d8976ee397f74 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Tue, 18 Jun 2024 23:49:44 +0200 Subject: [PATCH 36/40] Skip github test --- test/components/tools/openapi/test_openapi_client_live_openai.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/components/tools/openapi/test_openapi_client_live_openai.py b/test/components/tools/openapi/test_openapi_client_live_openai.py index 716f04ea..8be05a79 100644 --- a/test/components/tools/openapi/test_openapi_client_live_openai.py +++ b/test/components/tools/openapi/test_openapi_client_live_openai.py @@ -39,6 +39,7 @@ def test_serperdev(self, test_files_path): @pytest.mark.skipif("OPENAI_API_KEY" not in os.environ, reason="OPENAI_API_KEY not set") @pytest.mark.integration + @pytest.mark.skip("This test hits rate limit on Github API. Skip for now.") def test_github(self, test_files_path): config = ClientConfiguration(openapi_spec=create_openapi_spec(test_files_path / "yaml" / "github_compare.yml")) client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) From ecdc37d9c738d8b36ddf99576d9cbf578036aca4 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Tue, 18 Jun 2024 23:55:56 +0200 Subject: [PATCH 37/40] Skip github tests --- test/components/tools/openapi/test_openapi_client_live.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/components/tools/openapi/test_openapi_client_live.py b/test/components/tools/openapi/test_openapi_client_live.py index a248502b..3c3179d5 100644 --- a/test/components/tools/openapi/test_openapi_client_live.py +++ b/test/components/tools/openapi/test_openapi_client_live.py @@ -30,6 +30,7 @@ def test_serperdev(self, test_files_path): assert "invention" in str(response) @pytest.mark.integration + @pytest.mark.skip("This test hits rate limit on Github API. Skip for now.") def test_github(self, test_files_path): config = ClientConfiguration(openapi_spec=create_openapi_spec(test_files_path / "yaml" / "github_compare.yml")) api = OpenAPIServiceClient(config) From a5db73bbdfabc96cf606c65b55b45ee6ff28249c Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Wed, 19 Jun 2024 06:08:01 +0200 Subject: [PATCH 38/40] Increase default request timeout to 30 sec --- haystack_experimental/components/tools/openapi/_openapi.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/haystack_experimental/components/tools/openapi/_openapi.py b/haystack_experimental/components/tools/openapi/_openapi.py index de3675fd..2b13f70e 100644 --- a/haystack_experimental/components/tools/openapi/_openapi.py +++ b/haystack_experimental/components/tools/openapi/_openapi.py @@ -39,7 +39,7 @@ def send_request(request: Dict[str, Any]) -> Dict[str, Any]: params=request.get("params", {}), json=request.get("json"), auth=request.get("auth"), - timeout=10, + timeout=30, ) response.raise_for_status() return response.json() From 37e6e541c5851a605d718de01018b8c34c29ff06 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Wed, 19 Jun 2024 12:07:36 +0200 Subject: [PATCH 39/40] PR review --- .../components/tools/openapi/openapi_tool.py | 14 ++++------ .../components/tools/openapi/types.py | 15 ++++++++++ haystack_experimental/util/__init__.py | 4 +++ haystack_experimental/util/auth.py | 25 +++++++++++++++++ .../tools/openapi/test_openapi_tool.py | 28 +++++++++++++++++-- 5 files changed, 75 insertions(+), 11 deletions(-) create mode 100644 haystack_experimental/util/auth.py diff --git a/haystack_experimental/components/tools/openapi/openapi_tool.py b/haystack_experimental/components/tools/openapi/openapi_tool.py index 19d24f70..33c64d4a 100644 --- a/haystack_experimental/components/tools/openapi/openapi_tool.py +++ b/haystack_experimental/components/tools/openapi/openapi_tool.py @@ -19,6 +19,7 @@ OpenAPIServiceClient, ) from haystack_experimental.components.tools.openapi.types import LLMProvider, OpenAPISpecification +from haystack_experimental.util import serialize_secrets_inplace with LazyImport("Run 'pip install anthropic-haystack'") as anthropic_import: # pylint: disable=import-error @@ -77,6 +78,8 @@ def __init__( self.chat_generator = self._init_generator(generator_api, generator_api_params or {}) self.config_openapi: Optional[ClientConfiguration] = None self.open_api_service: Optional[OpenAPIServiceClient] = None + self.spec = spec # store the spec for serialization + self.credentials = credentials # store the credentials for serialization if spec: if os.path.isfile(spec): openapi_spec = OpenAPISpecification.from_file(spec) @@ -84,8 +87,6 @@ def __init__( openapi_spec = OpenAPISpecification.from_url(str(spec)) else: raise ValueError(f"Invalid OpenAPI specification source {spec}. Expected valid file path or URL") - self.spec = spec # store the spec for serialization - self.credentials = credentials # store the credentials for serialization self.config_openapi = ClientConfiguration( openapi_spec=openapi_spec, credentials=credentials.resolve_value() if credentials else None, @@ -176,9 +177,7 @@ def to_dict(self) -> Dict[str, Any]: :returns: The serialized component as a dictionary. """ - if "api_key" in self.generator_api_params: - self.generator_api_params["api_key"] = self.generator_api_params["api_key"].to_dict() - + serialize_secrets_inplace(self.generator_api_params, keys=["api_key"], recursive=True) return default_to_dict( self, generator_api=self.generator_api.value, @@ -197,11 +196,10 @@ def from_dict(cls, data: Dict[str, Any]) -> "OpenAPITool": The deserialized component instance. """ deserialize_secrets_inplace(data["init_parameters"], keys=["credentials"]) - if "generator_api_params" in data["init_parameters"]: - deserialize_secrets_inplace(data["init_parameters"]["generator_api_params"], keys=["api_key"]) + deserialize_secrets_inplace(data["init_parameters"]["generator_api_params"], keys=["api_key"]) init_params = data.get("init_parameters", {}) generator_api = init_params.get("generator_api") - data["init_parameters"]["generator_api"] = LLMProvider(generator_api) + data["init_parameters"]["generator_api"] = LLMProvider.from_str(generator_api) return default_from_dict(cls, data) def _init_generator(self, generator_api: LLMProvider, generator_api_params: Dict[str, Any]): diff --git a/haystack_experimental/components/tools/openapi/types.py b/haystack_experimental/components/tools/openapi/types.py index 00f3c2aa..2562daa2 100644 --- a/haystack_experimental/components/tools/openapi/types.py +++ b/haystack_experimental/components/tools/openapi/types.py @@ -31,6 +31,21 @@ class LLMProvider(Enum): ANTHROPIC = "anthropic" COHERE = "cohere" + @staticmethod + def from_str(string: str) -> "LLMProvider": + """ + Convert a string to a LLMProvider enum. + """ + provider_map = {e.value: e for e in LLMProvider} + provider = provider_map.get(string) + if provider is None: + msg = ( + f"Invalid LLMProvider '{string}'" + f"Supported LLMProviders are: {list(provider_map.keys())}" + ) + raise ValueError(msg) + return provider + @dataclass class Operation: diff --git a/haystack_experimental/util/__init__.py b/haystack_experimental/util/__init__.py index c1764a6e..032a10bc 100644 --- a/haystack_experimental/util/__init__.py +++ b/haystack_experimental/util/__init__.py @@ -1,3 +1,7 @@ # SPDX-FileCopyrightText: 2022-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 + +from haystack_experimental.util.auth import serialize_secrets_inplace + +__all__ = ["serialize_secrets_inplace"] diff --git a/haystack_experimental/util/auth.py b/haystack_experimental/util/auth.py new file mode 100644 index 00000000..4db0d3ef --- /dev/null +++ b/haystack_experimental/util/auth.py @@ -0,0 +1,25 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Dict, Iterable + +from haystack.utils import Secret + + +def serialize_secrets_inplace(data: Dict[str, Any], keys: Iterable[str], *, recursive: bool = False) -> None: + """ + Serialize secrets in a dictionary inplace. + + :param data: + The dictionary with the data containing secrets. + :param keys: + The keys of the secrets to serialize. + :param recursive: + Whether to recursively serialize nested dictionaries. + """ + for k, v in data.items(): + if isinstance(v, dict) and recursive: + serialize_secrets_inplace(v, keys, recursive=True) + elif k in keys and isinstance(v, Secret): + data[k] = v.to_dict() diff --git a/test/components/tools/openapi/test_openapi_tool.py b/test/components/tools/openapi/test_openapi_tool.py index 65358616..5119ba61 100644 --- a/test/components/tools/openapi/test_openapi_tool.py +++ b/test/components/tools/openapi/test_openapi_tool.py @@ -70,6 +70,28 @@ def test_from_dict(self, monkeypatch): assert tool.spec == openapi_spec_url assert tool.credentials == Secret.from_env_var("SERPERDEV_API_KEY") + def test_initialize_with_invalid_openapi_spec_url(self): + with pytest.raises(ConnectionError, match="Failed to fetch the specification from URL"): + OpenAPITool( + generator_api=LLMProvider.OPENAI, + generator_api_params={ + "model": "gpt-3.5-turbo", + "api_key": Secret.from_token("not_needed"), + }, + spec="https://raw.githubusercontent.com/invalid_openapi.json", + ) + + def test_initialize_with_invalid_openapi_spec_path(self): + with pytest.raises(ValueError, match="Invalid OpenAPI specification source"): + OpenAPITool( + generator_api=LLMProvider.OPENAI, + generator_api_params={ + "model": "gpt-3.5-turbo", + "api_key": Secret.from_token("not_needed"), + }, + spec="invalid_openapi.json", + ) + def test_initialize_with_valid_openapi_spec_url_and_credentials(self): openapi_spec_url = "https://raw.githubusercontent.com/mendableai/firecrawl/main/apps/api/openapi.json" credentials = Secret.from_token("") @@ -103,7 +125,7 @@ def test_run_live_openai(self): ) user_message = ChatMessage.from_user( - "Scrape URL: https://news.ycombinator.com/" + "Search for 'Who was Nikola Tesla?'" ) results = tool.run(messages=[user_message]) @@ -134,7 +156,7 @@ def test_run_live_anthropic(self): ) user_message = ChatMessage.from_user( - "Scrape URL: https://news.ycombinator.com/" + "Search for 'Who was Nikola Tesla?'" ) results = tool.run(messages=[user_message]) @@ -165,7 +187,7 @@ def test_run_live_cohere(self): ) user_message = ChatMessage.from_user( - "Scrape URL: https://news.ycombinator.com/" + "Search for 'Who was Nikola Tesla?'" ) results = tool.run(messages=[user_message]) From 7eb140f107eb8dd4c874c8b90a57be2fe6f148a5 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Mon, 24 Jun 2024 19:09:35 +0200 Subject: [PATCH 40/40] Add to Experiments catalog --- README.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 6af9aedf..a3ef99b2 100644 --- a/README.md +++ b/README.md @@ -37,10 +37,12 @@ the experiment will be either: The latest version of the package contains the following experiments: | Name | Type | Experiment end date | -| ------------------------ | ----------------------- | ------------------- | +|--------------------------|-------------------------| ------------------- | | [`EvaluationHarness`][1] | Evaluation orchestrator | August 2024 | +| [`OpenAPITool`][2] | OpenAPITool component | August 2024 | [1]: https://github.com/deepset-ai/haystack-experimental/tree/main/haystack_experimental/evaluation/harness +[2]: https://github.com/deepset-ai/haystack-experimental/tree/main/haystack_experimental/components/tools/openapi ## Usage