-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
206 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,6 @@ | ||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]> | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
from haystack_experimental.components.tools.openapi.openapi import OpenAPITool | ||
|
||
__all__ = ["OpenAPITool"] |
85 changes: 85 additions & 0 deletions
85
haystack_experimental/components/tools/openapi/generator_factory.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
import importlib | ||
import re | ||
from dataclasses import dataclass | ||
from enum import Enum | ||
from typing import Any, Optional, Tuple | ||
|
||
|
||
class LLMProvider(Enum): | ||
OPENAI = "openai" | ||
ANTHROPIC = "anthropic" | ||
COHERE = "cohere" | ||
|
||
|
||
PROVIDER_DETAILS = { | ||
LLMProvider.OPENAI: { | ||
"class_path": "haystack.components.generators.chat.openai.OpenAIChatGenerator", | ||
"patterns": [re.compile(r"^gpt.*")], | ||
}, | ||
LLMProvider.ANTHROPIC: { | ||
"class_path": "haystack_integrations.components.generators.anthropic.AnthropicChatGenerator", | ||
"patterns": [re.compile(r"^claude.*")], | ||
}, | ||
LLMProvider.COHERE: { | ||
"class_path": "haystack_integrations.components.generators.cohere.CohereChatGenerator", | ||
"patterns": [re.compile(r"^command-r.*")], | ||
}, | ||
} | ||
|
||
|
||
def load_class(full_class_path: str): | ||
""" | ||
Load a class from a string representation of its path e.g. "module.submodule.class_name" | ||
""" | ||
module_path, _, class_name = full_class_path.rpartition(".") | ||
module = importlib.import_module(module_path) | ||
return getattr(module, class_name) | ||
|
||
|
||
@dataclass | ||
class LLMIdentifier: | ||
provider: LLMProvider | ||
model_name: str | ||
|
||
def __post_init__(self): | ||
if not isinstance(self.provider, LLMProvider): | ||
raise ValueError(f"Invalid provider: {self.provider}") | ||
|
||
if not isinstance(self.model_name, str): | ||
raise ValueError(f"Model name must be a string: {self.model_name}") | ||
|
||
details = PROVIDER_DETAILS.get(self.provider) | ||
if not details or not any( | ||
pattern.match(self.model_name) for pattern in details["patterns"] | ||
): | ||
raise ValueError( | ||
f"Invalid combination of provider {self.provider} and model name {self.model_name}" | ||
) | ||
|
||
|
||
def create_generator( | ||
model_name: str, provider: Optional[str] = None, **model_kwargs | ||
) -> Tuple[LLMIdentifier, Any]: | ||
""" | ||
Create ChatGenerator instance based on the model name and provider. | ||
""" | ||
if provider: | ||
try: | ||
provider_enum = LLMProvider[provider.lower()] | ||
except KeyError: | ||
raise ValueError(f"Invalid provider: {provider}") | ||
else: | ||
provider_enum = None | ||
for prov, details in PROVIDER_DETAILS.items(): | ||
if any(pattern.match(model_name) for pattern in details["patterns"]): | ||
provider_enum = prov | ||
break | ||
|
||
if provider_enum is None: | ||
raise ValueError(f"Could not infer provider for model name: {model_name}") | ||
|
||
llm_identifier = LLMIdentifier(provider=provider_enum, model_name=model_name) | ||
class_path = PROVIDER_DETAILS[llm_identifier.provider]["class_path"] | ||
return llm_identifier, load_class(class_path)( | ||
model=llm_identifier.model_name, **model_kwargs | ||
) |
118 changes: 118 additions & 0 deletions
118
haystack_experimental/components/tools/openapi/openapi.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,121 @@ | ||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]> | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
import json | ||
from pathlib import Path | ||
from typing import Any, Dict, List, Optional, Union | ||
|
||
from haystack import component, default_from_dict, default_to_dict, logging | ||
from haystack.dataclasses import ChatMessage, ChatRole | ||
|
||
from haystack_experimental.components.tools.openapi.generator_factory import ( | ||
create_generator, | ||
) | ||
from haystack_experimental.util.openapi import ClientConfiguration, OpenAPIServiceClient | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
@component | ||
class OpenAPITool: | ||
""" | ||
The OpenAPITool calls an OpenAPI service using payloads generated by the chat generator from human instructions. | ||
Here is an example of how to use the OpenAPITool component to scrape a URL using the FireCrawl API: | ||
```python | ||
from haystack.components.tools import OpenAPITool | ||
from haystack.components.generators.chat.openai import OpenAIChatGenerator | ||
from haystack.dataclasses import ChatMessage | ||
tool = OpenAPITool(llm_provider="openai", model="gpt-3.5-turbo", | ||
tool_spec="https://raw.githubusercontent.com/mendableai/firecrawl/main/apps/api/openapi.json", | ||
tool_credentials="<your-tool-token>") | ||
results = tool.run(messages=[ChatMessage.from_user("Scrape URL: https://news.ycombinator.com/")]) | ||
print(results) | ||
``` | ||
Similarly, you can use the OpenAPITool component to use any OpenAPI service/tool by providing the OpenAPI | ||
specification and credentials. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
model: str, | ||
tool_spec: Optional[Union[str, Path]] = None, | ||
tool_credentials: Optional[Union[str, Dict[str, Any]]] = None, | ||
): | ||
self.llm_id, self.chat_generator = create_generator(model) | ||
self.config_openapi = ( | ||
ClientConfiguration( | ||
openapi_spec=tool_spec, | ||
credentials=tool_credentials, | ||
llm_provider=self.llm_id.provider.value, | ||
) | ||
if tool_spec | ||
else None | ||
) | ||
|
||
@component.output_types(service_response=List[ChatMessage]) | ||
def run( | ||
self, | ||
messages: List[ChatMessage], | ||
fc_generator_kwargs: Optional[Dict[str, Any]] = None, | ||
tool_spec: Optional[Union[str, Path, Dict[str, Any]]] = None, | ||
tool_credentials: Optional[Union[dict, str]] = None, | ||
) -> Dict[str, List[ChatMessage]]: | ||
""" | ||
Invokes the underlying OpenAPI service/tool with the function calling payload generated by the chat generator. | ||
:param messages: List of ChatMessages to generate function calling payload (e.g. human instructions). | ||
:param fc_generator_kwargs: Additional arguments for the function calling payload generation process. | ||
:param tool_spec: OpenAPI specification for the tool/service. | ||
:param tool_credentials: Credentials for the tool/service. | ||
:returns: a dictionary containing the service response with the following key: | ||
- `service_response`: List of ChatMessages containing the service response. | ||
""" | ||
last_message = messages[-1] | ||
if not last_message.is_from(ChatRole.USER): | ||
raise ValueError(f"{last_message} not from the user") | ||
if not last_message.content: | ||
raise ValueError("Function calling instruction message content is empty.") | ||
|
||
# build a new ClientConfiguration if a runtime tool_spec is provided | ||
config_openapi = ( | ||
ClientConfiguration( | ||
openapi_spec=tool_spec, | ||
credentials=tool_credentials, | ||
llm_provider=self.llm_id.provider.value, | ||
) | ||
if tool_spec | ||
else self.config_openapi | ||
) | ||
|
||
if not config_openapi: | ||
raise ValueError( | ||
"OpenAPI specification not provided. Please provide an OpenAPI specification either at initialization " | ||
"or during runtime." | ||
) | ||
# merge fc_generator_kwargs, tools definitions comes from the OpenAPI spec, other kwargs are passed by the user | ||
fc_generator_kwargs = { | ||
"tools": config_openapi.get_tools_definitions(), | ||
**(fc_generator_kwargs or {}), | ||
} | ||
|
||
# generate function calling payload with the chat generator | ||
logger.debug( | ||
f"Invoking chat generator with {last_message.content} to generate function calling payload." | ||
) | ||
fc_payload = self.chat_generator.run(messages, fc_generator_kwargs) | ||
|
||
openapi_service = OpenAPIServiceClient(config_openapi) | ||
try: | ||
invocation_payload = json.loads(fc_payload["replies"][0].content) | ||
logger.debug(f"Invoking tool with {invocation_payload}") | ||
service_response = openapi_service.invoke(invocation_payload) | ||
except Exception as e: | ||
logger.error(f"Error invoking OpenAPI endpoint. Error: {e}") | ||
service_response = {"error": str(e)} | ||
response_messages = [ChatMessage.from_user(json.dumps(service_response))] | ||
return {"service_response": response_messages} |