Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve OpenAPITool corner cases handling (missing operationId, servers under paths, etc) #37

Merged
merged 2 commits into from
Jul 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@

from haystack.lazy_imports import LazyImport

from haystack_experimental.components.tools.openapi.types import OpenAPISpecification
from haystack_experimental.components.tools.openapi.types import (
VALID_HTTP_METHODS,
OpenAPISpecification,
path_to_operation_id,
)

with LazyImport("Run 'pip install jsonref'") as jsonref_import:
# pylint: disable=import-error
Expand Down Expand Up @@ -96,11 +100,14 @@ def _openapi_to_functions(
f"at least {MIN_REQUIRED_OPENAPI_SPEC_VERSION}."
)
functions: List[Dict[str, Any]] = []
for paths in service_openapi_spec["paths"].values():
for path_spec in paths.values():
function_dict = parse_endpoint_fn(path_spec, parameters_name)
if function_dict:
functions.append(function_dict)
for path, path_value in service_openapi_spec["paths"].items():
for path_key, operation_spec in path_value.items():
if path_key.lower() in VALID_HTTP_METHODS:
if "operationId" not in operation_spec:
operation_spec["operationId"] = path_to_operation_id(path, path_key)
function_dict = parse_endpoint_fn(operation_spec, parameters_name)
if function_dict:
functions.append(function_dict)
return functions


Expand Down
107 changes: 50 additions & 57 deletions haystack_experimental/components/tools/openapi/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,24 @@
]


def path_to_operation_id(path: str, http_method: str = "get") -> str:
"""
Converts a path to an operationId.

:param path: The path to convert.
:param http_method: The HTTP method to use for the operationId.
:returns: The operationId.
"""
if http_method.lower() not in VALID_HTTP_METHODS:
raise ValueError(f"Invalid HTTP method: {http_method}")
return path.replace("/", "_").lstrip("_").rstrip("_") + "_" + http_method.lower()


class LLMProvider(Enum):
"""
LLM providers supported by `OpenAPITool`.
"""

OPENAI = "openai"
ANTHROPIC = "anthropic"
COHERE = "cohere"
Expand All @@ -50,18 +64,18 @@ def from_str(string: str) -> "LLMProvider":
@dataclass
class Operation:
"""
Represents an operation in an OpenAPI specification

See https://spec.openapis.org/oas/latest.html#paths-object for details.
Path objects can contain multiple operations, each with a unique combination of path and method.

:param path: Path of the operation.
:param method: HTTP method of the operation.
:param operation_dict: Operation details from OpenAPI spec
:param spec_dict: The encompassing OpenAPI specification.
:param security_requirements: A list of security requirements for the operation.
:param request_body: Request body details.
:param parameters: Parameters for the operation.
Represents an operation in an OpenAPI specification

See https://spec.openapis.org/oas/latest.html#paths-object for details.
Path objects can contain multiple operations, each with a unique combination of path and method.

:param path: Path of the operation.
:param method: HTTP method of the operation.
:param operation_dict: Operation details from OpenAPI spec
:param spec_dict: The encompassing OpenAPI specification.
:param security_requirements: A list of security requirements for the operation.
:param request_body: Request body details.
:param parameters: Parameters for the operation.
"""

path: str
Expand Down Expand Up @@ -105,8 +119,12 @@ def get_server(self, server_index: int = 0) -> str:
:returns: The server URL.
:raises ValueError: If no servers are found in the specification.
"""
servers = self.operation_dict.get("servers", []) or self.spec_dict.get(
"servers", []
# servers can be defined at the operation level, path level, or at the root level
# search for servers in the following order: operation, path, root
servers = (
self.operation_dict.get("servers", [])
or self.spec_dict.get("paths", {}).get(self.path, {}).get("servers", [])
or self.spec_dict.get("servers", [])
)
if not servers:
raise ValueError("No servers found in the provided specification.")
Expand Down Expand Up @@ -136,11 +154,7 @@ def __init__(self, spec_dict: Dict[str, Any]):
f"Invalid OpenAPI specification, expected a dictionary: {spec_dict}"
)
# just a crude sanity check, by no means a full validation
if (
"openapi" not in spec_dict
or "paths" not in spec_dict
or "servers" not in spec_dict
):
if "openapi" not in spec_dict or "paths" not in spec_dict:
raise ValueError(
"Invalid OpenAPI specification format. See https://swagger.io/specification/ for details.",
spec_dict,
Expand Down Expand Up @@ -201,51 +215,30 @@ def from_url(cls, url: str) -> "OpenAPISpecification":
) from e
return cls.from_str(content)

def find_operation_by_id(
self, op_id: str, method: Optional[str] = None
) -> Operation:
def find_operation_by_id(self, op_id: str) -> Operation:
"""
Find an Operation by operationId.

:param op_id: The operationId of the operation.
:param method: The HTTP method of the operation.
:returns: The matching operation
:raises ValueError: If no operation is found with the given operationId.
"""
for path, path_item in self.spec_dict.get("paths", {}).items():
op: Operation = self.get_operation_item(path, path_item, method)
if op_id in op.operation_dict.get("operationId", ""):
return self.get_operation_item(path, path_item, method)
raise ValueError(
f"No operation found with operationId {op_id}, method {method}"
)

def get_operation_item(
self, path: str, path_item: Dict[str, Any], method: Optional[str] = None
) -> Operation:
"""
Gets a particular Operation item from the OpenAPI specification given the path and method.

:param path: The path of the operation.
:param path_item: The path item from the OpenAPI specification.
:param method: The HTTP method of the operation.
:returns: The operation
"""
if method:
operation_dict = path_item.get(method.lower(), {})
if not operation_dict:
raise ValueError(
f"No operation found for method {method} at path {path}"
)
return Operation(path, method.lower(), operation_dict, self.spec_dict)
if len(path_item) == 1:
method, operation_dict = next(iter(path_item.items()))
return Operation(path, method, operation_dict, self.spec_dict)
if len(path_item) > 1:
raise ValueError(
f"Multiple operations found at path {path}, method parameter is required."
)
raise ValueError(f"No operations found at path {path} and method {method}")
for path, path_value in self.spec_dict.get("paths", {}).items():
operations = {
method: operation_dict
for method, operation_dict in path_value.items()
if method.lower() in VALID_HTTP_METHODS
}

for method, operation_dict in operations.items():
if (
operation_dict.get(
"operationId", path_to_operation_id(path, method)
)
== op_id
):
return Operation(path, method, operation_dict, self.spec_dict)
raise ValueError(f"No operation found with operationId {op_id}")

def get_security_schemes(self) -> Dict[str, Dict[str, Any]]:
"""
Expand Down
29 changes: 27 additions & 2 deletions test/components/tools/openapi/test_openapi_client_edge_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import pytest

from haystack_experimental.components.tools.openapi._openapi import OpenAPIServiceClient, ClientConfiguration
from haystack_experimental.components.tools.openapi._openapi import ClientConfiguration, OpenAPIServiceClient
from test.components.tools.openapi.conftest import FastAPITestClient, create_openapi_spec


Expand All @@ -26,4 +26,29 @@ def test_missing_operation_id(self, test_files_path):
with pytest.raises(ValueError, match="No operation found with operationId"):
client.invoke(payload)

# TODO: Add more tests for edge cases
def test_missing_operation_id_in_operation(self, test_files_path):
"""
Test that the tool definition is generated correctly when the operationId is missing in the specification.
"""
config = ClientConfiguration(openapi_spec=create_openapi_spec(test_files_path / "yaml" / "openapi_edge_cases.yml"),
request_sender=FastAPITestClient(None))

tools = config.get_tools_definitions(),
tool_def = tools[0][0]
assert tool_def["type"] == "function"
assert tool_def["function"]["name"] == "missing-operation-id_get"

def test_servers_order(self, test_files_path):
"""
Test that servers defined in different locations in the specification are used correctly.
"""

config = ClientConfiguration(openapi_spec=create_openapi_spec(test_files_path / "yaml" / "openapi_edge_cases.yml"),
request_sender=FastAPITestClient(None))

op = config.openapi_spec.find_operation_by_id("servers-order-path")
assert op.get_server() == "https://inpath.example.com"
op = config.openapi_spec.find_operation_by_id("servers-order-operation")
assert op.get_server() == "https://inoperation.example.com"
op = config.openapi_spec.find_operation_by_id("missing-operation-id_get")
assert op.get_server() == "http://localhost"
21 changes: 21 additions & 0 deletions test/components/tools/openapi/test_openapi_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,3 +201,24 @@ def test_run_live_cohere(self):
assert isinstance(json_response, dict)
except json.JSONDecodeError:
pytest.fail("Response content is not valid JSON")

@pytest.mark.integration
@pytest.mark.parametrize("provider", ["openai", "anthropic", "cohere"])
def test_run_live_meteo_forecast(self, provider: str):
tool = OpenAPITool(
generator_api=LLMProvider.from_str(provider),
spec="https://raw.githubusercontent.com/open-meteo/open-meteo/main/openapi.yml"
)
results = tool.run(messages=[ChatMessage.from_user(
"weather forecast for latitude 52.52 and longitude 13.41 and set hourly=temperature_2m")])

assert isinstance(results["service_response"], list)
assert len(results["service_response"]) == 1
assert isinstance(results["service_response"][0], ChatMessage)

try:
json_response = json.loads(results["service_response"][0].content)
assert isinstance(json_response, dict)
assert "hourly" in json_response
except json.JSONDecodeError:
pytest.fail("Response content is not valid JSON")
38 changes: 38 additions & 0 deletions test/test_files/yaml/openapi_edge_cases.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,44 @@ paths:
/missing-operation-id:
get:
summary: Missing operationId
parameters:
- name: name
in: path
required: true
schema:
type: string
responses:
'200':
description: OK

/servers-order-in-path:
servers:
- url: https://inpath.example.com
get:
summary: Servers order
operationId: servers-order-path
parameters:
- name: name
in: path
required: true
schema:
type: string
responses:
'200':
description: OK

/servers-order-in-operation:
get:
summary: Servers order
operationId: servers-order-operation
parameters:
- name: name
in: path
required: true
schema:
type: string
responses:
'200':
description: OK
servers:
- url: https://inoperation.example.com
Loading