Skip to content

Commit

Permalink
fix: Centralize OpenAPI schema reference resolution (#40)
Browse files Browse the repository at this point in the history
* Simplify jsonref_import check, spec has replaced refs by default

* Fix mypy issue (new with mypy upgrade)

* Use parameter reference in yaml as additional test for schema ref resolution
  • Loading branch information
vblagoje authored Jul 22, 2024
1 parent e5841a7 commit 3258230
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import dataclasses
import json
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Union, cast


def create_function_payload_extractor(
Expand Down Expand Up @@ -68,7 +68,8 @@ def _search(payload: Any, arguments_field_name: str) -> Dict[str, Any]:
if dict_converter := _get_dict_converter(payload):
payload = dict_converter()
elif dataclasses.is_dataclass(payload):
payload = dataclasses.asdict(payload)
# Cast payload to Any to satisfy mypy 1.11.0
payload = dataclasses.asdict(cast(Any, payload))
if isinstance(payload, dict):
if all(field in payload for field in _required_fields(arguments_field_name)):
# this is the payload we are looking for
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,12 @@
import logging
from typing import Any, Callable, Dict, List, Optional

from haystack.lazy_imports import LazyImport

from haystack_experimental.components.tools.openapi.types import (
VALID_HTTP_METHODS,
OpenAPISpecification,
path_to_operation_id,
)

with LazyImport("Run 'pip install jsonref'") as jsonref_import:
# pylint: disable=import-error
import jsonref


MIN_REQUIRED_OPENAPI_SPEC_VERSION = 3

logger = logging.getLogger(__name__)
Expand All @@ -31,10 +24,8 @@ def openai_converter(schema: OpenAPISpecification) -> List[Dict[str, Any]]:
:param schema: The OpenAPI specification to convert.
:returns: A list of dictionaries, each dictionary representing an OpenAI function definition.
"""
jsonref_import.check()
resolved_schema = jsonref.replace_refs(schema.spec_dict)
fn_definitions = _openapi_to_functions(
resolved_schema, "parameters", _parse_endpoint_spec_openai
schema.spec_dict, "parameters", _parse_endpoint_spec_openai
)
return [{"type": "function", "function": fn} for fn in fn_definitions]

Expand All @@ -48,10 +39,9 @@ def anthropic_converter(schema: OpenAPISpecification) -> List[Dict[str, Any]]:
:param schema: The OpenAPI specification to convert.
:returns: A list of dictionaries, each dictionary representing Anthropic function definition.
"""
jsonref_import.check()
resolved_schema = jsonref.replace_refs(schema.spec_dict)

return _openapi_to_functions(
resolved_schema, "input_schema", _parse_endpoint_spec_openai
schema.spec_dict, "input_schema", _parse_endpoint_spec_openai
)


Expand All @@ -64,10 +54,8 @@ def cohere_converter(schema: OpenAPISpecification) -> List[Dict[str, Any]]:
:param schema: The OpenAPI specification to convert.
:returns: A list of dictionaries, each representing a Cohere style function definition.
"""
jsonref_import.check()
resolved_schema = jsonref.replace_refs(schema.spec_dict)
return _openapi_to_functions(
resolved_schema, "not important for cohere", _parse_endpoint_spec_cohere
schema.spec_dict,"not important for cohere",_parse_endpoint_spec_cohere
)


Expand Down
9 changes: 8 additions & 1 deletion haystack_experimental/components/tools/openapi/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@

import requests
import yaml
from haystack.lazy_imports import LazyImport

with LazyImport("Run 'pip install jsonref'") as jsonref_import:
# pylint: disable=import-error
import jsonref


VALID_HTTP_METHODS = [
"get",
Expand Down Expand Up @@ -149,6 +155,7 @@ def __init__(self, spec_dict: Dict[str, Any]):
:param spec_dict: The OpenAPI specification as a dictionary.
"""
jsonref_import.check()
if not isinstance(spec_dict, Dict):
raise ValueError(
f"Invalid OpenAPI specification, expected a dictionary: {spec_dict}"
Expand All @@ -159,7 +166,7 @@ def __init__(self, spec_dict: Dict[str, Any]):
"Invalid OpenAPI specification format. See https://swagger.io/specification/ for details.",
spec_dict,
)
self.spec_dict = spec_dict
self.spec_dict = jsonref.replace_refs(spec_dict)

@classmethod
def from_str(cls, content: str) -> "OpenAPISpecification":
Expand Down
15 changes: 9 additions & 6 deletions test/test_files/yaml/openapi_greeting_service.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,7 @@ paths:
post:
operationId: greet
parameters:
- name: name
in: path
required: true
schema:
type: string
- $ref: '#/components/parameters/NameParameter'
requestBody:
required: true
content:
Expand Down Expand Up @@ -238,6 +234,13 @@ components:
tokenUrl: https://example.com/oauth/token
scopes:
read:greet: Read access to greeting service
parameters:
NameParameter:
name: name
in: path
required: true
schema:
type: string

schemas:
GreetBody:
Expand Down Expand Up @@ -269,4 +272,4 @@ components:
type: object
properties:
detail:
type: string
type: string

0 comments on commit 3258230

Please sign in to comment.