From 42525a3f237a903366411f4504df3af3329b56e3 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Tue, 28 May 2024 09:36:28 +0200 Subject: [PATCH] Initial openapi impl --- .../components/connectors/__init__.py | 7 + .../components/connectors/openapi.py | 125 +++ .../components/converters/__init__.py | 7 + .../components/converters/openapi.py | 83 ++ haystack_experimental/util/__init__.py | 3 + haystack_experimental/util/openapi.py | 968 ++++++++++++++++++ pyproject.toml | 5 +- test/components/__init__.py | 3 + test/components/connectors/__init__.py | 3 + test/components/connectors/test_openapi.py | 114 +++ test/components/converters/__init__.py | 3 + test/components/converters/test_openapi.py | 246 +++++ .../json/complex_types_openai_spec.json | 64 ++ .../json/complex_types_openapi_service.json | 103 ++ .../json/openapi_order_service.json | 109 ++ .../json/serperdev_openapi_spec.json | 62 ++ test/test_files/yaml/github_compare.yml | 438 ++++++++ test/test_files/yaml/openapi_edge_cases.yml | 13 + .../yaml/openapi_error_handling.yml | 24 + .../yaml/openapi_greeting_service.yml | 272 +++++ .../test_files/yaml/openapi_order_service.yml | 75 ++ test/test_files/yaml/serper.yml | 39 + test/util/__init__.py | 3 + test/util/conftest.py | 59 ++ test/util/test_openapi_client.py | 129 +++ test/util/test_openapi_client_auth.py | 238 +++++ ...est_openapi_client_complex_request_body.py | 87 ++ ...enapi_client_complex_request_body_mixed.py | 90 ++ test/util/test_openapi_client_edge_cases.py | 38 + .../test_openapi_client_error_handling.py | 46 + test/util/test_openapi_client_live.py | 73 ++ .../test_openapi_client_live_anthropic.py | 69 ++ test/util/test_openapi_client_live_cohere.py | 73 ++ test/util/test_openapi_client_live_openai.py | 99 ++ test/util/test_openapi_cohere_conversion.py | 79 ++ test/util/test_openapi_openai_conversion.py | 104 ++ test/util/test_openapi_spec.py | 127 +++ 37 files changed, 4079 insertions(+), 1 deletion(-) create mode 100644 haystack_experimental/components/connectors/__init__.py create mode 100644 haystack_experimental/components/connectors/openapi.py create mode 100644 haystack_experimental/components/converters/__init__.py create mode 100644 haystack_experimental/components/converters/openapi.py create mode 100644 haystack_experimental/util/__init__.py create mode 100644 haystack_experimental/util/openapi.py create mode 100644 test/components/__init__.py create mode 100644 test/components/connectors/__init__.py create mode 100644 test/components/connectors/test_openapi.py create mode 100644 test/components/converters/__init__.py create mode 100644 test/components/converters/test_openapi.py create mode 100644 test/test_files/json/complex_types_openai_spec.json create mode 100644 test/test_files/json/complex_types_openapi_service.json create mode 100644 test/test_files/json/openapi_order_service.json create mode 100644 test/test_files/json/serperdev_openapi_spec.json create mode 100644 test/test_files/yaml/github_compare.yml create mode 100644 test/test_files/yaml/openapi_edge_cases.yml create mode 100644 test/test_files/yaml/openapi_error_handling.yml create mode 100644 test/test_files/yaml/openapi_greeting_service.yml create mode 100644 test/test_files/yaml/openapi_order_service.yml create mode 100644 test/test_files/yaml/serper.yml create mode 100644 test/util/__init__.py create mode 100644 test/util/conftest.py create mode 100644 test/util/test_openapi_client.py create mode 100644 test/util/test_openapi_client_auth.py create mode 100644 test/util/test_openapi_client_complex_request_body.py create mode 100644 test/util/test_openapi_client_complex_request_body_mixed.py create mode 100644 test/util/test_openapi_client_edge_cases.py create mode 100644 test/util/test_openapi_client_error_handling.py create mode 100644 test/util/test_openapi_client_live.py create mode 100644 test/util/test_openapi_client_live_anthropic.py create mode 100644 test/util/test_openapi_client_live_cohere.py create mode 100644 test/util/test_openapi_client_live_openai.py create mode 100644 test/util/test_openapi_cohere_conversion.py create mode 100644 test/util/test_openapi_openai_conversion.py create mode 100644 test/util/test_openapi_spec.py diff --git a/haystack_experimental/components/connectors/__init__.py b/haystack_experimental/components/connectors/__init__.py new file mode 100644 index 00000000..9c3d11ec --- /dev/null +++ b/haystack_experimental/components/connectors/__init__.py @@ -0,0 +1,7 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from haystack_experimental.components.connectors.openapi import OpenAPIServiceConnector + +__all__ = ["OpenAPIServiceConnector"] diff --git a/haystack_experimental/components/connectors/openapi.py b/haystack_experimental/components/connectors/openapi.py new file mode 100644 index 00000000..a9afce5c --- /dev/null +++ b/haystack_experimental/components/connectors/openapi.py @@ -0,0 +1,125 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import json +from typing import Any, Dict, List, Optional, Union + +from haystack import component, logging +from haystack.dataclasses import ChatMessage, ChatRole + +from haystack_experimental.util.openapi import ClientConfigurationBuilder, OpenAPIServiceClient, validate_provider + +logger = logging.getLogger(__name__) + + +@component +class OpenAPIServiceConnector: + """ + The `OpenAPIServiceConnector` component connects the Haystack framework to OpenAPI services. + + It integrates with `ChatMessage` dataclass, where the payload in messages is used to determine the method to be + called and the parameters to be passed. The response from the service is returned as a `ChatMessage`. + + Function calling payloads from OpenAI, Anthropic, and Cohere LLMs are supported. + + Before using this component, users usually resolve function calling function definitions with a help of + `OpenAPIServiceToFunctions` component. + + The example below demonstrates how to use the `OpenAPIServiceConnector` to invoke a method on a + https://serper.dev/ service specified via OpenAPI specification. + + Note, however, that `OpenAPIServiceConnector` is usually not meant to be used directly, but rather as part of a + pipeline that includes the `OpenAPIServiceToFunctions` component and an `OpenAIChatGenerator` component using LLM + with the function calling capabilities. In the example below we use the function calling payload directly, but in a + real-world scenario, the function calling payload would usually be generated by the `OpenAIChatGenerator` + component. + + Usage example: + + ```python + import json + import requests + + from haystack_experimental.components.connectors import OpenAPIServiceConnector + from haystack.dataclasses import ChatMessage + + + fc_payload = [{'function': {'arguments': '{"q": "Why was Sam Altman ousted from OpenAI?"}', 'name': 'search'}, + 'id': 'call_PmEBYvZ7mGrQP5PUASA5m9wO', 'type': 'function'}] + + serper_token = + serperdev_openapi_spec = json.loads(requests.get("https://bit.ly/serper_dev_spec").text) + service_connector = OpenAPIServiceConnector() + result = service_connector.run(messages=[ChatMessage.from_assistant(json.dumps(fc_payload))], + service_openapi_spec=serperdev_openapi_spec, service_credentials=serper_token) + print(result) + + >> {'service_response': [ChatMessage(content='{"searchParameters": {"q": "Why was Sam Altman ousted from OpenAI?", + >> "type": "search", "engine": "google"}, "answerBox": {"snippet": "Concerns over AI safety and OpenAI\'s role + >> in protecting were at the center of Altman\'s brief ouster from the company."... + ``` + + """ + + def __init__(self, provider: Optional[str] = None): + """ + Initializes the OpenAPIServiceConnector instance. + """ + self.llm_provider = validate_provider(provider or "openai") + + @component.output_types(service_response=Dict[str, Any]) + def run( + self, + messages: List[ChatMessage], + service_openapi_spec: Dict[str, Any], + service_credentials: Optional[Union[dict, str]] = None, + ) -> Dict[str, List[ChatMessage]]: + """ + Processes a list of chat messages to invoke a method on an OpenAPI service. + + It parses the last message in the list, expecting it to contain an OpenAI function calling descriptor + (name & parameters) in JSON format. + + :param messages: A list of `ChatMessage` objects containing the messages to be processed. The last message + should contain the function invocation payload in OpenAI function calling format. See the example in the class + docstring for the expected format. + :param service_openapi_spec: The OpenAPI JSON specification object of the service to be invoked. + :param service_credentials: The credentials to be used for authentication with the service. + Currently, only the http and apiKey OpenAPI security schemes are supported. + + :return: A dictionary with the following keys: + - `service_response`: a list of `ChatMessage` objects, each containing the response from the service. The + response is in JSON format, and the `content` attribute of the `ChatMessage` + contains the JSON string. + + :raises ValueError: If the last message is not from the assistant or if it does not contain the correct + payload to invoke a method on the service. + """ + + last_message = messages[-1] + if not last_message.is_from(ChatRole.ASSISTANT): + raise ValueError(f"{last_message} is not from the assistant.") + if not last_message.content: + raise ValueError("Function calling message content is empty.") + + builder = ClientConfigurationBuilder() + config_openapi = ( + builder.with_openapi_spec(service_openapi_spec) + .with_credentials(service_credentials or {}) + .with_provider(self.llm_provider) + .build() + ) + logger.debug(f"Invoking service {config_openapi.get_openapi_spec().get_name()} with {last_message.content}") + openapi_service = OpenAPIServiceClient(config_openapi) + try: + payload = ( + json.loads(last_message.content) if isinstance(last_message.content, str) else last_message.content + ) + service_response = openapi_service.invoke(payload) + except Exception as e: # pylint: disable=broad-exception-caught + logger.error(f"Error invoking OpenAPI endpoint. Error: {e}") + service_response = {"error": str(e)} + response_messages = [ChatMessage.from_user(json.dumps(service_response))] + + return {"service_response": response_messages} diff --git a/haystack_experimental/components/converters/__init__.py b/haystack_experimental/components/converters/__init__.py new file mode 100644 index 00000000..6fd23f7d --- /dev/null +++ b/haystack_experimental/components/converters/__init__.py @@ -0,0 +1,7 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from haystack_experimental.components.converters.openapi import OpenAPIServiceToFunctions + +__all__ = ["OpenAPIServiceToFunctions"] diff --git a/haystack_experimental/components/converters/openapi.py b/haystack_experimental/components/converters/openapi.py new file mode 100644 index 00000000..17d0eb77 --- /dev/null +++ b/haystack_experimental/components/converters/openapi.py @@ -0,0 +1,83 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from pathlib import Path +from typing import Any, Dict, List, Optional, Union + +from haystack import component, logging +from haystack.dataclasses.byte_stream import ByteStream + +from haystack_experimental.util.openapi import ClientConfigurationBuilder, validate_provider + +logger = logging.getLogger(__name__) + + +@component +class OpenAPIServiceToFunctions: + """ + Converts OpenAPI service schemas to a format suitable for OpenAI, Anthropic, or Cohere function calling. + + The definition must respect OpenAPI specification 3.0.0 or higher. + It can be specified in JSON or YAML format. + Each function must have: + - unique operationId + - description + - requestBody and/or parameters + - schema for the requestBody and/or parameters + For more details on OpenAPI specification see the + [official documentation](https://github.com/OAI/OpenAPI-Specification). + + Usage example: + ```python + from haystack_experimental.components.converters import OpenAPIServiceToFunctions + + converter = OpenAPIServiceToFunctions() + result = converter.run(sources=["path/to/openapi_definition.yaml"]) + assert result["functions"] + ``` + """ + + MIN_REQUIRED_OPENAPI_SPEC_VERSION = 3 + + def __init__(self, provider: Optional[str] = None): + """ + Create an OpenAPIServiceToFunctions component. + + :param provider: The LLM provider to use, defaults to "openai". + """ + self.llm_provider = validate_provider(provider or "openai") + + @component.output_types(functions=List[Dict[str, Any]], openapi_specs=List[Dict[str, Any]]) + def run(self, sources: List[Union[str, Path, ByteStream]]) -> Dict[str, Any]: + """ + Converts OpenAPI definitions into LLM specific function calling format. + + :param sources: + File paths or ByteStream objects of OpenAPI definitions (in JSON or YAML format). + + :returns: + A dictionary with the following keys: + - functions: Function definitions in JSON object format + - openapi_specs: OpenAPI specs in JSON/YAML object format with resolved references + + :raises RuntimeError: + If the OpenAPI definitions cannot be downloaded or processed. + :raises ValueError: + If the source type is not recognized or no functions are found in the OpenAPI definitions. + """ + all_extracted_fc_definitions: List[Dict[str, Any]] = [] + all_openapi_specs = [] + + builder = ClientConfigurationBuilder() + for source in sources: + source = source.to_string() if isinstance(source, ByteStream) else source + # to get tools definitions all we need is the openapi spec + config_openapi = builder.with_openapi_spec(source).with_provider(self.llm_provider).build() + + all_extracted_fc_definitions.extend(config_openapi.get_tools_definitions()) + all_openapi_specs.append(config_openapi.get_openapi_spec().to_dict(resolve_references=True)) + if not all_extracted_fc_definitions: + logger.warning("No OpenAI function definitions extracted from the provided OpenAPI specification sources.") + + return {"functions": all_extracted_fc_definitions, "openapi_specs": all_openapi_specs} diff --git a/haystack_experimental/util/__init__.py b/haystack_experimental/util/__init__.py new file mode 100644 index 00000000..c1764a6e --- /dev/null +++ b/haystack_experimental/util/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 diff --git a/haystack_experimental/util/openapi.py b/haystack_experimental/util/openapi.py new file mode 100644 index 00000000..552892ca --- /dev/null +++ b/haystack_experimental/util/openapi.py @@ -0,0 +1,968 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import dataclasses +import json +import logging +import os +from base64 import b64encode +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Callable, Dict, List, Literal, Optional, Protocol, Union, runtime_checkable +from urllib.parse import urlparse + +import jsonref +import requests +import yaml +from requests.adapters import HTTPAdapter +from urllib3 import Retry + +VALID_HTTP_METHODS = ["get", "put", "post", "delete", "options", "head", "patch", "trace"] + +MIN_REQUIRED_OPENAPI_SPEC_VERSION = 3 + +logger = logging.getLogger(__name__) + + +def validate_provider(provider: str) -> str: + """ + Check if the selected provider is supported. + + :param provider: The selected provider to validate. + :return: The validated provider. + :raises ValueError: If the selected provider is not supported. + """ + available_providers = ["openai", "anthropic", "cohere"] + if provider not in available_providers: + raise ValueError(f"LLM provider {provider} is not supported. Available providers: {available_providers}") + return provider + + +@runtime_checkable +class AuthenticationStrategy(Protocol): + """ + Represents an authentication strategy that can be applied to an HTTP request. + """ + def apply_auth(self, security_scheme: Dict[str, Any], request: Dict[str, Any]): + """ + Apply the authentication strategy to the given request. + + :param security_scheme: the security scheme from the OpenAPI spec. + :param request: the request to apply the authentication to. + """ + + +class PassThroughAuthentication(AuthenticationStrategy): + """No-op authentication strategy that does nothing.""" + def apply_auth(self, security_scheme: Dict[str, Any], request: Dict[str, Any]): + """ + No-op authentication strategy that does nothing. + """ + + +@dataclass +class ApiKeyAuthentication(AuthenticationStrategy): + """ API key authentication strategy.""" + api_key: Optional[str] = None + + def apply_auth(self, security_scheme: Dict[str, Any], request: Dict[str, Any]): + """ + Apply the API key authentication strategy to the given request. + + :param security_scheme: the security scheme from the OpenAPI spec. + :param request: the request to apply the authentication to. + """ + if security_scheme["in"] == "header": + request.setdefault("headers", {})[security_scheme["name"]] = self.api_key + elif security_scheme["in"] == "query": + request.setdefault("params", {})[security_scheme["name"]] = self.api_key + elif security_scheme["in"] == "cookie": + request.setdefault("cookies", {})[security_scheme["name"]] = self.api_key + else: + raise ValueError( + f"Unsupported apiKey authentication location: {security_scheme['in']}, " + f"must be one of 'header', 'query', or 'cookie'" + ) + + +@dataclass +class HTTPAuthentication(AuthenticationStrategy): + """ HTTP authentication strategy.""" + username: Optional[str] = None + password: Optional[str] = None + token: Optional[str] = None + + def __post_init__(self): + if not self.token and (not self.username or not self.password): + raise ValueError( + "For HTTP Basic Auth, both username and password must be provided. " + "For Bearer Auth, a token must be provided." + ) + + def apply_auth(self, security_scheme: Dict[str, Any], request: Dict[str, Any]): + """ + Apply the HTTP authentication strategy to the given request. + + :param security_scheme: the security scheme from the OpenAPI spec. + :param request: the request to apply the authentication to. + """ + if security_scheme["type"] == "http": + if security_scheme["scheme"].lower() == "basic": + if not self.username or not self.password: + raise ValueError("Username and password must be provided for Basic Auth.") + credentials = f"{self.username}:{self.password}" + encoded_credentials = b64encode(credentials.encode("utf-8")).decode("utf-8") + request.setdefault("headers", {})["Authorization"] = f"Basic {encoded_credentials}" + elif security_scheme["scheme"].lower() == "bearer": + if not self.token: + raise ValueError("Token must be provided for Bearer Auth.") + request.setdefault("headers", {})["Authorization"] = f"Bearer {self.token}" + else: + raise ValueError(f"Unsupported HTTP authentication scheme: {security_scheme['scheme']}") + else: + raise ValueError("HTTPAuthentication strategy received a non-HTTP security scheme.") + + +@dataclass +class HttpClientConfig: + """ Configuration for the HTTP client. """ + timeout: int = 10 + max_retries: int = 3 + backoff_factor: float = 0.3 + retry_on_status: set = field(default_factory=lambda: {500, 502, 503, 504}) + default_headers: Dict[str, str] = field(default_factory=dict) + + +class HttpClient: + """ HTTP client for sending requests. """ + def __init__(self, config: Optional[HttpClientConfig] = None): + self.config = config or HttpClientConfig() + self.session = requests.Session() + self._initialize_session() + + def _initialize_session(self) -> None: + retries = Retry( + total=self.config.max_retries, + backoff_factor=self.config.backoff_factor, + status_forcelist=self.config.retry_on_status, + ) + adapter = HTTPAdapter(max_retries=retries) + self.session.mount("http://", adapter) + self.session.mount("https://", adapter) + self.session.headers.update(self.config.default_headers) + + def send_request(self, request: Dict[str, Any]) -> Any: + """ + Send an HTTP request using the provided request dictionary. + + :param request: A dictionary containing the request details. + """ + url = request["url"] + method = request["method"] + headers = {**self.config.default_headers, **request.get("headers", {})} + params = request.get("params", {}) + json_data = request.get("json") + auth = request.get("auth") + try: + response = self.session.request(method, url, headers=headers, params=params, json=json_data, auth=auth) + response.raise_for_status() + return response.json() + except requests.exceptions.HTTPError as e: + logger.warning("HTTP error occurred: %s while sending request to %s", e, url) + raise HttpClientError(f"HTTP error occurred: {e}") from e + except requests.exceptions.RequestException as e: + logger.warning("Request error occurred: %s while sending request to %s", e, url) + raise HttpClientError(f"HTTP error occurred: {e}") from e + except Exception as e: + logger.warning("An error occurred: %s while sending request to %s", e, url) + raise HttpClientError(f"An error occurred: {e}") from e + + +class HttpClientError(Exception): + """Exception raised for errors in the HTTP client.""" + + +class Operation: + """ Represents an operation in an OpenAPI specification.""" + def __init__(self, path: str, method: str, operation_dict: Dict[str, Any], spec_dict: Dict[str, Any]): + if method.lower() not in VALID_HTTP_METHODS: + raise ValueError(f"Invalid HTTP method: {method}") + self.path = path + self.method = method.lower() + self.operation_dict = operation_dict + self.spec_dict = spec_dict + + def get_parameters(self, location: Optional[Literal["header", "query", "path"]] = None) -> List[Dict[str, Any]]: + """ + Get the parameters for the operation. + + :param location: The location of the parameters to retrieve. If None, all parameters are returned. + """ + parameters = self.operation_dict.get("parameters", []) + path_item = self.spec_dict.get("paths", {}).get(self.path, {}) + parameters.extend(path_item.get("parameters", [])) + if location: + return [param for param in parameters if param["in"] == location] + return parameters + + def get_request_body(self) -> Dict[str, Any]: + """ + Get the request body for the operation. + """ + return self.operation_dict.get("requestBody", {}) + + def get_responses(self) -> Dict[str, Any]: + """ + Get the responses for the operation. + """ + return self.operation_dict.get("responses", {}) + + def get_security_requirements(self) -> List[Dict[str, List[str]]]: + """ + Get the security requirements for the operation. + """ + security_requirements = self.operation_dict.get("security", []) + if not security_requirements: + security_requirements = self.spec_dict.get("security", []) + return security_requirements + + def get_server_url(self) -> str: + """ + Get the server URL for the operation. + """ + servers = self.operation_dict.get("servers", []) + if not servers: + servers = self.spec_dict.get("servers", []) + if servers: + return servers[0].get("url", "") + return "" + + def get_field(self, key: str, default: Any = None) -> Any: + """ + Get a field from the operation dictionary. + """ + return self.operation_dict.get(key, default) + + +class OpenAPISpecification: + """ Represents an OpenAPI specification.""" + def __init__(self, spec_dict: Dict[str, Any]): + if not isinstance(spec_dict, Dict): + raise ValueError(f"Invalid OpenAPI specification, expected a dictionary: {spec_dict}") + # just a crude sanity check, by no means a full validation + if "openapi" not in spec_dict or "paths" not in spec_dict or "servers" not in spec_dict: + raise ValueError( + "Invalid OpenAPI specification format. See https://swagger.io/specification/ for details.", spec_dict + ) + self.spec_dict = spec_dict + + @classmethod + def from_dict(cls, spec_dict: Dict[str, Any]) -> "OpenAPISpecification": + """ + Create an OpenAPISpecification instance from a dictionary. + """ + parser = cls(spec_dict) + return parser + + @classmethod + def from_str(cls, content: str) -> "OpenAPISpecification": + """ + Create an OpenAPISpecification instance from a string. + """ + try: + loaded_spec = json.loads(content) + except json.JSONDecodeError: + try: + loaded_spec = yaml.safe_load(content) + except yaml.YAMLError as e: + raise ValueError("Content cannot be decoded as JSON or YAML: " + str(e)) from e + return cls(loaded_spec) + + @classmethod + def from_file(cls, spec_file: Union[str, Path]) -> "OpenAPISpecification": + """ + Create an OpenAPISpecification instance from a file. + """ + with open(spec_file, encoding="utf-8") as file: + content = file.read() + return cls.from_str(content) + + @classmethod + def from_url(cls, url: str) -> "OpenAPISpecification": + """ + Create an OpenAPISpecification instance from a URL. + """ + try: + response = requests.get(url, timeout=10) + response.raise_for_status() + content = response.text + except requests.RequestException as e: + raise ConnectionError(f"Failed to fetch the specification from URL: {url}. {e!s}") from e + return cls.from_str(content) + + def get_name(self) -> str: + """ + Get the title of the OpenAPI specification. + """ + return self.spec_dict.get("info", {}).get("title", "") + + def get_paths(self) -> Dict[str, Dict[str, Any]]: + """ + Get the paths from the OpenAPI specification. + """ + return self.spec_dict.get("paths", {}) + + def get_operation(self, path: str, method: Optional[str] = None) -> Operation: + """ + Retrieve an operation from the OpenAPI specification. + """ + path_item = self.get_paths().get(path, {}) + return self.get_operation_item(path, path_item, method) + + def find_operation_by_path_substring(self, path_partial: str, method: Optional[str] = None) -> Operation: + """ + Find an operation by a substring of the path. + """ + for path, path_item in self.get_paths().items(): + if path_partial in path: + return self.get_operation_item(path, path_item, method) + raise ValueError(f"No operation found with path containing {path_partial}") + + def find_operation_by_id(self, op_id: str, method: Optional[str] = None) -> Operation: + """ + Find an operation by operationId. + """ + for path, path_item in self.get_paths().items(): + op: Operation = self.get_operation_item(path, path_item, method) + if op_id in op.get_field("operationId", ""): + return self.get_operation_item(path, path_item, method) + raise ValueError(f"No operation found with operationId {op_id}") + + def get_operation_item(self, path: str, path_item: Dict[str, Any], method: Optional[str] = None) -> Operation: + """ + Get an operation item from the OpenAPI specification. + + :param path: The path of the operation. + :param path_item: The path item from the OpenAPI specification. + :param method: The HTTP method of the operation. + """ + if method: + operation_dict = path_item.get(method.lower(), {}) + if not operation_dict: + raise ValueError(f"No operation found for method {method} at path {path}") + return Operation(path, method.lower(), operation_dict, self.spec_dict) + if len(path_item) == 1: + method, operation_dict = next(iter(path_item.items())) + return Operation(path, method, operation_dict, self.spec_dict) + if len(path_item) > 1: + raise ValueError(f"Multiple operations found at path {path}, method parameter is required.") + raise ValueError(f"No operations found at path {path}.") + + def get_operations(self) -> List[Operation]: + """ + Get all operations from the OpenAPI specification. + """ + operations = [] + for path, path_item in self.get_paths().items(): + for method, operation_dict in path_item.items(): + if method.lower() in VALID_HTTP_METHODS: + operations.append(Operation(path, method, operation_dict, self.spec_dict)) + return operations + + def get_security_schemes(self) -> Dict[str, Dict[str, Any]]: + """ + Get the security schemes from the OpenAPI specification. + """ + components = self.spec_dict.get("components", {}) + return components.get("securitySchemes", {}) + + def to_dict(self, *, resolve_references: Optional[bool] = False) -> Dict[str, Any]: + """ + Converts the OpenAPI specification to a dictionary format. + + Optionally resolves all $ref references within the spec, returning a fully resolved specification + dictionary if `resolve_references` is set to True. + + :param resolve_references: If True, resolve references in the specification. + :return: A dictionary representation of the OpenAPI specification, optionally fully resolved. + """ + return jsonref.replace_refs(self.spec_dict, proxies=False) if resolve_references else self.spec_dict + + +class ClientConfiguration: + """ Configuration for the OpenAPI client. """ + + def __init__( # noqa: PLR0913 pylint: disable=too-many-arguments + self, + openapi_spec: Union[str, Path, Dict[str, Any]], + credentials: Optional[Union[str, Dict[str, Any], AuthenticationStrategy]] = None, + http_client: Optional[HttpClient] = None, + http_client_config: Optional[HttpClientConfig] = None, + llm_provider: Optional[str] = None, + ): # noqa: PLR0913 + if isinstance(openapi_spec, (str, Path)) and os.path.isfile(openapi_spec): + self.openapi_spec = OpenAPISpecification.from_file(openapi_spec) + elif isinstance(openapi_spec, dict): + self.openapi_spec = OpenAPISpecification.from_dict(openapi_spec) + elif isinstance(openapi_spec, str): + if self.is_valid_http_url(openapi_spec): + self.openapi_spec = OpenAPISpecification.from_url(openapi_spec) + else: + self.openapi_spec = OpenAPISpecification.from_str(openapi_spec) + else: + raise ValueError("Invalid OpenAPI specification format. Expected file path or dictionary.") + + self.credentials = credentials + self.http_client = http_client or HttpClient(http_client_config) + self.http_client_config = http_client_config or HttpClientConfig() + self.llm_provider = llm_provider or "openai" + + def get_openapi_spec(self) -> OpenAPISpecification: + """ + Get the OpenAPI specification. + """ + return self.openapi_spec + + def get_http_client(self) -> HttpClient: + """ + Get the HTTP client. + """ + return self.http_client + + def get_http_client_config(self) -> HttpClientConfig: + """ + Get the HTTP client configuration. + """ + return self.http_client_config + + def get_auth_config(self) -> AuthenticationStrategy: + """ + Get the authentication configuration. + """ + if not self.credentials: + return PassThroughAuthentication() + if isinstance(self.credentials, AuthenticationStrategy): + return self.credentials + security_schemes = self.openapi_spec.get_security_schemes() + if isinstance(self.credentials, str): + return self._create_authentication_from_string(self.credentials, security_schemes) + if isinstance(self.credentials, dict): + return self._create_authentication_from_dict(self.credentials) + raise ValueError(f"Unsupported credentials type: {type(self.credentials)}") + + def get_tools_definitions(self) -> List[Dict[str, Any]]: + """ + Get the tools definitions used as tools LLM parameter. + """ + provider_to_converter = {"anthropic": anthropic_converter, "cohere": cohere_converter} + converter = provider_to_converter.get(self.llm_provider, openai_converter) + return converter(self.openapi_spec) + + def get_payload_extractor(self): + """ + Get the payload extractor for the LLM provider. + """ + provider_to_arguments_field_name = {"anthropic": "input", "cohere": "parameters"} # add more providers here + # default to OpenAI "arguments" + arguments_field_name = provider_to_arguments_field_name.get(self.llm_provider, "arguments") + return LLMFunctionPayloadExtractor(arguments_field_name=arguments_field_name) + + def _create_authentication_from_string( + self, credentials: str, security_schemes: Dict[str, Any] + ) -> AuthenticationStrategy: + for scheme in security_schemes.values(): + if scheme["type"] == "apiKey": + return ApiKeyAuthentication(api_key=credentials) + if scheme["type"] == "http": + return HTTPAuthentication(token=credentials) + if scheme["type"] == "oauth2": + raise NotImplementedError("OAuth2 authentication is not yet supported.") + raise ValueError(f"Unable to create authentication from provided credentials: {credentials}") + + def _create_authentication_from_dict(self, credentials: Dict[str, Any]) -> AuthenticationStrategy: + if "username" in credentials and "password" in credentials: + return HTTPAuthentication(username=credentials["username"], password=credentials["password"]) + if "api_key" in credentials: + return ApiKeyAuthentication(api_key=credentials["api_key"]) + if "token" in credentials: + return HTTPAuthentication(token=credentials["token"]) + if "access_token" in credentials: + raise NotImplementedError("OAuth2 authentication is not yet supported.") + raise ValueError("Unable to create authentication from provided credentials: {credentials}") + + def is_valid_http_url(self, url: str) -> bool: + """Check if a URL is a valid HTTP/HTTPS URL.""" + r = urlparse(url) + return all([r.scheme in ["http", "https"], r.netloc]) + + +class LLMFunctionPayloadExtractor: + """ + Implements a recursive search for extracting LLM generated function payloads. + """ + def __init__(self, arguments_field_name: str): + self.arguments_field_name = arguments_field_name + + def extract_function_invocation(self, payload: Any) -> Dict[str, Any]: + """ + Extract the function invocation details from the payload. + """ + fields_and_values = self._search(payload) + if fields_and_values: + arguments = fields_and_values.get(self.arguments_field_name) + if not isinstance(arguments, (str, dict)): + raise ValueError( + f"Invalid {self.arguments_field_name} type {type(arguments)} for function call, expected str/dict" + ) + return { + "name": fields_and_values.get("name"), + "arguments": json.loads(arguments) if isinstance(arguments, str) else arguments, + } + return {} + + def _required_fields(self) -> List[str]: + return ["name", self.arguments_field_name] + + def _search(self, payload: Any) -> Dict[str, Any]: + if self._is_primitive(payload): + return {} + if dict_converter := self._get_dict_converter(payload): + payload = dict_converter() + elif dataclasses.is_dataclass(payload): + payload = dataclasses.asdict(payload) + if isinstance(payload, dict): + if all(field in payload for field in self._required_fields()): + # this is the payload we are looking for + return payload + for value in payload.values(): + result = self._search(value) + if result: + return result + elif isinstance(payload, list): + for item in payload: + result = self._search(item) + if result: + return result + return {} + + def _get_dict_converter( + self, obj: Any, method_names: Optional[List[str]] = None + ) -> Union[Callable[[], Dict[str, Any]], None]: + method_names = method_names or ["model_dump", "dict"] # search for pydantic v2 then v1 + for attr in method_names: + if hasattr(obj, attr) and callable(getattr(obj, attr)): + return getattr(obj, attr) + return None + + def _is_primitive(self, obj) -> bool: + return isinstance(obj, (int, float, str, bool, type(None))) + + +class ClientConfigurationBuilder: + """ + ClientConfigurationBuilder provides a fluent interface for constructing a `ClientConfiguration`. + + This builder allows for the step-by-step configuration of all necessary components to interact with an + API defined by an OpenAPI specification. + """ + + def __init__(self): + self._openapi_spec: Union[str, Path, Dict[str, Any], None] = None + self._credentials: Optional[Union[str, Dict[str, Any], AuthenticationStrategy]] = None + self._http_client: Optional[HttpClient] = None + self._http_client_config: Optional[HttpClientConfig] = None + self._llm_provider: Optional[str] = None + + def with_openapi_spec(self, openapi_spec: Union[str, Path, Dict[str, Any]]) -> "ClientConfigurationBuilder": + """ + Sets the OpenAPI specification for the configuration. + + :param openapi_spec: The OpenAPI specification as a URL, file path, or dictionary. + :return: The instance of this builder to allow for method chaining. + """ + self._openapi_spec = openapi_spec + return self + + def with_credentials( + self, credentials: Union[str, Dict[str, Any], AuthenticationStrategy] + ) -> "ClientConfigurationBuilder": + """ + Specifies the credentials used for authenticating requests made by the client. + + :param credentials: Credentials as a string, dictionary, or an AuthenticationStrategy instance. + :return: The instance of this builder to allow for method chaining. + """ + self._credentials = credentials + return self + + def with_http_client(self, http_client: HttpClient) -> "ClientConfigurationBuilder": + """ + Specifies the HTTP client to be used for making API calls. + + :param http_client: The HTTP client implementation. + :return: The instance of this builder to allow for method chaining. + """ + self._http_client = http_client + return self + + def with_http_client_config(self, http_client_config: HttpClientConfig) -> "ClientConfigurationBuilder": + """ + Specifies the HTTP client configuration. + + If not set, the default configuration is used. + + :param http_client_config: Configuration settings for the HTTP client. + :return: The instance of this builder to allow for method chaining. + """ + self._http_client_config = http_client_config + return self + + def with_provider(self, llm_provider: str) -> "ClientConfigurationBuilder": + """ + Specifies the Large Language Model (LLM) provider to be used for generating function calls. + + :param llm_provider: The LLM provider name. + :return: The instance of this builder to allow for method chaining. + """ + self._llm_provider = llm_provider + return self + + def build(self) -> ClientConfiguration: + """ + Constructs a `ClientConfiguration` instance using the settings provided. + + It validates that an OpenAPI specification has been set before proceeding with the build. + + :return: A configured instance of ClientConfiguration. + :raises ValueError: If the OpenAPI specification is not set. + """ + if self._openapi_spec is None: + raise ValueError("OpenAPI specification must be provided to build a configuration.") + + return ClientConfiguration( + openapi_spec=self._openapi_spec, + credentials=self._credentials, + http_client=self._http_client, + http_client_config=self._http_client_config, + llm_provider=self._llm_provider or "openai", + ) + + +class RequestBuilder: + """ Builds an HTTP request based on an OpenAPI operation""" + def __init__(self, client_config: ClientConfiguration): + self.openapi_parser = client_config.get_openapi_spec() + self.http_client = client_config.get_http_client() + self.auth_config = client_config.get_auth_config() or PassThroughAuthentication() + + def build_request(self, operation: Operation, **kwargs) -> Any: + """ + Build an HTTP request based on the operation and arguments provided. + """ + url = self._build_url(operation, **kwargs) + method = operation.method.lower() + headers = self._build_headers(operation) + query_params = self._build_query_params(operation, **kwargs) + body = self._build_request_body(operation, **kwargs) + request = { + "url": url, + "method": method, + "headers": headers, + "params": query_params, + "json": body, + } + self._apply_authentication(operation, request) + return request + + def _build_headers(self, operation: Operation, **kwargs) -> Dict[str, str]: + headers = {} + for parameter in operation.get_parameters("header"): + param_value = kwargs.get(parameter["name"], None) + if param_value: + headers[parameter["name"]] = str(param_value) + elif parameter.get("required", False): + raise ValueError(f"Missing required header parameter: {parameter['name']}") + return headers + + def _build_url(self, operation: Operation, **kwargs) -> str: + server_url = operation.get_server_url() + path = operation.path + for parameter in operation.get_parameters("path"): + param_value = kwargs.get(parameter["name"], None) + if param_value: + path = path.replace(f"{{{parameter['name']}}}", str(param_value)) + elif parameter.get("required", False): + raise ValueError(f"Missing required path parameter: {parameter['name']}") + return server_url + path + + def _build_query_params(self, operation: Operation, **kwargs) -> Dict[str, Any]: + query_params = {} + + # Simplify query parameter assembly using _get_parameter_value + for parameter in operation.get_parameters("query"): + param_value = kwargs.get(parameter["name"], None) + if param_value: + query_params[parameter["name"]] = param_value + elif parameter.get("required", False): + raise ValueError(f"Missing required query parameter: {parameter['name']}") + return query_params + + def _build_request_body(self, operation: Operation, **kwargs) -> Any: + request_body = operation.get_request_body() + if request_body: + content = request_body.get("content", {}) + if "application/json" in content: + return {**kwargs} + raise NotImplementedError("Request body content type not supported") + return None + + def _apply_authentication(self, operation: Operation, request: Dict[str, Any]): + # security requirements specify which authentication scheme to apply (the "what/which") + security_requirements = operation.get_security_requirements() + # security schemes define how to authenticate (the "how") + security_schemes = operation.spec_dict.get("components", {}).get("securitySchemes", {}) + if security_requirements: + for requirement in security_requirements: + for scheme_name in requirement: + if scheme_name in security_schemes: + security_scheme = security_schemes[scheme_name] + self.auth_config.apply_auth(security_scheme, request) + break + + +class OpenAPIServiceClient: + """ + A client for invoking operations on REST services defined by OpenAPI specifications. + + Together with the `ClientConfiguration`, its `ClientConfigurationBuilder`, the `OpenAPIServiceClient` + simplifies the process of (LLMs) with services defined by OpenAPI specifications. + """ + + def __init__(self, client_config: ClientConfiguration): + self.openapi_spec = client_config.get_openapi_spec() + self.http_client = client_config.get_http_client() + self.request_builder = RequestBuilder(client_config) + self.payload_extractor = client_config.get_payload_extractor() + + def invoke(self, function_payload: Any) -> Any: + """ + Invokes a function specified in the function payload. + + :param function_payload: The function payload containing the details of the function to be invoked. + :returns: The response from the service after invoking the function. + :raises OpenAPIClientError: If the function invocation payload cannot be extracted from the function payload. + :raises HttpClientError: If an error occurs while sending the request and receiving the response. + """ + fn_invocation_payload = self.payload_extractor.extract_function_invocation(function_payload) + if not fn_invocation_payload: + raise OpenAPIClientError( + f"Failed to extract function invocation payload from {function_payload} using " + f"{self.payload_extractor.__class__.__name__}. Ensure the payload format matches the expected " + "structure for the designated LLM extractor." + ) + # fn_invocation_payload, if not empty, guaranteed to have "name" and "arguments" keys from here on + operation = self.openapi_spec.find_operation_by_id(fn_invocation_payload.get("name")) + request = self.request_builder.build_request(operation, **fn_invocation_payload.get("arguments")) + return self.http_client.send_request(request) + + +class OpenAPIClientError(Exception): + """Exception raised for errors in the OpenAPI client.""" + + +def openai_converter(schema: OpenAPISpecification) -> List[Dict[str, Any]]: + """ + Converts OpenAPI specification to a list of function suitable for OpenAI LLM function calling. + + :param schema: The OpenAPI specification to convert. + :return: A list of dictionaries, each representing a function definition. + """ + resolved_schema = jsonref.replace_refs(schema.spec_dict) + fn_definitions = _openapi_to_functions(resolved_schema, "parameters", _parse_endpoint_spec_openai) + return [{"type": "function", "function": fn} for fn in fn_definitions] + + +def anthropic_converter(schema: OpenAPISpecification) -> List[Dict[str, Any]]: + """ + Converts an OpenAPI specification to a list of function definitions for Anthropic LLM function calling. + + :param schema: The OpenAPI specification to convert. + :return: A list of dictionaries, each representing a function definition. + """ + resolved_schema = jsonref.replace_refs(schema.spec_dict) + return _openapi_to_functions(resolved_schema, "input_schema", _parse_endpoint_spec_openai) + + +def cohere_converter(schema: OpenAPISpecification) -> List[Dict[str, Any]]: + """ + Converts an OpenAPI specification to a list of function definitions for Cohere LLM function calling. + + :param schema: The OpenAPI specification to convert. + :return: A list of dictionaries, each representing a function definition. + """ + resolved_schema = jsonref.replace_refs(schema.spec_dict) + return _openapi_to_functions(resolved_schema, "not important for cohere", _parse_endpoint_spec_cohere) + + +def _openapi_to_functions(service_openapi_spec: Dict[str, Any], parameters_name: str, + parse_endpoint_fn: Callable[[Dict[str, Any], str], Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + Extracts functions from the OpenAPI specification, converts them into a function schema. + """ + + # Doesn't enforce rigid spec validation because that would require a lot of dependencies + # We check the version and require minimal fields to be present, so we can extract functions + spec_version = service_openapi_spec.get("openapi") + if not spec_version: + raise ValueError(f"Invalid OpenAPI spec provided. Could not extract version from {service_openapi_spec}") + service_openapi_spec_version = int(spec_version.split(".")[0]) + # Compare the versions + if service_openapi_spec_version < MIN_REQUIRED_OPENAPI_SPEC_VERSION: + raise ValueError( + f"Invalid OpenAPI spec version {service_openapi_spec_version}. Must be " + f"at least {MIN_REQUIRED_OPENAPI_SPEC_VERSION}." + ) + functions: List[Dict[str, Any]] = [] + for paths in service_openapi_spec["paths"].values(): + for path_spec in paths.values(): + function_dict = parse_endpoint_fn(path_spec, parameters_name) + if function_dict: + functions.append(function_dict) + return functions + + +def _parse_endpoint_spec_openai(resolved_spec: Dict[str, Any], parameters_name: str) -> Dict[str, Any]: + """ + Parses an OpenAPI endpoint specification for OpenAI. + """ + if not isinstance(resolved_spec, dict): + logger.warning("Invalid OpenAPI spec format provided. Could not extract function.") + return {} + function_name = resolved_spec.get("operationId") + description = resolved_spec.get("description") or resolved_spec.get("summary", "") + schema: Dict[str, Any] = {"type": "object", "properties": {}} + # requestBody section + req_body_schema = ( + resolved_spec.get("requestBody", {}).get("content", {}).get("application/json", {}).get("schema", {}) + ) + if "properties" in req_body_schema: + for prop_name, prop_schema in req_body_schema["properties"].items(): + schema["properties"][prop_name] = _parse_property_attributes(prop_schema) + if "required" in req_body_schema: + schema.setdefault("required", []).extend(req_body_schema["required"]) + + # parameters section + for param in resolved_spec.get("parameters", []): + if "schema" in param: + schema_dict = _parse_property_attributes(param["schema"]) + # these attributes are not in param[schema] level but on param level + useful_attributes = ["description", "pattern", "enum"] + schema_dict.update({key: param[key] for key in useful_attributes if param.get(key)}) + schema["properties"][param["name"]] = schema_dict + if param.get("required", False): + schema.setdefault("required", []).append(param["name"]) + + if function_name and description and schema["properties"]: + return {"name": function_name, "description": description, parameters_name: schema} + logger.warning("Invalid OpenAPI spec format provided. Could not extract function from %s", resolved_spec) + return {} + + +def _parse_property_attributes(property_schema: Dict[str, Any], include_attributes: Optional[List[str]] = None + ) -> Dict[str, Any]: + """ + Recursively parses the attributes of a property schema. + """ + include_attributes = include_attributes or ["description", "pattern", "enum"] + schema_type = property_schema.get("type") + parsed_schema = {"type": schema_type} if schema_type else {} + for attr in include_attributes: + if attr in property_schema: + parsed_schema[attr] = property_schema[attr] + if schema_type == "object": + properties = property_schema.get("properties", {}) + parsed_properties = { + prop_name: _parse_property_attributes(prop, include_attributes) + for prop_name, prop in properties.items() + } + parsed_schema["properties"] = parsed_properties + if "required" in property_schema: + parsed_schema["required"] = property_schema["required"] + elif schema_type == "array": + items = property_schema.get("items", {}) + parsed_schema["items"] = _parse_property_attributes(items, include_attributes) + return parsed_schema + + +def _parse_endpoint_spec_cohere(operation: Dict[str, Any], ignored_param: str) -> Dict[str, Any]: + """ + Parses an endpoint specification for Cohere. + """ + function_name = operation.get("operationId") + description = operation.get("description") or operation.get("summary", "") + parameter_definitions = _parse_parameters(operation) + if function_name: + return { + "name": function_name, + "description": description, + "parameter_definitions": parameter_definitions, + } + logger.warning("Operation missing operationId, cannot create function definition.") + return {} + + +def _parse_parameters(operation: Dict[str, Any]) -> Dict[str, Any]: + """ + Parses the parameters from an operation specification. + """ + parameters = {} + for param in operation.get("parameters", []): + if "schema" in param: + parameters[param["name"]] = _parse_schema( + param["schema"], param.get("required", False), param.get("description", "") + ) + if "requestBody" in operation: + content = operation["requestBody"].get("content", {}).get("application/json", {}) + if "schema" in content: + schema_properties = content["schema"].get("properties", {}) + required_properties = content["schema"].get("required", []) + for name, schema in schema_properties.items(): + parameters[name] = _parse_schema( + schema, name in required_properties, schema.get("description", "") + ) + return parameters + + +def _parse_schema(schema: Dict[str, Any], required: bool, description: str) -> Dict[str, Any]: # noqa: FBT001 + """ + Parses a schema part of an operation specification. + """ + schema_type = _get_type(schema) + if schema_type == "object": + # Recursive call for complex types + properties = schema.get("properties", {}) + nested_parameters = { + name: _parse_schema( + schema=prop_schema, + required=bool(name in schema.get("required", False)), + description=prop_schema.get("description", ""), + ) + for name, prop_schema in properties.items() + } + return { + "type": schema_type, + "description": description, + "properties": nested_parameters, + "required": required, + } + return {"type": schema_type, "description": description, "required": required} + + +def _get_type(schema: Dict[str, Any]) -> str: + type_mapping = {"integer": "int", "string": "str", "boolean": "bool", "number": "float", "object": "object", + "array": "list"} + schema_type = schema.get("type", "object") + if schema_type not in type_mapping: + raise ValueError(f"Unsupported schema type {schema_type}") + return type_mapping[schema_type] diff --git a/pyproject.toml b/pyproject.toml index 85214d0e..fc3454d7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ classifiers = [ "Programming Language :: Python :: 3.12", "Topic :: Scientific/Engineering :: Artificial Intelligence", ] -dependencies = ["haystack-ai"] +dependencies = ["jsonref", "haystack-ai"] [project.urls] "CI: GitHub" = "https://github.com/deepset-ai/haystack-experimental/actions" "GitHub: issues" = "https://github.com/deepset-ai/haystack-experimental/issues" @@ -39,6 +39,9 @@ dependencies = [ # Test "pytest", "pytest-cov", + "fastapi", + "cohere", + "anthropic", # Linting "pylint", "ruff", diff --git a/test/components/__init__.py b/test/components/__init__.py new file mode 100644 index 00000000..c1764a6e --- /dev/null +++ b/test/components/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 diff --git a/test/components/connectors/__init__.py b/test/components/connectors/__init__.py new file mode 100644 index 00000000..c1764a6e --- /dev/null +++ b/test/components/connectors/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 diff --git a/test/components/connectors/test_openapi.py b/test/components/connectors/test_openapi.py new file mode 100644 index 00000000..f9b5e413 --- /dev/null +++ b/test/components/connectors/test_openapi.py @@ -0,0 +1,114 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +import json +import os +from unittest.mock import patch + +import pytest + +from haystack_experimental.components.connectors import OpenAPIServiceConnector +from haystack.dataclasses import ChatMessage + + +class TestOpenAPIServiceConnector: + @pytest.fixture + def setup_mock(self): + with patch("haystack_experimental.components.connectors.openapi.OpenAPIServiceClient") as mock_client: + mock_client_instance = mock_client.return_value + mock_client_instance.invoke.return_value = {"service_response": "Yes, he was fired and rehired"} + yield mock_client_instance + + def test_init(self): + service_connector = OpenAPIServiceConnector() + assert service_connector is not None + assert service_connector.llm_provider == "openai" + + def test_init_with_anthropic_provider(self): + service_connector = OpenAPIServiceConnector(provider="anthropic") + assert service_connector is not None + assert service_connector.llm_provider == "anthropic" + + def test_run_with_mock(self, setup_mock, test_files_path): + fc_payload = [ + { + "function": {"arguments": '{"q": "Why was Sam Altman ousted from OpenAI?"}', "name": "search"}, + "id": "call_PmEBYvZ7mGrQP5PUASA5m9wO", + "type": "function", + } + ] + with open(os.path.join(test_files_path, "json/serperdev_openapi_spec.json"), "r") as file: + serperdev_openapi_spec = json.load(file) + + service_connector = OpenAPIServiceConnector() + result = service_connector.run( + messages=[ChatMessage.from_assistant(json.dumps(fc_payload))], + service_openapi_spec=serperdev_openapi_spec, + service_credentials="fake_api_key", + ) + + assert "service_response" in result + assert len(result["service_response"]) == 1 + assert isinstance(result["service_response"][0], ChatMessage) + response_content = json.loads(result["service_response"][0].content) + assert response_content == {"service_response": "Yes, he was fired and rehired"} + + # verify invocation payload + setup_mock.invoke.assert_called_once() + invocation_payload = [ + { + "function": {"arguments": '{"q": "Why was Sam Altman ousted from OpenAI?"}', "name": "search"}, + "id": "call_PmEBYvZ7mGrQP5PUASA5m9wO", + "type": "function", + } + ] + setup_mock.invoke.assert_called_with(invocation_payload) + + @pytest.mark.integration + @pytest.mark.skipif("SERPERDEV_API_KEY" not in os.environ, reason="SerperDev API key is not available") + def test_run(self, test_files_path): + fc_payload = [ + { + "function": {"arguments": '{"q": "Why was Sam Altman ousted from OpenAI?"}', "name": "search"}, + "id": "call_PmEBYvZ7mGrQP5PUASA5m9wO", + "type": "function", + } + ] + + with open(os.path.join(test_files_path, "json/serperdev_openapi_spec.json"), "r") as file: + serperdev_openapi_spec = json.load(file) + + service_connector = OpenAPIServiceConnector() + result = service_connector.run( + messages=[ChatMessage.from_assistant(json.dumps(fc_payload))], + service_openapi_spec=serperdev_openapi_spec, + service_credentials=os.environ["SERPERDEV_API_KEY"], + ) + assert "service_response" in result + assert len(result["service_response"]) == 1 + assert isinstance(result["service_response"][0], ChatMessage) + response_text = result["service_response"][0].content + assert "Sam" in response_text or "Altman" in response_text + + @pytest.mark.integration + def test_run_no_credentials(self, test_files_path): + fc_payload = [ + { + "function": {"arguments": '{"q": "Why was Sam Altman ousted from OpenAI?"}', "name": "search"}, + "id": "call_PmEBYvZ7mGrQP5PUASA5m9wO", + "type": "function", + } + ] + + with open(os.path.join(test_files_path, "json/serperdev_openapi_spec.json"), "r") as file: + serperdev_openapi_spec = json.load(file) + + service_connector = OpenAPIServiceConnector() + result = service_connector.run( + messages=[ChatMessage.from_assistant(json.dumps(fc_payload))], service_openapi_spec=serperdev_openapi_spec + ) + assert "service_response" in result + assert len(result["service_response"]) == 1 + assert isinstance(result["service_response"][0], ChatMessage) + response_text = result["service_response"][0].content + assert "403" in response_text diff --git a/test/components/converters/__init__.py b/test/components/converters/__init__.py new file mode 100644 index 00000000..c1764a6e --- /dev/null +++ b/test/components/converters/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 diff --git a/test/components/converters/test_openapi.py b/test/components/converters/test_openapi.py new file mode 100644 index 00000000..fabb4624 --- /dev/null +++ b/test/components/converters/test_openapi.py @@ -0,0 +1,246 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +import json +import sys +import tempfile + +import pytest + +from haystack_experimental.components.converters import OpenAPIServiceToFunctions +from haystack.dataclasses import ByteStream + + +@pytest.fixture +def json_serperdev_openapi_spec(): + serper_spec = """ + { + "openapi": "3.0.0", + "info": { + "title": "SerperDev", + "version": "1.0.0", + "description": "API for performing search queries" + }, + "servers": [ + { + "url": "https://google.serper.dev" + } + ], + "paths": { + "/search": { + "post": { + "operationId": "search", + "description": "Search the web with Google", + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "q": { + "type": "string" + } + } + } + } + } + }, + "responses": { + "200": { + "description": "Successful response", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "searchParameters": { + "type": "undefined" + }, + "knowledgeGraph": { + "type": "undefined" + }, + "answerBox": { + "type": "undefined" + }, + "organic": { + "type": "undefined" + }, + "topStories": { + "type": "undefined" + }, + "peopleAlsoAsk": { + "type": "undefined" + }, + "relatedSearches": { + "type": "undefined" + } + } + } + } + } + } + }, + "security": [ + { + "apikey": [] + } + ] + } + } + }, + "components": { + "securitySchemes": { + "apikey": { + "type": "apiKey", + "name": "x-api-key", + "in": "header" + } + } + } + } + """ + return serper_spec + + +@pytest.fixture +def yaml_serperdev_openapi_spec(): + serper_spec = """ + openapi: 3.0.0 + info: + title: SerperDev + version: 1.0.0 + description: API for performing search queries + servers: + - url: 'https://google.serper.dev' + paths: + /search: + post: + operationId: search + description: Search the web with Google + requestBody: + required: true + content: + application/json: + schema: + type: object + properties: + q: + type: string + responses: + '200': + description: Successful response + content: + application/json: + schema: + type: object + properties: + searchParameters: + type: undefined + knowledgeGraph: + type: undefined + answerBox: + type: undefined + organic: + type: undefined + topStories: + type: undefined + peopleAlsoAsk: + type: undefined + relatedSearches: + type: undefined + security: + - apikey: [] + components: + securitySchemes: + apikey: + type: apiKey + name: x-api-key + in: header + """ + return serper_spec + + +@pytest.fixture +def fn_definition_transform(): + return lambda function_def: {"type": "function", "function": function_def} + + +class TestOpenAPIServiceToFunctions: + # test we can extract functions from openapi spec given + def test_run_with_bytestream_source(self, json_serperdev_openapi_spec, fn_definition_transform): + service = OpenAPIServiceToFunctions() + spec_stream = ByteStream.from_string(json_serperdev_openapi_spec) + result = service.run(sources=[spec_stream]) + assert len(result["functions"]) == 1 + fc = result["functions"][0] + + # check that fc definition is as expected + assert fc == fn_definition_transform( + { + "name": "search", + "description": "Search the web with Google", + "parameters": {"type": "object", "properties": {"q": {"type": "string"}}}, + } + ) + + @pytest.mark.skipif( + sys.platform in ["win32", "cygwin"], + reason="Can't run on Windows Github CI, need access temp file but windows does not allow it", + ) + def test_run_with_file_source(self, json_serperdev_openapi_spec, fn_definition_transform): + # test we can extract functions from openapi spec given in file + service = OpenAPIServiceToFunctions() + # write the spec to NamedTemporaryFile and check that it is parsed correctly + with tempfile.NamedTemporaryFile() as tmp: + tmp.write(json_serperdev_openapi_spec.encode("utf-8")) + tmp.seek(0) + result = service.run(sources=[tmp.name]) + assert len(result["functions"]) == 1 + fc = result["functions"][0] + + # check that fc definition is as expected + assert fc == fn_definition_transform( + { + "name": "search", + "description": "Search the web with Google", + "parameters": {"type": "object", "properties": {"q": {"type": "string"}}}, + } + ) + + def test_run_with_invalid_bytestream_source(self, caplog): + # test invalid source + service = OpenAPIServiceToFunctions() + with pytest.raises(ValueError, match="Invalid OpenAPI specification"): + service.run(sources=[ByteStream.from_string("")]) + + def test_complex_types_conversion(self, test_files_path, fn_definition_transform): + # ensure that complex types from OpenAPI spec are converted to the expected format in OpenAI function calling + service = OpenAPIServiceToFunctions() + result = service.run(sources=[test_files_path / "json" / "complex_types_openapi_service.json"]) + assert len(result["functions"]) == 1 + + with open(test_files_path / "json" / "complex_types_openai_spec.json") as openai_spec_file: + desired_output = json.load(openai_spec_file) + assert result["functions"][0] == fn_definition_transform(desired_output) + + def test_simple_and_complex_at_once(self, test_files_path, json_serperdev_openapi_spec, fn_definition_transform): + # ensure multiple functions are extracted from multiple paths in OpenAPI spec + service = OpenAPIServiceToFunctions() + sources = [ + ByteStream.from_string(json_serperdev_openapi_spec), + test_files_path / "json" / "complex_types_openapi_service.json", + ] + result = service.run(sources=sources) + assert len(result["functions"]) == 2 + + with open(test_files_path / "json" / "complex_types_openai_spec.json") as openai_spec_file: + desired_output = json.load(openai_spec_file) + assert result["functions"][0] == fn_definition_transform( + { + "name": "search", + "description": "Search the web with Google", + "parameters": {"type": "object", "properties": {"q": {"type": "string"}}}, + } + ) + assert result["functions"][1] == fn_definition_transform(desired_output) diff --git a/test/test_files/json/complex_types_openai_spec.json b/test/test_files/json/complex_types_openai_spec.json new file mode 100644 index 00000000..ebaf9556 --- /dev/null +++ b/test/test_files/json/complex_types_openai_spec.json @@ -0,0 +1,64 @@ +{ + "name": "processPayment", + "description": "Process a new payment using the specified payment method", + "parameters": { + "type": "object", + "properties": { + "transaction_amount": { + "type": "number", + "description": "The amount to be paid" + }, + "description": { + "type": "string", + "description": "A brief description of the payment" + }, + "payment_method_id": { + "type": "string", + "description": "The payment method to be used" + }, + "payer": { + "type": "object", + "description": "Information about the payer, including their name, email, and identification number", + "properties": { + "name": { + "type": "string", + "description": "The payer's name" + }, + "email": { + "type": "string", + "description": "The payer's email address" + }, + "identification": { + "type": "object", + "description": "The payer's identification number", + "properties": { + "type": { + "type": "string", + "description": "The type of identification document (e.g., CPF, CNPJ)" + }, + "number": { + "type": "string", + "description": "The identification number" + } + }, + "required": [ + "type", + "number" + ] + } + }, + "required": [ + "name", + "email", + "identification" + ] + } + }, + "required": [ + "transaction_amount", + "description", + "payment_method_id", + "payer" + ] + } +} diff --git a/test/test_files/json/complex_types_openapi_service.json b/test/test_files/json/complex_types_openapi_service.json new file mode 100644 index 00000000..3ea04f8a --- /dev/null +++ b/test/test_files/json/complex_types_openapi_service.json @@ -0,0 +1,103 @@ +{ + "openapi": "3.0.0", + "info": { + "title": "Payment API", + "version": "1.0.0" + }, + "servers": [ + { + "url": "http://localhost:8080" + } + ], + "paths": { + "/new_payment": { + "post": { + "summary": "Process a new payment", + "description": "Process a new payment using the specified payment method", + "operationId": "processPayment", + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "transaction_amount": { + "type": "number", + "description": "The amount to be paid" + }, + "description": { + "type": "string", + "description": "A brief description of the payment" + }, + "payment_method_id": { + "type": "string", + "description": "The payment method to be used" + }, + "payer": { + "$ref": "#/components/schemas/Payer" + } + }, + "required": [ + "transaction_amount", + "description", + "payment_method_id", + "payer" + ] + } + } + } + }, + "responses": { + "200": { + "description": "Payment processed successfully" + }, + "400": { + "description": "Invalid request" + } + } + } + } + }, + "components": { + "schemas": { + "Payer": { + "type": "object", + "description": "Information about the payer, including their name, email, and identification number", + "properties": { + "name": { + "type": "string", + "description": "The payer's name" + }, + "email": { + "type": "string", + "description": "The payer's email address" + }, + "identification": { + "type": "object", + "description": "The payer's identification number", + "properties": { + "type": { + "type": "string", + "description": "The type of identification document (e.g., CPF, CNPJ)" + }, + "number": { + "type": "string", + "description": "The identification number" + } + }, + "required": [ + "type", + "number" + ] + } + }, + "required": [ + "name", + "email", + "identification" + ] + } + } + } +} diff --git a/test/test_files/json/openapi_order_service.json b/test/test_files/json/openapi_order_service.json new file mode 100644 index 00000000..3e24661a --- /dev/null +++ b/test/test_files/json/openapi_order_service.json @@ -0,0 +1,109 @@ +{ + "openapi": "3.0.0", + "info": { + "title": "Order Service", + "version": "1.0.0" + }, + "servers": [{"url": "http://localhost"}], + "paths": { + "/orders": { + "post": { + "summary": "Create a new order", + "operationId": "createOrder", + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/Order" + } + } + } + }, + "responses": { + "201": { + "description": "Created", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/OrderResponse" + } + } + } + } + } + } + } + }, + "components": { + "schemas": { + "Order": { + "type": "object", + "properties": { + "customer": { + "$ref": "#/components/schemas/Customer" + }, + "items": { + "type": "array", + "items": { + "$ref": "#/components/schemas/OrderItem" + } + } + }, + "required": [ + "customer", + "items" + ] + }, + "Customer": { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "email": { + "type": "string" + } + }, + "required": [ + "name", + "email" + ] + }, + "OrderItem": { + "type": "object", + "properties": { + "product": { + "type": "string" + }, + "quantity": { + "type": "integer" + } + }, + "required": [ + "product", + "quantity" + ] + }, + "OrderResponse": { + "type": "object", + "properties": { + "orderId": { + "type": "string" + }, + "status": { + "type": "string" + }, + "totalAmount": { + "type": "number" + } + }, + "required": [ + "orderId", + "status", + "totalAmount" + ] + } + } + } +} \ No newline at end of file diff --git a/test/test_files/json/serperdev_openapi_spec.json b/test/test_files/json/serperdev_openapi_spec.json new file mode 100644 index 00000000..123c993a --- /dev/null +++ b/test/test_files/json/serperdev_openapi_spec.json @@ -0,0 +1,62 @@ +{ + "openapi": "3.0.0", + "info": { + "title": "SerperDev", + "version": "1.0.0", + "description": "API for performing search queries" + }, + "servers": [ + { + "url": "https://google.serper.dev" + } + ], + "paths": { + "/search": { + "post": { + "operationId": "search", + "description": "Search the web with Google", + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "q": { + "type": "string" + } + } + } + } + } + }, + "responses": { + "200": { + "description": "Successful response", + "content": { + "application/json": { + "schema": { + "type": "object" + } + } + } + } + }, + "security": [ + { + "apikey": [] + } + ] + } + } + }, + "components": { + "securitySchemes": { + "apikey": { + "type": "apiKey", + "name": "x-api-key", + "in": "header" + } + } + } +} diff --git a/test/test_files/yaml/github_compare.yml b/test/test_files/yaml/github_compare.yml new file mode 100644 index 00000000..e14575b7 --- /dev/null +++ b/test/test_files/yaml/github_compare.yml @@ -0,0 +1,438 @@ +openapi: 3.1.0 +info: + title: Github API + description: Enables interaction with OpenAPI + version: v1.0.0 +servers: + - url: https://api.github.com +paths: + /repos/{owner}/{repo}/compare/{basehead}: + get: + summary: Compare two branches + description: Compares two branches against one another. + tags: + - repos + operationId: compare_branches + externalDocs: + description: API method documentation + url: >- + https://docs.github.com/enterprise-server@3.9/rest/commits/commits#compare-two-commits + parameters: + - name: basehead + description: >- + The base branch and head branch to compare. This parameter expects + the format `BASE...HEAD` + in: path + required: true + x-multi-segment: true + schema: + type: string + - name: owner + description: The repository owner, usually a company or orgnization + in: path + required: true + x-multi-segment: true + schema: + type: string + - name: repo + description: The repository itself, the project + in: path + required: true + x-multi-segment: true + schema: + type: string + responses: + '200': + description: Response + content: + application/json: + schema: + $ref: '#/components/schemas/commit-comparison' + x-github: + githubCloudOnly: false + enabledForGitHubApps: true + category: commits + subcategory: commits +components: + schemas: + commit-comparison: + title: Commit Comparison + description: Commit Comparison + type: object + properties: + url: + type: string + format: uri + example: >- + https://api.github.com/repos/octocat/Hello-World/compare/master...topic + html_url: + type: string + format: uri + example: https://github.com/octocat/Hello-World/compare/master...topic + permalink_url: + type: string + format: uri + example: >- + https://github.com/octocat/Hello-World/compare/octocat:bbcd538c8e72b8c175046e27cc8f907076331401...octocat:0328041d1152db8ae77652d1618a02e57f745f17 + diff_url: + type: string + format: uri + example: https://github.com/octocat/Hello-World/compare/master...topic.diff + patch_url: + type: string + format: uri + example: https://github.com/octocat/Hello-World/compare/master...topic.patch + base_commit: + $ref: '#/components/schemas/commit' + merge_base_commit: + $ref: '#/components/schemas/commit' + status: + type: string + enum: + - diverged + - ahead + - behind + - identical + example: ahead + ahead_by: + type: integer + example: 4 + behind_by: + type: integer + example: 5 + total_commits: + type: integer + example: 6 + commits: + type: array + items: + $ref: '#/components/schemas/commit' + files: + type: array + items: + $ref: '#/components/schemas/diff-entry' + required: + - url + - html_url + - permalink_url + - diff_url + - patch_url + - base_commit + - merge_base_commit + - status + - ahead_by + - behind_by + - total_commits + - commits + nullable-git-user: + title: Git User + description: Metaproperties for Git author/committer information. + type: object + properties: + name: + type: string + example: '"Chris Wanstrath"' + email: + type: string + example: '"chris@ozmm.org"' + date: + type: string + example: '"2007-10-29T02:42:39.000-07:00"' + nullable: true + nullable-simple-user: + title: Simple User + description: A GitHub user. + type: object + properties: + name: + nullable: true + type: string + email: + nullable: true + type: string + login: + type: string + example: octocat + id: + type: integer + example: 1 + node_id: + type: string + example: MDQ6VXNlcjE= + avatar_url: + type: string + format: uri + example: https://github.com/images/error/octocat_happy.gif + gravatar_id: + type: string + example: 41d064eb2195891e12d0413f63227ea7 + nullable: true + url: + type: string + format: uri + example: https://api.github.com/users/octocat + html_url: + type: string + format: uri + example: https://github.com/octocat + followers_url: + type: string + format: uri + example: https://api.github.com/users/octocat/followers + following_url: + type: string + example: https://api.github.com/users/octocat/following{/other_user} + gists_url: + type: string + example: https://api.github.com/users/octocat/gists{/gist_id} + starred_url: + type: string + example: https://api.github.com/users/octocat/starred{/owner}{/repo} + subscriptions_url: + type: string + format: uri + example: https://api.github.com/users/octocat/subscriptions + organizations_url: + type: string + format: uri + example: https://api.github.com/users/octocat/orgs + repos_url: + type: string + format: uri + example: https://api.github.com/users/octocat/repos + events_url: + type: string + example: https://api.github.com/users/octocat/events{/privacy} + received_events_url: + type: string + format: uri + example: https://api.github.com/users/octocat/received_events + type: + type: string + example: User + site_admin: + type: boolean + starred_at: + type: string + example: '"2020-07-09T00:17:55Z"' + required: + - avatar_url + - events_url + - followers_url + - following_url + - gists_url + - gravatar_id + - html_url + - id + - node_id + - login + - organizations_url + - received_events_url + - repos_url + - site_admin + - starred_url + - subscriptions_url + - type + - url + nullable: true + verification: + title: Verification + type: object + properties: + verified: + type: boolean + reason: + type: string + payload: + type: string + nullable: true + signature: + type: string + nullable: true + required: + - verified + - reason + - payload + - signature + diff-entry: + title: Diff Entry + description: Diff Entry + type: object + properties: + sha: + type: string + example: bbcd538c8e72b8c175046e27cc8f907076331401 + filename: + type: string + example: file1.txt + status: + type: string + enum: + - added + - removed + - modified + - renamed + - copied + - changed + - unchanged + example: added + additions: + type: integer + example: 103 + deletions: + type: integer + example: 21 + changes: + type: integer + example: 124 + blob_url: + type: string + format: uri + example: >- + https://github.com/octocat/Hello-World/blob/6dcb09b5b57875f334f61aebed695e2e4193db5e/file1.txt + raw_url: + type: string + format: uri + example: >- + https://github.com/octocat/Hello-World/raw/6dcb09b5b57875f334f61aebed695e2e4193db5e/file1.txt + contents_url: + type: string + format: uri + example: >- + https://api.github.com/repos/octocat/Hello-World/contents/file1.txt?ref=6dcb09b5b57875f334f61aebed695e2e4193db5e + patch: + type: string + example: '@@ -132,7 +132,7 @@ module Test @@ -1000,7 +1000,7 @@ module Test' + previous_filename: + type: string + example: file.txt + required: + - additions + - blob_url + - changes + - contents_url + - deletions + - filename + - raw_url + - sha + - status + commit: + title: Commit + description: Commit + type: object + properties: + url: + type: string + format: uri + example: >- + https://api.github.com/repos/octocat/Hello-World/commits/6dcb09b5b57875f334f61aebed695e2e4193db5e + sha: + type: string + example: 6dcb09b5b57875f334f61aebed695e2e4193db5e + node_id: + type: string + example: MDY6Q29tbWl0NmRjYjA5YjViNTc4NzVmMzM0ZjYxYWViZWQ2OTVlMmU0MTkzZGI1ZQ== + html_url: + type: string + format: uri + example: >- + https://github.com/octocat/Hello-World/commit/6dcb09b5b57875f334f61aebed695e2e4193db5e + comments_url: + type: string + format: uri + example: >- + https://api.github.com/repos/octocat/Hello-World/commits/6dcb09b5b57875f334f61aebed695e2e4193db5e/comments + commit: + type: object + properties: + url: + type: string + format: uri + example: >- + https://api.github.com/repos/octocat/Hello-World/commits/6dcb09b5b57875f334f61aebed695e2e4193db5e + author: + $ref: '#/components/schemas/nullable-git-user' + committer: + $ref: '#/components/schemas/nullable-git-user' + message: + type: string + example: Fix all the bugs + comment_count: + type: integer + example: 0 + tree: + type: object + properties: + sha: + type: string + example: 827efc6d56897b048c772eb4087f854f46256132 + url: + type: string + format: uri + example: >- + https://api.github.com/repos/octocat/Hello-World/tree/827efc6d56897b048c772eb4087f854f46256132 + required: + - sha + - url + verification: + $ref: '#/components/schemas/verification' + required: + - author + - committer + - comment_count + - message + - tree + - url + author: + $ref: '#/components/schemas/nullable-simple-user' + committer: + $ref: '#/components/schemas/nullable-simple-user' + parents: + type: array + items: + type: object + properties: + sha: + type: string + example: 7638417db6d59f3c431d3e1f261cc637155684cd + url: + type: string + format: uri + example: >- + https://api.github.com/repos/octocat/Hello-World/commits/7638417db6d59f3c431d3e1f261cc637155684cd + html_url: + type: string + format: uri + example: >- + https://github.com/octocat/Hello-World/commit/7638417db6d59f3c431d3e1f261cc637155684cd + required: + - sha + - url + stats: + type: object + properties: + additions: + type: integer + deletions: + type: integer + total: + type: integer + files: + type: array + items: + $ref: '#/components/schemas/diff-entry' + required: + - url + - sha + - node_id + - html_url + - comments_url + - commit + - author + - committer + - parents + securitySchemes: + apikey: + type: apiKey + name: x-api-key + in: header diff --git a/test/test_files/yaml/openapi_edge_cases.yml b/test/test_files/yaml/openapi_edge_cases.yml new file mode 100644 index 00000000..cef304c5 --- /dev/null +++ b/test/test_files/yaml/openapi_edge_cases.yml @@ -0,0 +1,13 @@ +openapi: 3.0.0 +info: + title: Edge Cases API + version: 1.0.0 +servers: + - url: http://localhost # not used anyway +paths: + /missing-operation-id: + get: + summary: Missing operationId + responses: + '200': + description: OK diff --git a/test/test_files/yaml/openapi_error_handling.yml b/test/test_files/yaml/openapi_error_handling.yml new file mode 100644 index 00000000..5cf23fe5 --- /dev/null +++ b/test/test_files/yaml/openapi_error_handling.yml @@ -0,0 +1,24 @@ +openapi: 3.0.0 +info: + title: Error Handling API + version: 1.0.0 +servers: + - url: http://localhost # not used anyway +paths: + /error/{status_code}: + get: + summary: Raise HTTP error + operationId: raiseHttpError + parameters: + - name: status_code + in: path + required: true + schema: + type: integer + responses: + '400': + description: Bad Request + '401': + description: Unauthorized + '404': + description: Not Found \ No newline at end of file diff --git a/test/test_files/yaml/openapi_greeting_service.yml b/test/test_files/yaml/openapi_greeting_service.yml new file mode 100644 index 00000000..701dee33 --- /dev/null +++ b/test/test_files/yaml/openapi_greeting_service.yml @@ -0,0 +1,272 @@ +openapi: 3.0.0 +info: + title: Greeting Service + version: 1.0.0 +servers: + - url: http://localhost # not used anyway +paths: + /greet/{name}: + post: + operationId: greet + parameters: + - name: name + in: path + required: true + schema: + type: string + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/MessageBody' + responses: + '200': + description: Successful response + content: + application/json: + schema: + $ref: '#/components/schemas/GreetingResponse' + + /greet-params/{name}: + get: + operationId: greetParams + parameters: + - name: name + in: path + required: true + schema: + type: string + responses: + '200': + description: Successful response + content: + application/json: + schema: + $ref: '#/components/schemas/GreetingResponse' + + /greet-body: + post: + operationId: greetBody + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/GreetBody' + responses: + '200': + description: Successful response + content: + application/json: + schema: + $ref: '#/components/schemas/GreetingResponse' + + /greet-api-key/{name}: + get: + operationId: greetApiKey + security: + - ApiKeyAuth: [] + parameters: + - name: name + in: path + required: true + schema: + type: string + responses: + '200': + description: Successful response + content: + application/json: + schema: + $ref: '#/components/schemas/GreetingResponse' + '401': + description: Unauthorized + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + + /greet-basic-auth/{name}: + get: + operationId: greetBasicAuth + security: + - BasicAuth: [] + parameters: + - name: name + in: path + required: true + schema: + type: string + responses: + '200': + description: Successful response + content: + application/json: + schema: + $ref: '#/components/schemas/GreetingResponse' + '401': + description: Unauthorized + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + /greet-api-key-query/{name}: + get: + operationId: greetApiKeyQuery + security: + - ApiKeyAuthQuery: [ ] + parameters: + - name: name + in: path + required: true + schema: + type: string + responses: + '200': + description: Successful response + content: + application/json: + schema: + $ref: '#/components/schemas/GreetingResponse' + '401': + description: Unauthorized + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + + /greet-api-key-cookie/{name}: + get: + operationId: greetApiKeyCookie + security: + - ApiKeyAuthCookie: [ ] + parameters: + - name: name + in: path + required: true + schema: + type: string + responses: + '200': + description: Successful response + content: + application/json: + schema: + $ref: '#/components/schemas/GreetingResponse' + '401': + description: Unauthorized + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + + /greet-bearer-auth/{name}: + get: + operationId: greetBearerAuth + security: + - BearerAuth: [ ] + parameters: + - name: name + in: path + required: true + schema: + type: string + responses: + '200': + description: Successful response + content: + application/json: + schema: + $ref: '#/components/schemas/GreetingResponse' + '401': + description: Unauthorized + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + + /greet-oauth/{name}: + get: + operationId: greetOAuth + security: + - OAuth2: [ ] + parameters: + - name: name + in: path + required: true + schema: + type: string + responses: + '200': + description: Successful response + content: + application/json: + schema: + $ref: '#/components/schemas/GreetingResponse' + '401': + description: Unauthorized + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' +components: + securitySchemes: + ApiKeyAuth: + type: apiKey + in: header + name: X-API-Key + BasicAuth: + type: http + scheme: basic + ApiKeyAuthQuery: + type: apiKey + in: query + name: api_key + ApiKeyAuthCookie: + type: apiKey + in: cookie + name: api_key + BearerAuth: + type: http + scheme: bearer + OAuth2: + type: oauth2 + flows: + authorizationCode: + authorizationUrl: https://example.com/oauth/authorize + tokenUrl: https://example.com/oauth/token + scopes: + read:greet: Read access to greeting service + + schemas: + GreetBody: + type: object + properties: + message: + type: string + name: + type: string + required: + - message + - name + + MessageBody: + type: object + properties: + message: + type: string + required: + - message + + GreetingResponse: + type: object + properties: + greeting: + type: string + + ErrorResponse: + type: object + properties: + detail: + type: string \ No newline at end of file diff --git a/test/test_files/yaml/openapi_order_service.yml b/test/test_files/yaml/openapi_order_service.yml new file mode 100644 index 00000000..07360ea5 --- /dev/null +++ b/test/test_files/yaml/openapi_order_service.yml @@ -0,0 +1,75 @@ +openapi: 3.0.0 +info: + title: Order Service + version: 1.0.0 +servers: + - url: http://localhost # not used anyway +paths: + /orders: + post: + summary: Create a new order + operationId: createOrder + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/Order' + responses: + '201': + description: Created + content: + application/json: + schema: + $ref: '#/components/schemas/OrderResponse' + +components: + schemas: + Order: + type: object + properties: + customer: + $ref: '#/components/schemas/Customer' + items: + type: array + items: + $ref: '#/components/schemas/OrderItem' + required: + - customer + - items + + Customer: + type: object + properties: + name: + type: string + email: + type: string + required: + - name + - email + + OrderItem: + type: object + properties: + product: + type: string + quantity: + type: integer + required: + - product + - quantity + + OrderResponse: + type: object + properties: + orderId: + type: string + status: + type: string + totalAmount: + type: number + required: + - orderId + - status + - totalAmount \ No newline at end of file diff --git a/test/test_files/yaml/serper.yml b/test/test_files/yaml/serper.yml new file mode 100644 index 00000000..9d5b1a85 --- /dev/null +++ b/test/test_files/yaml/serper.yml @@ -0,0 +1,39 @@ +openapi: 3.0.0 +info: + title: SerperDev + version: 1.0.0 + description: API for performing search queries +servers: + - url: https://google.serper.dev +paths: + /search: + post: + operationId: serperdev_search + description: Search the web with Google + requestBody: + required: true + content: + application/json: + schema: + type: object + properties: + q: + type: string + required: + - q + responses: + '200': + description: Successful response + content: + application/json: + schema: + type: object + additionalProperties: true + security: + - apikey: [] +components: + securitySchemes: + apikey: + type: apiKey + name: x-api-key + in: header diff --git a/test/util/__init__.py b/test/util/__init__.py new file mode 100644 index 00000000..c1764a6e --- /dev/null +++ b/test/util/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 diff --git a/test/util/conftest.py b/test/util/conftest.py new file mode 100644 index 00000000..bb575703 --- /dev/null +++ b/test/util/conftest.py @@ -0,0 +1,59 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + + +from pathlib import Path +from urllib.parse import urlparse + +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient + +from haystack_experimental.util.openapi import HttpClient, HttpClientError + + +@pytest.fixture() +def test_files_path(): + return Path(__file__).parent.parent / "test_files" + + +class FastAPITestClient(HttpClient): + + def __init__(self, app: FastAPI): + self.app = app + self.client = TestClient(app) + + def strip_host(self, url: str) -> str: + parsed_url = urlparse(url) + new_path = parsed_url.path + if parsed_url.query: + new_path += "?" + parsed_url.query + return new_path + + def send_request(self, request: dict) -> dict: + method = request["method"] + # OAS spec will list a server URL, but FastAPI doesn't need it for local testing, in fact it will fail + # if the URL has a host. So we strip it here. + url = self.strip_host(request["url"]) + headers = request.get("headers", {}) + params = request.get("params", {}) + json_data = request.get("json", None) + auth = request.get("auth", None) + cookies = request.get("cookies", {}) + + try: + response = self.client.request( + method, + url, + headers=headers, + params=params, + json=json_data, + auth=auth, + cookies=cookies, + ) + response.raise_for_status() + return response.json() + except Exception as e: + # Handle HTTP errors + raise HttpClientError(f"HTTP error occurred: {e}") from e diff --git a/test/util/test_openapi_client.py b/test/util/test_openapi_client.py new file mode 100644 index 00000000..581dfdf8 --- /dev/null +++ b/test/util/test_openapi_client.py @@ -0,0 +1,129 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + + +from fastapi import FastAPI +from fastapi.responses import JSONResponse +from pydantic import BaseModel + +from haystack_experimental.util.openapi import ClientConfigurationBuilder, OpenAPIServiceClient +from test.util.conftest import FastAPITestClient + +""" +Tests OpenAPIServiceClient with three FastAPI apps for different parameter types: + +- **greet_mix_params_body**: A POST endpoint `/greet/` accepting a JSON payload with a message, returning a +greeting with the name from the URL and the message from the payload. + +- **greet_params_only**: A GET endpoint `/greet-params/` taking a URL parameter, returning a greeting with +the name from the URL. + +- **greet_request_body_only**: A POST endpoint `/greet-body` accepting a JSON payload with a name and message, +returning a greeting with both. + +OpenAPI specs for these endpoints are in `openapi_greeting_service.yml` in `test/test_files` directory. +""" + + +class GreetBody(BaseModel): + message: str + name: str + + +class MessageBody(BaseModel): + message: str + + +# FastAPI app definitions +def create_greet_mix_params_body_app() -> FastAPI: + app = FastAPI() + + @app.post("/greet/{name}") + def greet(name: str, body: MessageBody): + greeting = f"{body.message}, {name} from mix_params_body!" + return JSONResponse(content={"greeting": greeting}) + + return app + + +def create_greet_params_only_app() -> FastAPI: + app = FastAPI() + + @app.get("/greet-params/{name}") + def greet_params(name: str): + greeting = f"Hello, {name} from params_only!" + return JSONResponse(content={"greeting": greeting}) + + return app + + +def create_greet_request_body_only_app() -> FastAPI: + app = FastAPI() + + @app.post("/greet-body") + def greet_request_body(body: GreetBody): + greeting = f"{body.message}, {body.name} from request_body_only!" + return JSONResponse(content={"greeting": greeting}) + + return app + + +class TestOpenAPI: + + def test_greet_mix_params_body(self, test_files_path): + builder = ClientConfigurationBuilder() + config = ( + builder.with_openapi_spec(test_files_path / "yaml" / "openapi_greeting_service.yml") + .with_http_client(FastAPITestClient(create_greet_mix_params_body_app())) + .build() + ) + client = OpenAPIServiceClient(config) + payload = { + "id": "call_NJr1NBz2Th7iUWJpRIJZoJIA", + "function": { + "arguments": '{"name": "John", "message": "Bonjour"}', + "name": "greet", + }, + "type": "function", + } + response = client.invoke(payload) + assert response == {"greeting": "Bonjour, John from mix_params_body!"} + + def test_greet_params_only(self, test_files_path): + builder = ClientConfigurationBuilder() + config = ( + builder.with_openapi_spec(test_files_path / "yaml" / "openapi_greeting_service.yml") + .with_http_client(FastAPITestClient(create_greet_params_only_app())) + .build() + ) + client = OpenAPIServiceClient(config) + payload = { + "id": "call_NJr1NBz2Th7iUWJpRIJZoJIA", + "function": { + "arguments": '{"name": "John"}', + "name": "greetParams", + }, + "type": "function", + } + response = client.invoke(payload) + assert response == {"greeting": "Hello, John from params_only!"} + + def test_greet_request_body_only(self, test_files_path): + builder = ClientConfigurationBuilder() + config = ( + builder.with_openapi_spec(test_files_path / "yaml" / "openapi_greeting_service.yml") + .with_http_client(FastAPITestClient(create_greet_request_body_only_app())) + .build() + ) + client = OpenAPIServiceClient(config) + payload = { + "id": "call_NJr1NBz2Th7iUWJpRIJZoJIA", + "function": { + "arguments": '{"name": "John", "message": "Hola"}', + "name": "greetBody", + }, + "type": "function", + } + response = client.invoke(payload) + assert response == {"greeting": "Hola, John from request_body_only!"} diff --git a/test/util/test_openapi_client_auth.py b/test/util/test_openapi_client_auth.py new file mode 100644 index 00000000..d326d164 --- /dev/null +++ b/test/util/test_openapi_client_auth.py @@ -0,0 +1,238 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + + +from fastapi import Depends, FastAPI, HTTPException, status +from fastapi.responses import JSONResponse +from fastapi.security import ( + APIKeyCookie, + APIKeyHeader, + APIKeyQuery, + HTTPAuthorizationCredentials, + HTTPBasic, + HTTPBasicCredentials, + HTTPBearer, +) + +from haystack_experimental.util.openapi import ClientConfigurationBuilder, OpenAPIServiceClient, ApiKeyAuthentication, \ + HTTPAuthentication +from test.util.conftest import FastAPITestClient + +API_KEY = "secret_api_key" +BASIC_AUTH_USERNAME = "admin" +BASIC_AUTH_PASSWORD = "secret_password" + +API_KEY_QUERY = "secret_api_key_query" +API_KEY_COOKIE = "secret_api_key_cookie" +BEARER_TOKEN = "secret_bearer_token" + +OAUTH_TOKEN = "secret-oauth-token" + +api_key_query = APIKeyQuery(name="api_key") +api_key_cookie = APIKeyCookie(name="api_key") +bearer_auth = HTTPBearer() + +api_key_header = APIKeyHeader(name="X-API-Key") +basic_auth_http = HTTPBasic() + + +def create_greet_api_key_query_app() -> FastAPI: + app = FastAPI() + + def api_key_query_auth(api_key: str = Depends(api_key_query)): + if api_key != API_KEY_QUERY: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key") + return api_key + + @app.get("/greet-api-key-query/{name}") + def greet_api_key_query(name: str, api_key: str = Depends(api_key_query_auth)): + greeting = f"Hello, {name} from api_key_query_auth, using {api_key}" + return JSONResponse(content={"greeting": greeting}) + + return app + + +def create_greet_api_key_cookie_app() -> FastAPI: + app = FastAPI() + + def api_key_cookie_auth(api_key: str = Depends(api_key_cookie)): + if api_key != API_KEY_COOKIE: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key") + return api_key + + @app.get("/greet-api-key-cookie/{name}") + def greet_api_key_cookie(name: str, api_key: str = Depends(api_key_cookie_auth)): + greeting = f"Hello, {name} from api_key_cookie_auth, using {api_key}" + return JSONResponse(content={"greeting": greeting}) + + return app + + +def create_greet_bearer_auth_app() -> FastAPI: + app = FastAPI() + + def bearer_auth_scheme( + credentials: HTTPAuthorizationCredentials = Depends(bearer_auth), # noqa: B008 + ): + if credentials.scheme != "Bearer" or credentials.credentials != BEARER_TOKEN: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token") + return credentials.credentials + + @app.get("/greet-bearer-auth/{name}") + def greet_bearer_auth(name: str, token: str = Depends(bearer_auth_scheme)): + greeting = f"Hello, {name} from bearer_auth, using {token}" + return JSONResponse(content={"greeting": greeting}) + + return app + + +def create_greet_api_key_auth_app() -> FastAPI: + app = FastAPI() + + def api_key_auth(api_key: str = Depends(api_key_header)): + if api_key != API_KEY: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key") + return api_key + + @app.get("/greet-api-key/{name}") + def greet_api_key(name: str, api_key: str = Depends(api_key_auth)): + greeting = f"Hello, {name} from api_key_auth, using {api_key}" + return JSONResponse(content={"greeting": greeting}) + + return app + + +def create_greet_basic_auth_app() -> FastAPI: + app = FastAPI() + + def basic_auth(credentials: HTTPBasicCredentials = Depends(basic_auth_http)): # noqa: B008 + if credentials.username != BASIC_AUTH_USERNAME or credentials.password != BASIC_AUTH_PASSWORD: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid credentials") + return credentials.username + + @app.get("/greet-basic-auth/{name}") + def greet_basic_auth(name: str, username: str = Depends(basic_auth)): + greeting = f"Hello, {name} from basic_auth, using {username}" + return JSONResponse(content={"greeting": greeting}) + + return app + + +def create_greet_oauth_auth_app() -> FastAPI: + app = FastAPI() + + def oauth_auth(token: HTTPAuthorizationCredentials = Depends(HTTPBearer())): # noqa: B008 + if token.credentials != OAUTH_TOKEN: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token") + return token + + @app.get("/greet-oauth/{name}") + def greet_oauth(name: str, token: HTTPAuthorizationCredentials = Depends(oauth_auth)): # noqa: B008 + greeting = f"Hello, {name} from oauth_auth, using {token}" + return JSONResponse(content={"greeting": greeting}) + + return app + + +class TestOpenAPIAuth: + + def test_greet_api_key_auth(self, test_files_path): + builder = ClientConfigurationBuilder() + config = ( + builder.with_openapi_spec(test_files_path / "yaml" / "openapi_greeting_service.yml") + .with_http_client(FastAPITestClient(create_greet_api_key_auth_app())) + .with_credentials(ApiKeyAuthentication(API_KEY)) + .build() + ) + client = OpenAPIServiceClient(config) + payload = { + "id": "call_NJr1NBz2Th7iUWJpRIJZoJIA", + "function": { + "arguments": '{"name": "John"}', + "name": "greetApiKey", + }, + "type": "function", + } + response = client.invoke(payload) + assert response == {"greeting": "Hello, John from api_key_auth, using secret_api_key"} + + def test_greet_basic_auth(self, test_files_path): + builder = ClientConfigurationBuilder() + config = ( + builder.with_openapi_spec(test_files_path / "yaml" / "openapi_greeting_service.yml") + .with_http_client(FastAPITestClient(create_greet_basic_auth_app())) + .with_credentials(HTTPAuthentication(BASIC_AUTH_USERNAME, BASIC_AUTH_PASSWORD)) + .build() + ) + client = OpenAPIServiceClient(config) + payload = { + "id": "call_NJr1NBz2Th7iUWJpRIJZoJIA", + "function": { + "arguments": '{"name": "John"}', + "name": "greetBasicAuth", + }, + "type": "function", + } + response = client.invoke(payload) + assert response == {"greeting": "Hello, John from basic_auth, using admin"} + + def test_greet_api_key_query_auth(self, test_files_path): + builder = ClientConfigurationBuilder() + config = ( + builder.with_openapi_spec(test_files_path / "yaml" / "openapi_greeting_service.yml") + .with_http_client(FastAPITestClient(create_greet_api_key_query_app())) + .with_credentials(ApiKeyAuthentication(API_KEY_QUERY)) + .build() + ) + client = OpenAPIServiceClient(config) + payload = { + "id": "call_NJr1NBz2Th7iUWJpRIJZoJIA", + "function": { + "arguments": '{"name": "John"}', + "name": "greetApiKeyQuery", + }, + "type": "function", + } + response = client.invoke(payload) + assert response == {"greeting": "Hello, John from api_key_query_auth, using secret_api_key_query"} + + def test_greet_api_key_cookie_auth(self, test_files_path): + builder = ClientConfigurationBuilder() + config = ( + builder.with_openapi_spec(test_files_path / "yaml" / "openapi_greeting_service.yml") + .with_http_client(FastAPITestClient(create_greet_api_key_cookie_app())) + .with_credentials(ApiKeyAuthentication(API_KEY_COOKIE)) + .build() + ) + client = OpenAPIServiceClient(config) + payload = { + "id": "call_NJr1NBz2Th7iUWJpRIJZoJIA", + "function": { + "arguments": '{"name": "John"}', + "name": "greetApiKeyCookie", + }, + "type": "function", + } + response = client.invoke(payload) + assert response == {"greeting": "Hello, John from api_key_cookie_auth, using secret_api_key_cookie"} + + def test_greet_bearer_auth(self, test_files_path): + builder = ClientConfigurationBuilder() + config = ( + builder.with_openapi_spec(test_files_path / "yaml" / "openapi_greeting_service.yml") + .with_http_client(FastAPITestClient(create_greet_bearer_auth_app())) + .with_credentials(HTTPAuthentication(token=BEARER_TOKEN)) + .build() + ) + client = OpenAPIServiceClient(config) + payload = { + "id": "call_NJr1NBz2Th7iUWJpRIJZoJIA", + "function": { + "arguments": '{"name": "John"}', + "name": "greetBearerAuth", + }, + "type": "function", + } + response = client.invoke(payload) + assert response == {"greeting": "Hello, John from bearer_auth, using secret_bearer_token"} diff --git a/test/util/test_openapi_client_complex_request_body.py b/test/util/test_openapi_client_complex_request_body.py new file mode 100644 index 00000000..17702cb3 --- /dev/null +++ b/test/util/test_openapi_client_complex_request_body.py @@ -0,0 +1,87 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + + +import json +from typing import List + +import pytest +from fastapi import FastAPI +from fastapi.responses import JSONResponse +from pydantic import BaseModel + +from haystack_experimental.util.openapi import ClientConfigurationBuilder, OpenAPIServiceClient +from test.util.conftest import FastAPITestClient + + +class Customer(BaseModel): + name: str + email: str + + +class OrderItem(BaseModel): + product: str + quantity: int + + +class Order(BaseModel): + customer: Customer + items: List[OrderItem] + + +class OrderResponse(BaseModel): + orderId: str # noqa: N815 + status: str + totalAmount: float # noqa: N815 + + +def create_order_app() -> FastAPI: + app = FastAPI() + + @app.post("/orders") + def create_order(order: Order): + total_amount = sum(item.quantity * 10 for item in order.items) + response = OrderResponse( + orderId="ORDER-001", + status="CREATED", + totalAmount=total_amount, + ) + return JSONResponse(content=response.model_dump(), status_code=201) + + return app + + +class TestComplexRequestBody: + + @pytest.mark.parametrize("spec_file_path", ["openapi_order_service.yml", "openapi_order_service.json"]) + def test_create_order(self, spec_file_path, test_files_path): + builder = ClientConfigurationBuilder() + path_element = "yaml" if spec_file_path.endswith(".yml") else "json" + config = ( + builder.with_openapi_spec(test_files_path / path_element / spec_file_path) + .with_http_client(FastAPITestClient(create_order_app())) + .build() + ) + client = OpenAPIServiceClient(config) + order_json = { + "customer": {"name": "John Doe", "email": "john@example.com"}, + "items": [ + {"product": "Product A", "quantity": 2}, + {"product": "Product B", "quantity": 1}, + ], + } + payload = { + "id": "call_NJr1NBz2Th7iUWJpRIJZoJIA", + "function": { + "arguments": json.dumps(order_json), + "name": "createOrder", + }, + "type": "function", + } + response = client.invoke(payload) + assert response == { + "orderId": "ORDER-001", + "status": "CREATED", + "totalAmount": 30, + } diff --git a/test/util/test_openapi_client_complex_request_body_mixed.py b/test/util/test_openapi_client_complex_request_body_mixed.py new file mode 100644 index 00000000..907520eb --- /dev/null +++ b/test/util/test_openapi_client_complex_request_body_mixed.py @@ -0,0 +1,90 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + + +import json + +from fastapi import FastAPI +from fastapi.responses import JSONResponse +from pydantic import BaseModel + +from haystack_experimental.util.openapi import ClientConfigurationBuilder, OpenAPIServiceClient +from test.util.conftest import FastAPITestClient + + +class Identification(BaseModel): + type: str + number: str + + +class Payer(BaseModel): + name: str + email: str + identification: Identification + + +class PaymentRequest(BaseModel): + transaction_amount: float + description: str + payment_method_id: str + payer: Payer + + +class PaymentResponse(BaseModel): + transaction_id: str + status: str + message: str + + +def create_payment_app() -> FastAPI: + app = FastAPI() + + @app.post("/new_payment") + def process_payment(payment: PaymentRequest): + # sanity + assert payment.transaction_amount == 100.0 + response = PaymentResponse( + transaction_id="TRANS-12345", status="SUCCESS", message="Payment processed successfully." + ) + return JSONResponse(content=response.model_dump(), status_code=200) + + return app + + +# Write the unit test +class TestPaymentProcess: + + def test_process_payment(self, test_files_path): + config = ( + ClientConfigurationBuilder() + .with_openapi_spec(test_files_path / "json" / "complex_types_openapi_service.json") + .with_http_client(FastAPITestClient(create_payment_app())) + .build() + ) + client = OpenAPIServiceClient(config) + + payment_json = { + "transaction_amount": 100.0, + "description": "Test Payment", + "payment_method_id": "CARD-123", + "payer": { + "name": "Alice Smith", + "email": "alice@example.com", + "identification": {"type": "CPF", "number": "123.456.789-00"}, + }, + } + payload = { + "id": "call_uniqueID123", + "function": { + "arguments": json.dumps(payment_json), + "name": "processPayment", + }, + "type": "function", + } + response = client.invoke(payload) + assert response == { + "transaction_id": "TRANS-12345", + "status": "SUCCESS", + "message": "Payment processed successfully.", + } diff --git a/test/util/test_openapi_client_edge_cases.py b/test/util/test_openapi_client_edge_cases.py new file mode 100644 index 00000000..f421e73e --- /dev/null +++ b/test/util/test_openapi_client_edge_cases.py @@ -0,0 +1,38 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + + +import pytest + +from haystack_experimental.util.openapi import ClientConfigurationBuilder, OpenAPIServiceClient +from test.util.conftest import FastAPITestClient + + +class TestEdgeCases: + def test_invalid_openapi_spec(self): + builder = ClientConfigurationBuilder() + with pytest.raises(ValueError, match="Invalid OpenAPI specification"): + config = builder.with_openapi_spec("invalid_spec.yml").build() + OpenAPIServiceClient(config) + + def test_missing_operation_id(self, test_files_path): + builder = ClientConfigurationBuilder() + config = ( + builder.with_openapi_spec(test_files_path / "yaml" / "openapi_edge_cases.yml") + .with_http_client(FastAPITestClient(None)) + .build() + ) + client = OpenAPIServiceClient(config) + + payload = { + "type": "function", + "function": { + "arguments": '{"name": "John", "message": "Hola"}', + "name": "missingOperationId", + }, + } + with pytest.raises(ValueError, match="No operation found with operationId"): + client.invoke(payload) + + # TODO: Add more tests for edge cases diff --git a/test/util/test_openapi_client_error_handling.py b/test/util/test_openapi_client_error_handling.py new file mode 100644 index 00000000..3201ad65 --- /dev/null +++ b/test/util/test_openapi_client_error_handling.py @@ -0,0 +1,46 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + + +import json + +import pytest +from fastapi import FastAPI, HTTPException + +from haystack_experimental.util.openapi import ClientConfigurationBuilder, OpenAPIServiceClient, HttpClientError +from test.util.conftest import FastAPITestClient + + +def create_error_handling_app() -> FastAPI: + app = FastAPI() + + @app.get("/error/{status_code}") + def raise_http_error(status_code: int): + raise HTTPException(status_code=status_code, detail=f"HTTP {status_code} error") + + return app + + +class TestErrorHandling: + @pytest.mark.parametrize("status_code", [400, 401, 403, 404, 500]) + def test_http_error_handling(self, test_files_path, status_code): + builder = ClientConfigurationBuilder() + config = ( + builder.with_openapi_spec(test_files_path / "yaml" / "openapi_error_handling.yml") + .with_http_client(FastAPITestClient(create_error_handling_app())) + .build() + ) + client = OpenAPIServiceClient(config) + json_error = {"status_code": status_code} + payload = { + "type": "function", + "function": { + "arguments": json.dumps(json_error), + "name": "raiseHttpError", + }, + } + with pytest.raises(HttpClientError) as exc_info: + client.invoke(payload) + + assert str(status_code) in str(exc_info.value) diff --git a/test/util/test_openapi_client_live.py b/test/util/test_openapi_client_live.py new file mode 100644 index 00000000..869c56a8 --- /dev/null +++ b/test/util/test_openapi_client_live.py @@ -0,0 +1,73 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import json +import os + +import pytest +import yaml +from haystack_experimental.util.openapi import ClientConfigurationBuilder, OpenAPIServiceClient + + +class TestClientLive: + + @pytest.mark.skipif("SERPERDEV_API_KEY" not in os.environ, reason="SERPERDEV_API_KEY not set") + @pytest.mark.integration + def test_serperdev(self, test_files_path): + builder = ClientConfigurationBuilder() + config = ( + builder.with_openapi_spec(test_files_path / "yaml" / "serper.yml") + .with_credentials(os.getenv("SERPERDEV_API_KEY")) + .build() + ) + serper_api = OpenAPIServiceClient(config) + payload = { + "id": "call_NJr1NBz2Th7iUWJpRIJZoJIA", + "function": { + "arguments": '{"q": "Who was Nikola Tesla?"}', + "name": "serperdev_search", + }, + "type": "function", + } + response = serper_api.invoke(payload) + assert "invention" in str(response) + + @pytest.mark.skipif("SERPERDEV_API_KEY" not in os.environ, reason="SERPERDEV_API_KEY not set") + @pytest.mark.integration + def test_serperdev_load_spec_first(self, test_files_path): + with open(test_files_path / "yaml" / "serper.yml") as file: + loaded_spec = yaml.safe_load(file) + + # use builder with dict spec + builder = ClientConfigurationBuilder() + config = builder.with_openapi_spec(loaded_spec).with_credentials(os.getenv("SERPERDEV_API_KEY")).build() + serper_api = OpenAPIServiceClient(config) + payload = { + "id": "call_NJr1NBz2Th7iUWJpRIJZoJIA", + "function": { + "arguments": '{"q": "Who was Nikola Tesla?"}', + "name": "serperdev_search", + }, + "type": "function", + } + response = serper_api.invoke(payload) + assert "invention" in str(response) + + @pytest.mark.integration + def test_github(self, test_files_path): + builder = ClientConfigurationBuilder() + config = builder.with_openapi_spec(test_files_path / "yaml" / "github_compare.yml").build() + api = OpenAPIServiceClient(config) + + params = {"owner": "deepset-ai", "repo": "haystack", "basehead": "main...add_default_adapter_filters"} + payload = { + "id": "call_NJr1NBz2Th7iUWJpRIJZoJIA", + "function": { + "arguments": json.dumps(params), + "name": "compare", + }, + "type": "function", + } + response = api.invoke(payload) + assert "deepset" in str(response) diff --git a/test/util/test_openapi_client_live_anthropic.py b/test/util/test_openapi_client_live_anthropic.py new file mode 100644 index 00000000..0ac5c09c --- /dev/null +++ b/test/util/test_openapi_client_live_anthropic.py @@ -0,0 +1,69 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import os + +import anthropic +import pytest + +from haystack_experimental.util.openapi import ClientConfigurationBuilder, OpenAPIServiceClient + + +class TestClientLiveAnthropic: + + @pytest.mark.skipif("SERPERDEV_API_KEY" not in os.environ, reason="SERPERDEV_API_KEY not set") + @pytest.mark.skipif("ANTHROPIC_API_KEY" not in os.environ, reason="ANTHROPIC_API_KEY not set") + @pytest.mark.integration + def test_serperdev(self): + builder = ClientConfigurationBuilder() + config = ( + builder.with_openapi_spec("https://bit.ly/serper_dev_spec_yaml") + .with_credentials(os.getenv("SERPERDEV_API_KEY")) + .with_provider("anthropic") + .build() + ) + client = anthropic.Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY")) + response = client.beta.tools.messages.create( + model="claude-3-opus-20240229", + max_tokens=1024, + tools=config.get_tools_definitions(), + messages=[{"role": "user", "content": "Do a google search: Who was Nikola Tesla?"}], + ) + service_api = OpenAPIServiceClient(config) + service_response = service_api.invoke(response) + assert "inventions" in str(service_response) + + # make a few more requests to test the same tool + service_response = service_api.invoke(response) + assert "Serbian" in str(service_response) + + service_response = service_api.invoke(response) + assert "American" in str(service_response) + + @pytest.mark.skipif("ANTHROPIC_API_KEY" not in os.environ, reason="ANTHROPIC_API_KEY not set") + @pytest.mark.integration + def test_github(self, test_files_path): + builder = ClientConfigurationBuilder() + config = ( + builder.with_openapi_spec(test_files_path / "yaml" / "github_compare.yml") + .with_provider("anthropic") + .build() + ) + + client = anthropic.Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY")) + response = client.beta.tools.messages.create( + model="claude-3-opus-20240229", + max_tokens=1024, + tools=config.get_tools_definitions(), + messages=[ + { + "role": "user", + "content": "Compare branches main and add_default_adapter_filters in repo" + " haystack and owner deepset-ai", + } + ], + ) + service_api = OpenAPIServiceClient(config) + service_response = service_api.invoke(response) + assert "deepset" in str(service_response) diff --git a/test/util/test_openapi_client_live_cohere.py b/test/util/test_openapi_client_live_cohere.py new file mode 100644 index 00000000..901b474d --- /dev/null +++ b/test/util/test_openapi_client_live_cohere.py @@ -0,0 +1,73 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +import os + +import cohere +import pytest + +from haystack_experimental.util.openapi import ClientConfigurationBuilder, OpenAPIServiceClient + +# Copied from Cohere's documentation +preamble = """ +## Task & Context +You help people answer their questions and other requests interactively. You will be asked a very wide array of + requests on all kinds of topics. You will be equipped with a wide range of search engines or similar tools to + help you, which you use to research your answer. You should focus on serving the user's needs as best you can, + which will be wide-ranging. + +## Style Guide +Unless the user asks for a different style of answer, you should answer in full sentences, using proper grammar and + spelling. +""" + + +class TestClientLiveCohere: + + @pytest.mark.skipif("SERPERDEV_API_KEY" not in os.environ, reason="SERPERDEV_API_KEY not set") + @pytest.mark.skipif("COHERE_API_KEY" not in os.environ, reason="COHERE_API_KEY not set") + @pytest.mark.integration + def test_serperdev(self, test_files_path): + builder = ClientConfigurationBuilder() + config = ( + builder.with_openapi_spec(test_files_path / "yaml" / "serper.yml") + .with_credentials(os.getenv("SERPERDEV_API_KEY")) + .with_provider("cohere") + .build() + ) + client = cohere.Client(api_key=os.getenv("COHERE_API_KEY")) + response = client.chat( + model="command-r", + preamble=preamble, + tools=config.get_tools_definitions(), + message="Do a google search: Who was Nikola Tesla?", + ) + service_api = OpenAPIServiceClient(config) + service_response = service_api.invoke(response) + assert "inventions" in str(service_response) + + # make a few more requests to test the same tool + service_response = service_api.invoke(response) + assert "Serbian" in str(service_response) + + service_response = service_api.invoke(response) + assert "American" in str(service_response) + + @pytest.mark.skipif("COHERE_API_KEY" not in os.environ, reason="COHERE_API_KEY not set") + @pytest.mark.integration + def test_github(self, test_files_path): + builder = ClientConfigurationBuilder() + config = ( + builder.with_openapi_spec(test_files_path / "yaml" / "github_compare.yml").with_provider("cohere").build() + ) + + client = cohere.Client(api_key=os.getenv("COHERE_API_KEY")) + response = client.chat( + model="command-r", + preamble=preamble, + tools=config.get_tools_definitions(), + message="Compare branches main and add_default_adapter_filters in repo haystack and owner deepset-ai", + ) + service_api = OpenAPIServiceClient(config) + service_response = service_api.invoke(response) + assert "deepset" in str(service_response) diff --git a/test/util/test_openapi_client_live_openai.py b/test/util/test_openapi_client_live_openai.py new file mode 100644 index 00000000..bfc00e06 --- /dev/null +++ b/test/util/test_openapi_client_live_openai.py @@ -0,0 +1,99 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import os + +import pytest +from openai import OpenAI + +from haystack_experimental.util.openapi import ClientConfigurationBuilder, OpenAPIServiceClient + + +class TestClientLiveOpenAPI: + + @pytest.mark.skipif("SERPERDEV_API_KEY" not in os.environ, reason="SERPERDEV_API_KEY not set") + @pytest.mark.skipif("OPENAI_API_KEY" not in os.environ, reason="OPENAI_API_KEY not set") + @pytest.mark.integration + def test_serperdev(self): + builder = ClientConfigurationBuilder() + config = ( + builder.with_openapi_spec("https://bit.ly/serper_dev_spec_yaml") + .with_credentials(os.getenv("SERPERDEV_API_KEY")) + .build() + ) + client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) + response = client.chat.completions.create( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "Do a serperdev google search: Who was Nikola Tesla?"}], + tools=config.get_tools_definitions(), + ) + service_api = OpenAPIServiceClient(config) + service_response = service_api.invoke(response) + assert "inventions" in str(service_response) + + # make a few more requests to test the same tool + service_response = service_api.invoke(response) + assert "Serbian" in str(service_response) + + service_response = service_api.invoke(response) + assert "American" in str(service_response) + + @pytest.mark.skipif("OPENAI_API_KEY" not in os.environ, reason="OPENAI_API_KEY not set") + @pytest.mark.integration + def test_github(self, test_files_path): + builder = ClientConfigurationBuilder() + config = builder.with_openapi_spec(test_files_path / "yaml" / "github_compare.yml").build() + + client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) + response = client.chat.completions.create( + model="gpt-3.5-turbo", + messages=[ + { + "role": "user", + "content": "Compare branches main and add_default_adapter_filters in repo" + " haystack and owner deepset-ai", + } + ], + tools=config.get_tools_definitions(), + ) + service_api = OpenAPIServiceClient(config) + service_response = service_api.invoke(response) + assert "deepset" in str(service_response) + + @pytest.mark.skipif("FIRECRAWL_API_KEY" not in os.environ, reason="FIRECRAWL_API_KEY not set") + @pytest.mark.skipif("OPENAI_API_KEY" not in os.environ, reason="OPENAI_API_KEY not set") + @pytest.mark.integration + def test_firecrawl(self): + openapi_spec_url = "https://raw.githubusercontent.com/mendableai/firecrawl/main/apps/api/openapi.json" + builder = ClientConfigurationBuilder() + config = builder.with_openapi_spec(openapi_spec_url).with_credentials(os.getenv("FIRECRAWL_API_KEY")).build() + + client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) + response = client.chat.completions.create( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "Scrape URL: https://news.ycombinator.com/"}], + tools=config.get_tools_definitions(), + ) + service_api = OpenAPIServiceClient(config) + service_response = service_api.invoke(response) + assert isinstance(service_response, dict) + assert service_response.get("success", False), "Firecrawl scrape API call failed" + + # now test the same openapi service but different endpoint/tool + top_k = 2 + response = client.chat.completions.create( + model="gpt-3.5-turbo", + messages=[ + { + "role": "user", + "content": f"Search Google for `Why was Sam Altman ousted from OpenAI?`, limit to {top_k} results", + } + ], + tools=config.get_tools_definitions(), + ) + service_response = service_api.invoke(response) + assert isinstance(service_response, dict) + assert service_response.get("success", False), "Firecrawl search API call failed" + assert len(service_response.get("data", [])) == top_k + assert "Sam" in str(service_response) diff --git a/test/util/test_openapi_cohere_conversion.py b/test/util/test_openapi_cohere_conversion.py new file mode 100644 index 00000000..5afb1673 --- /dev/null +++ b/test/util/test_openapi_cohere_conversion.py @@ -0,0 +1,79 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from haystack_experimental.util.openapi import OpenAPISpecification, cohere_converter + + +class TestOpenAPISchemaConversion: + + def test_serperdev(self, test_files_path): + spec = OpenAPISpecification.from_file(test_files_path / "yaml" / "serper.yml") + functions = cohere_converter(schema=spec) + + assert functions + assert len(functions) == 1 + function = functions[0] + assert function["name"] == "serperdev_search" + assert function["description"] == "Search the web with Google" + assert function["parameter_definitions"] == {"q": {"description": "", "type": "str", "required": True}} + + def test_github(self, test_files_path): + spec = OpenAPISpecification.from_file(test_files_path / "yaml" / "github_compare.yml") + functions = cohere_converter(schema=spec) + assert functions + assert len(functions) == 1 + function = functions[0] + assert function["name"] == "compare_branches" + assert function["description"] == "Compares two branches against one another." + assert function["parameter_definitions"] == { + "basehead": { + "description": "The base branch and head branch to compare." + " This parameter expects the format `BASE...HEAD`", + "type": "str", + "required": True, + }, + "owner": { + "description": "The repository owner, usually a company or orgnization", + "type": "str", + "required": True, + }, + "repo": {"description": "The repository itself, the project", "type": "str", "required": True}, + } + + def test_complex_types(self, test_files_path): + spec = OpenAPISpecification.from_file(test_files_path / "json" / "complex_types_openapi_service.json") + functions = cohere_converter(schema=spec) + + assert functions + assert len(functions) == 1 + function = functions[0] + assert function["name"] == "processPayment" + assert function["description"] == "Process a new payment using the specified payment method" + assert function["parameter_definitions"] == { + "transaction_amount": {"type": "float", "description": "The amount to be paid", "required": True}, + "description": {"type": "str", "description": "A brief description of the payment", "required": True}, + "payment_method_id": {"type": "str", "description": "The payment method to be used", "required": True}, + "payer": { + "type": "object", + "description": "Information about the payer, including their name, email, and identification number", + "properties": { + "name": {"type": "str", "description": "The payer's name", "required": True}, + "email": {"type": "str", "description": "The payer's email address", "required": True}, + "identification": { + "type": "object", + "description": "The payer's identification number", + "properties": { + "type": { + "type": "str", + "description": "The type of identification document (e.g., CPF, CNPJ)", + "required": True, + }, + "number": {"type": "str", "description": "The identification number", "required": True}, + }, + "required": True, + }, + }, + "required": True, + }, + } diff --git a/test/util/test_openapi_openai_conversion.py b/test/util/test_openapi_openai_conversion.py new file mode 100644 index 00000000..fb3afa0e --- /dev/null +++ b/test/util/test_openapi_openai_conversion.py @@ -0,0 +1,104 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +from haystack_experimental.util.openapi import openai_converter, anthropic_converter, OpenAPISpecification + + +class TestOpenAPISchemaConversion: + + @pytest.mark.parametrize("provider", ["openai", "anthropic"]) + def test_serperdev(self, test_files_path, provider): + spec = OpenAPISpecification.from_file(test_files_path / "yaml" / "serper.yml") + functions = openai_converter(schema=spec) if provider == "openai" else anthropic_converter(schema=spec) + assert functions + assert len(functions) == 1 + function = functions[0]["function"] if provider == "openai" else functions[0] + assert function["name"] == "serperdev_search" + assert function["description"] == "Search the web with Google" + assert ( + function["parameters"] + if provider == "openai" + else function["input_schema"] + == {"type": "object", "properties": {"q": {"type": "string"}}, "required": ["q"]} + ) + + @pytest.mark.parametrize("provider", ["openai", "anthropic"]) + def test_github(self, test_files_path, provider: str): + spec = OpenAPISpecification.from_file(test_files_path / "yaml" / "github_compare.yml") + functions = openai_converter(schema=spec) if provider == "openai" else anthropic_converter(schema=spec) + assert functions + assert len(functions) == 1 + function = functions[0]["function"] if provider == "openai" else functions[0] + assert function["name"] == "compare_branches" + assert function["description"] == "Compares two branches against one another." + assert ( + function["parameters"] + if provider == "openai" + else function["input_schema"] + == { + "type": "object", + "properties": { + "basehead": { + "type": "string", + "description": "The base branch and head branch to compare. " + "This parameter expects the format `BASE...HEAD`", + }, + "owner": { + "type": "string", + "description": "The repository owner, usually a company or orgnization", + }, + "repo": {"type": "string", "description": "The repository itself, the project"}, + }, + "required": ["basehead", "owner", "repo"], + } + ) + + @pytest.mark.parametrize("provider", ["openai", "anthropic"]) + def test_complex_types(self, test_files_path, provider: str): + spec = OpenAPISpecification.from_file(test_files_path / "json" / "complex_types_openapi_service.json") + functions = openai_converter(schema=spec) if provider == "openai" else anthropic_converter(schema=spec) + + assert functions + assert len(functions) == 1 + function = functions[0]["function"] if provider == "openai" else functions[0] + assert function["name"] == "processPayment" + assert function["description"] == "Process a new payment using the specified payment method" + assert ( + function["parameters"] + if provider == "openai" + else function["input_schema"] + == { + "type": "object", + "properties": { + "transaction_amount": {"type": "number", "description": "The amount to be paid"}, + "description": {"type": "string", "description": "A brief description of the payment"}, + "payment_method_id": {"type": "string", "description": "The payment method to be used"}, + "payer": { + "type": "object", + "description": "Information about the payer, including their name, email, " + "and identification number", + "properties": { + "name": {"type": "string", "description": "The payer's name"}, + "email": {"type": "string", "description": "The payer's email address"}, + "identification": { + "type": "object", + "description": "The payer's identification number", + "properties": { + "type": { + "type": "string", + "description": "The type of identification document (e.g., CPF, CNPJ)", + }, + "number": {"type": "string", "description": "The identification number"}, + }, + "required": ["type", "number"], + }, + }, + "required": ["name", "email", "identification"], + }, + }, + "required": ["transaction_amount", "description", "payment_method_id", "payer"], + } + ) diff --git a/test/util/test_openapi_spec.py b/test/util/test_openapi_spec.py new file mode 100644 index 00000000..722946f1 --- /dev/null +++ b/test/util/test_openapi_spec.py @@ -0,0 +1,127 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import json + +import pytest + +from haystack_experimental.util.openapi import OpenAPISpecification + + +class TestOpenAPISpecification: + + # can be initialized from a dictionary + def test_initialized_from_dictionary(self): + spec_dict = { + "openapi": "3.0.0", + "info": {"title": "Test API", "version": "1.0.0"}, + "servers": [{"url": "https://api.example.com"}], + "paths": { + "/users": { + "get": {"summary": "Get all users", "responses": {"200": {"description": "Successful response"}}} + } + }, + } + openapi_spec = OpenAPISpecification.from_dict(spec_dict) + assert openapi_spec.spec_dict == spec_dict + + # can be initialized from a string + def test_initialized_from_string(self): + content = """ + openapi: 3.0.0 + info: + title: Test API + version: 1.0.0 + servers: + - url: https://api.example.com + paths: + /users: + get: + summary: Get all users + responses: + '200': + description: Successful response + """ + openapi_spec = OpenAPISpecification.from_str(content) + assert openapi_spec.spec_dict == { + "openapi": "3.0.0", + "info": {"title": "Test API", "version": "1.0.0"}, + "servers": [{"url": "https://api.example.com"}], + "paths": { + "/users": { + "get": {"summary": "Get all users", "responses": {"200": {"description": "Successful response"}}} + } + }, + } + + # can be initialized from a file + def test_initialized_from_file(self, tmp_path): + content = """ + openapi: 3.0.0 + info: + title: Test API + version: 1.0.0 + servers: + - url: https://api.example.com + paths: + /users: + get: + summary: Get all users + responses: + '200': + description: Successful response + """ + file_path = tmp_path / "spec.yaml" + file_path.write_text(content) + openapi_spec = OpenAPISpecification.from_file(file_path) + assert openapi_spec.spec_dict == { + "openapi": "3.0.0", + "info": {"title": "Test API", "version": "1.0.0"}, + "servers": [{"url": "https://api.example.com"}], + "paths": { + "/users": { + "get": {"summary": "Get all users", "responses": {"200": {"description": "Successful response"}}} + } + }, + } + + # can get all paths + def test_get_all_paths(self): + spec_dict = { + "openapi": "3.0.0", + "info": {"title": "Test API", "version": "1.0.0"}, + "servers": [{"url": "https://api.example.com"}], + "paths": {"/users": {}, "/products": {}, "/orders": {}}, + } + openapi_spec = OpenAPISpecification(spec_dict) + paths = openapi_spec.get_paths() + assert paths == {"/users": {}, "/products": {}, "/orders": {}} + + # raises ValueError if initialized from an invalid schema + def test_raises_value_error_invalid_schema(self): + spec_dict = {"info": {"title": "Test API", "version": "1.0.0"}, "paths": {"/users": {}}} + with pytest.raises(ValueError): + OpenAPISpecification(spec_dict) + + # Should return the raw OpenAPI specification dictionary with resolved references. + def test_return_raw_spec_with_resolved_references(self, test_files_path): + spec = OpenAPISpecification.from_file(test_files_path / "json" / "complex_types_openapi_service.json") + raw_spec = spec.to_dict(resolve_references=True) + + assert "$ref" not in str(raw_spec) + assert "#/" not in str(raw_spec) + + # verify that we can serialize the raw spec to a string + schema_ser = json.dumps(raw_spec, indent=2) + + # and that the serialized string does not contain any $ref or #/ references + assert "$ref" not in schema_ser + assert "#/" not in schema_ser + + # and that we can deserialize the serialized string back to a dictionary + schema = json.loads(schema_ser) + assert "$ref" not in schema + assert "#/" not in schema + + assert schema == raw_spec