diff --git a/haystack_experimental/components/tools/openapi/_payload_extraction.py b/haystack_experimental/components/tools/openapi/_payload_extraction.py index 6247c56a..61bb1abf 100644 --- a/haystack_experimental/components/tools/openapi/_payload_extraction.py +++ b/haystack_experimental/components/tools/openapi/_payload_extraction.py @@ -4,7 +4,7 @@ import dataclasses import json -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union, cast def create_function_payload_extractor( @@ -68,7 +68,8 @@ def _search(payload: Any, arguments_field_name: str) -> Dict[str, Any]: if dict_converter := _get_dict_converter(payload): payload = dict_converter() elif dataclasses.is_dataclass(payload): - payload = dataclasses.asdict(payload) + # Cast payload to Any to satisfy mypy 1.11.0 + payload = dataclasses.asdict(cast(Any, payload)) if isinstance(payload, dict): if all(field in payload for field in _required_fields(arguments_field_name)): # this is the payload we are looking for diff --git a/haystack_experimental/components/tools/openapi/_schema_conversion.py b/haystack_experimental/components/tools/openapi/_schema_conversion.py index 4a986752..35c8c36b 100644 --- a/haystack_experimental/components/tools/openapi/_schema_conversion.py +++ b/haystack_experimental/components/tools/openapi/_schema_conversion.py @@ -5,19 +5,12 @@ import logging from typing import Any, Callable, Dict, List, Optional -from haystack.lazy_imports import LazyImport - from haystack_experimental.components.tools.openapi.types import ( VALID_HTTP_METHODS, OpenAPISpecification, path_to_operation_id, ) -with LazyImport("Run 'pip install jsonref'") as jsonref_import: - # pylint: disable=import-error - import jsonref - - MIN_REQUIRED_OPENAPI_SPEC_VERSION = 3 logger = logging.getLogger(__name__) @@ -31,10 +24,8 @@ def openai_converter(schema: OpenAPISpecification) -> List[Dict[str, Any]]: :param schema: The OpenAPI specification to convert. :returns: A list of dictionaries, each dictionary representing an OpenAI function definition. """ - jsonref_import.check() - resolved_schema = jsonref.replace_refs(schema.spec_dict) fn_definitions = _openapi_to_functions( - resolved_schema, "parameters", _parse_endpoint_spec_openai + schema.spec_dict, "parameters", _parse_endpoint_spec_openai ) return [{"type": "function", "function": fn} for fn in fn_definitions] @@ -48,10 +39,9 @@ def anthropic_converter(schema: OpenAPISpecification) -> List[Dict[str, Any]]: :param schema: The OpenAPI specification to convert. :returns: A list of dictionaries, each dictionary representing Anthropic function definition. """ - jsonref_import.check() - resolved_schema = jsonref.replace_refs(schema.spec_dict) + return _openapi_to_functions( - resolved_schema, "input_schema", _parse_endpoint_spec_openai + schema.spec_dict, "input_schema", _parse_endpoint_spec_openai ) @@ -64,10 +54,8 @@ def cohere_converter(schema: OpenAPISpecification) -> List[Dict[str, Any]]: :param schema: The OpenAPI specification to convert. :returns: A list of dictionaries, each representing a Cohere style function definition. """ - jsonref_import.check() - resolved_schema = jsonref.replace_refs(schema.spec_dict) return _openapi_to_functions( - resolved_schema, "not important for cohere", _parse_endpoint_spec_cohere + schema.spec_dict,"not important for cohere",_parse_endpoint_spec_cohere ) diff --git a/haystack_experimental/components/tools/openapi/types.py b/haystack_experimental/components/tools/openapi/types.py index 77e22421..3d9f23a2 100644 --- a/haystack_experimental/components/tools/openapi/types.py +++ b/haystack_experimental/components/tools/openapi/types.py @@ -10,6 +10,12 @@ import requests import yaml +from haystack.lazy_imports import LazyImport + +with LazyImport("Run 'pip install jsonref'") as jsonref_import: + # pylint: disable=import-error + import jsonref + VALID_HTTP_METHODS = [ "get", @@ -149,6 +155,7 @@ def __init__(self, spec_dict: Dict[str, Any]): :param spec_dict: The OpenAPI specification as a dictionary. """ + jsonref_import.check() if not isinstance(spec_dict, Dict): raise ValueError( f"Invalid OpenAPI specification, expected a dictionary: {spec_dict}" @@ -159,7 +166,7 @@ def __init__(self, spec_dict: Dict[str, Any]): "Invalid OpenAPI specification format. See https://swagger.io/specification/ for details.", spec_dict, ) - self.spec_dict = spec_dict + self.spec_dict = jsonref.replace_refs(spec_dict) @classmethod def from_str(cls, content: str) -> "OpenAPISpecification": diff --git a/test/test_files/yaml/openapi_greeting_service.yml b/test/test_files/yaml/openapi_greeting_service.yml index 701dee33..1ba40381 100644 --- a/test/test_files/yaml/openapi_greeting_service.yml +++ b/test/test_files/yaml/openapi_greeting_service.yml @@ -9,11 +9,7 @@ paths: post: operationId: greet parameters: - - name: name - in: path - required: true - schema: - type: string + - $ref: '#/components/parameters/NameParameter' requestBody: required: true content: @@ -238,6 +234,13 @@ components: tokenUrl: https://example.com/oauth/token scopes: read:greet: Read access to greeting service + parameters: + NameParameter: + name: name + in: path + required: true + schema: + type: string schemas: GreetBody: @@ -269,4 +272,4 @@ components: type: object properties: detail: - type: string \ No newline at end of file + type: string