Skip to content

Commit

Permalink
Add OpenAPITool initial impl
Browse files Browse the repository at this point in the history
  • Loading branch information
vblagoje committed Jun 5, 2024
1 parent 83e8f88 commit edf18e3
Show file tree
Hide file tree
Showing 3 changed files with 206 additions and 0 deletions.
3 changes: 3 additions & 0 deletions haystack_experimental/components/tools/openapi/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
from haystack_experimental.components.tools.openapi.openapi import OpenAPITool

__all__ = ["OpenAPITool"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import importlib
import re
from dataclasses import dataclass
from enum import Enum
from typing import Any, Optional, Tuple


class LLMProvider(Enum):
OPENAI = "openai"
ANTHROPIC = "anthropic"
COHERE = "cohere"


PROVIDER_DETAILS = {
LLMProvider.OPENAI: {
"class_path": "haystack.components.generators.chat.openai.OpenAIChatGenerator",
"patterns": [re.compile(r"^gpt.*")],
},
LLMProvider.ANTHROPIC: {
"class_path": "haystack_integrations.components.generators.anthropic.AnthropicChatGenerator",
"patterns": [re.compile(r"^claude.*")],
},
LLMProvider.COHERE: {
"class_path": "haystack_integrations.components.generators.cohere.CohereChatGenerator",
"patterns": [re.compile(r"^command-r.*")],
},
}


def load_class(full_class_path: str):
"""
Load a class from a string representation of its path e.g. "module.submodule.class_name"
"""
module_path, _, class_name = full_class_path.rpartition(".")
module = importlib.import_module(module_path)
return getattr(module, class_name)


@dataclass
class LLMIdentifier:
provider: LLMProvider
model_name: str

def __post_init__(self):
if not isinstance(self.provider, LLMProvider):
raise ValueError(f"Invalid provider: {self.provider}")

if not isinstance(self.model_name, str):
raise ValueError(f"Model name must be a string: {self.model_name}")

details = PROVIDER_DETAILS.get(self.provider)
if not details or not any(
pattern.match(self.model_name) for pattern in details["patterns"]
):
raise ValueError(
f"Invalid combination of provider {self.provider} and model name {self.model_name}"
)


def create_generator(
model_name: str, provider: Optional[str] = None, **model_kwargs
) -> Tuple[LLMIdentifier, Any]:
"""
Create ChatGenerator instance based on the model name and provider.
"""
if provider:
try:
provider_enum = LLMProvider[provider.lower()]
except KeyError:
raise ValueError(f"Invalid provider: {provider}")
else:
provider_enum = None
for prov, details in PROVIDER_DETAILS.items():
if any(pattern.match(model_name) for pattern in details["patterns"]):
provider_enum = prov
break

if provider_enum is None:
raise ValueError(f"Could not infer provider for model name: {model_name}")

llm_identifier = LLMIdentifier(provider=provider_enum, model_name=model_name)
class_path = PROVIDER_DETAILS[llm_identifier.provider]["class_path"]
return llm_identifier, load_class(class_path)(
model=llm_identifier.model_name, **model_kwargs
)
118 changes: 118 additions & 0 deletions haystack_experimental/components/tools/openapi/openapi.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,121 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
import json
from pathlib import Path
from typing import Any, Dict, List, Optional, Union

from haystack import component, default_from_dict, default_to_dict, logging
from haystack.dataclasses import ChatMessage, ChatRole

from haystack_experimental.components.tools.openapi.generator_factory import (
create_generator,
)
from haystack_experimental.util.openapi import ClientConfiguration, OpenAPIServiceClient

logger = logging.getLogger(__name__)


@component
class OpenAPITool:
"""
The OpenAPITool calls an OpenAPI service using payloads generated by the chat generator from human instructions.
Here is an example of how to use the OpenAPITool component to scrape a URL using the FireCrawl API:
```python
from haystack.components.tools import OpenAPITool
from haystack.components.generators.chat.openai import OpenAIChatGenerator
from haystack.dataclasses import ChatMessage
tool = OpenAPITool(llm_provider="openai", model="gpt-3.5-turbo",
tool_spec="https://raw.githubusercontent.com/mendableai/firecrawl/main/apps/api/openapi.json",
tool_credentials="<your-tool-token>")
results = tool.run(messages=[ChatMessage.from_user("Scrape URL: https://news.ycombinator.com/")])
print(results)
```
Similarly, you can use the OpenAPITool component to use any OpenAPI service/tool by providing the OpenAPI
specification and credentials.
"""

def __init__(
self,
model: str,
tool_spec: Optional[Union[str, Path]] = None,
tool_credentials: Optional[Union[str, Dict[str, Any]]] = None,
):
self.llm_id, self.chat_generator = create_generator(model)
self.config_openapi = (
ClientConfiguration(
openapi_spec=tool_spec,
credentials=tool_credentials,
llm_provider=self.llm_id.provider.value,
)
if tool_spec
else None
)

@component.output_types(service_response=List[ChatMessage])
def run(
self,
messages: List[ChatMessage],
fc_generator_kwargs: Optional[Dict[str, Any]] = None,
tool_spec: Optional[Union[str, Path, Dict[str, Any]]] = None,
tool_credentials: Optional[Union[dict, str]] = None,
) -> Dict[str, List[ChatMessage]]:
"""
Invokes the underlying OpenAPI service/tool with the function calling payload generated by the chat generator.
:param messages: List of ChatMessages to generate function calling payload (e.g. human instructions).
:param fc_generator_kwargs: Additional arguments for the function calling payload generation process.
:param tool_spec: OpenAPI specification for the tool/service.
:param tool_credentials: Credentials for the tool/service.
:returns: a dictionary containing the service response with the following key:
- `service_response`: List of ChatMessages containing the service response.
"""
last_message = messages[-1]
if not last_message.is_from(ChatRole.USER):
raise ValueError(f"{last_message} not from the user")
if not last_message.content:
raise ValueError("Function calling instruction message content is empty.")

# build a new ClientConfiguration if a runtime tool_spec is provided
config_openapi = (
ClientConfiguration(
openapi_spec=tool_spec,
credentials=tool_credentials,
llm_provider=self.llm_id.provider.value,
)
if tool_spec
else self.config_openapi
)

if not config_openapi:
raise ValueError(
"OpenAPI specification not provided. Please provide an OpenAPI specification either at initialization "
"or during runtime."
)
# merge fc_generator_kwargs, tools definitions comes from the OpenAPI spec, other kwargs are passed by the user
fc_generator_kwargs = {
"tools": config_openapi.get_tools_definitions(),
**(fc_generator_kwargs or {}),
}

# generate function calling payload with the chat generator
logger.debug(
f"Invoking chat generator with {last_message.content} to generate function calling payload."
)
fc_payload = self.chat_generator.run(messages, fc_generator_kwargs)

openapi_service = OpenAPIServiceClient(config_openapi)
try:
invocation_payload = json.loads(fc_payload["replies"][0].content)
logger.debug(f"Invoking tool with {invocation_payload}")
service_response = openapi_service.invoke(invocation_payload)
except Exception as e:
logger.error(f"Error invoking OpenAPI endpoint. Error: {e}")
service_response = {"error": str(e)}
response_messages = [ChatMessage.from_user(json.dumps(service_response))]
return {"service_response": response_messages}

0 comments on commit edf18e3

Please sign in to comment.