Skip to content

Commit

Permalink
PR feedback - details
Browse files Browse the repository at this point in the history
  • Loading branch information
vblagoje committed Jun 18, 2024
1 parent df71f3a commit bfdcf94
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 16 deletions.
2 changes: 1 addition & 1 deletion haystack_experimental/components/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@

from .openai.function_caller import OpenAIFunctionCaller

_all_ = ["OpenAIFunctionCaller"]
_all_ = ["OpenAIFunctionCaller"]
22 changes: 11 additions & 11 deletions haystack_experimental/components/tools/openapi/_openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def send_request(request: Dict[str, Any]) -> Dict[str, Any]:


# Authentication strategies
def create_api_key_auth_function(api_key: str):
def create_api_key_auth_function(api_key: str) -> Callable[[Dict[str, Any], Dict[str, Any]], None]:
"""
Create a function that applies the API key authentication strategy to a given request.
Expand All @@ -78,7 +78,7 @@ def create_api_key_auth_function(api_key: str):
at the schema specified location.
"""

def apply_auth(security_scheme: Dict[str, Any], request: Dict[str, Any]):
def apply_auth(security_scheme: Dict[str, Any], request: Dict[str, Any]) -> None:
"""
Apply the API key authentication strategy to the given request.
Expand All @@ -100,7 +100,7 @@ def apply_auth(security_scheme: Dict[str, Any], request: Dict[str, Any]):
return apply_auth


def create_http_auth_function(token: str):
def create_http_auth_function(token: str) -> Callable[[Dict[str, Any], Dict[str, Any]], None]:
"""
Create a function that applies the http authentication strategy to a given request.
Expand All @@ -109,7 +109,7 @@ def create_http_auth_function(token: str):
at the schema specified location.
"""

def apply_auth(security_scheme: Dict[str, Any], request: Dict[str, Any]):
def apply_auth(security_scheme: Dict[str, Any], request: Dict[str, Any]) -> None:
"""
Apply the HTTP authentication strategy to the given request.
Expand Down Expand Up @@ -165,7 +165,7 @@ def __init__( # noqa: PLR0913 pylint: disable=too-many-arguments
if is_valid_http_url(openapi_spec):
self.openapi_spec = OpenAPISpecification.from_url(openapi_spec)
else:
self.openapi_spec = OpenAPISpecification._from_str(openapi_spec)
self.openapi_spec = OpenAPISpecification.from_str(openapi_spec)
else:
raise ValueError(
"Invalid OpenAPI specification format. Expected file path or dictionary."
Expand Down Expand Up @@ -203,11 +203,11 @@ def get_tools_definitions(self) -> List[Dict[str, Any]]:
provider_to_converter = defaultdict(
lambda: openai_converter,
{
LLMProvider.ANTHROPIC.value: anthropic_converter,
LLMProvider.COHERE.value: cohere_converter,
LLMProvider.ANTHROPIC: anthropic_converter,
LLMProvider.COHERE: cohere_converter,
}
)
converter = provider_to_converter[self.llm_provider.value]
converter = provider_to_converter[self.llm_provider]
return converter(self.openapi_spec)

def get_payload_extractor(self) -> Callable[[Dict[str, Any]], Dict[str, Any]]:
Expand All @@ -220,11 +220,11 @@ def get_payload_extractor(self) -> Callable[[Dict[str, Any]], Dict[str, Any]]:
provider_to_arguments_field_name = defaultdict(
lambda: "arguments",
{
LLMProvider.ANTHROPIC.value: "input",
LLMProvider.COHERE.value: "parameters",
LLMProvider.ANTHROPIC: "input",
LLMProvider.COHERE: "parameters",
}
)
arguments_field_name = provider_to_arguments_field_name[self.llm_provider.value]
arguments_field_name = provider_to_arguments_field_name[self.llm_provider]
return create_function_payload_extractor(arguments_field_name)

def _create_authentication_from_string(
Expand Down
6 changes: 3 additions & 3 deletions haystack_experimental/components/tools/openapi/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def __init__(self, spec_dict: Dict[str, Any]):
self.spec_dict = spec_dict

@classmethod
def _from_str(cls, content: str) -> "OpenAPISpecification":
def from_str(cls, content: str) -> "OpenAPISpecification":
"""
Create an OpenAPISpecification instance from a string.
Expand Down Expand Up @@ -165,7 +165,7 @@ def from_file(cls, spec_file: Union[str, Path]) -> "OpenAPISpecification":
"""
with open(spec_file, encoding="utf-8") as file:
content = file.read()
return cls._from_str(content)
return cls.from_str(content)

@classmethod
def from_url(cls, url: str) -> "OpenAPISpecification":
Expand All @@ -184,7 +184,7 @@ def from_url(cls, url: str) -> "OpenAPISpecification":
raise ConnectionError(
f"Failed to fetch the specification from URL: {url}. {e!s}"
) from e
return cls._from_str(content)
return cls.from_str(content)

def find_operation_by_id(
self, op_id: str, method: Optional[str] = None
Expand Down
2 changes: 1 addition & 1 deletion test/components/tools/openapi/test_openapi_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def test_initialized_from_string(self):
'200':
description: Successful response
"""
openapi_spec = OpenAPISpecification._from_str(content)
openapi_spec = OpenAPISpecification.from_str(content)
assert openapi_spec.spec_dict == {
"openapi": "3.0.0",
"info": {"title": "Test API", "version": "1.0.0"},
Expand Down

0 comments on commit bfdcf94

Please sign in to comment.