diff --git a/haystack_experimental/components/tools/openapi/generator_factory.py b/haystack_experimental/components/tools/openapi/generator_factory.py index a0d4c36e..0fb992ab 100644 --- a/haystack_experimental/components/tools/openapi/generator_factory.py +++ b/haystack_experimental/components/tools/openapi/generator_factory.py @@ -5,88 +5,101 @@ import importlib import re from dataclasses import dataclass -from enum import Enum -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple -class LLMProvider(Enum): - """ - Enum for different LLM providers - """ - OPENAI = "openai" - ANTHROPIC = "anthropic" - COHERE = "cohere" - - -PROVIDER_DETAILS: Dict[LLMProvider, Dict[str, Any]] = { - LLMProvider.OPENAI: { - "class_path": "haystack.components.generators.chat.openai.OpenAIChatGenerator", - "patterns": [re.compile(r"^gpt.*")], - }, - LLMProvider.ANTHROPIC: { - "class_path": "haystack_integrations.components.generators.anthropic.AnthropicChatGenerator", - "patterns": [re.compile(r"^claude.*")], - }, - LLMProvider.COHERE: { - "class_path": "haystack_integrations.components.generators.cohere.CohereChatGenerator", - "patterns": [re.compile(r"^command-r.*")], - }, -} - - -def load_class(full_class_path: str): +@dataclass +class ChatGeneratorDescriptor: """ - Load a class from a string representation of its path e.g. "module.submodule.class_name" + Dataclass to describe a Chat Generator """ - module_path, _, class_name = full_class_path.rpartition(".") - module = importlib.import_module(module_path) - return getattr(module, class_name) - -@dataclass -class LLMIdentifier: - """ Dataclass to hold the LLM provider and model name""" - provider: LLMProvider + class_path: str + patterns: List[re.Pattern] + name: str model_name: str - def __post_init__(self): - if not isinstance(self.provider, LLMProvider): - raise ValueError(f"Invalid provider: {self.provider}") - - if not isinstance(self.model_name, str): - raise ValueError(f"Model name must be a string: {self.model_name}") - details = PROVIDER_DETAILS.get(self.provider) - if not details or not any( - pattern.match(self.model_name) for pattern in details["patterns"] - ): - raise ValueError( - f"Invalid combination of provider {self.provider} and model name {self.model_name}" - ) - - -def create_generator( - model_name: str, provider: Optional[str] = None, **model_kwargs -) -> Tuple[LLMIdentifier, Any]: +class ChatGeneratorDescriptorManager: """ - Create ChatGenerator instance based on the model name and provider. + Class to manage Chat Generator Descriptors """ - provider_enum = None - if provider: - if provider.lower() not in LLMProvider.__members__: - raise ValueError(f"Invalid provider: {provider}") - provider_enum = LLMProvider[provider.lower()] - else: - for prov, details in PROVIDER_DETAILS.items(): - if any(pattern.match(model_name) for pattern in details["patterns"]): - provider_enum = prov - break - - if provider_enum is None: - raise ValueError(f"Could not infer provider for model name: {model_name}") - llm_identifier = LLMIdentifier(provider=provider_enum, model_name=model_name) - class_path = PROVIDER_DETAILS[llm_identifier.provider]["class_path"] - return llm_identifier, load_class(class_path)( - model=llm_identifier.model_name, **model_kwargs - ) + def __init__(self): + self._descriptors: Dict[str, ChatGeneratorDescriptor] = {} + self._register_default_descriptors() + + def _register_default_descriptors(self): + """ + Register default Chat Generator Descriptors. + """ + default_descriptors = [ + ChatGeneratorDescriptor( + class_path="haystack.components.generators.chat.openai.OpenAIChatGenerator", + patterns=[re.compile(r"^gpt.*")], + name="openai", + model_name="gpt-3.5-turbo", + ), + ChatGeneratorDescriptor( + class_path="haystack_integrations.components.generators.anthropic.AnthropicChatGenerator", + patterns=[re.compile(r"^claude.*")], + name="anthropic", + model_name="claude-1", + ), + ChatGeneratorDescriptor( + class_path="haystack_integrations.components.generators.cohere.CohereChatGenerator", + patterns=[re.compile(r"^command-r.*")], + name="cohere", + model_name="command-r", + ), + ] + + for descriptor in default_descriptors: + self.register_descriptor(descriptor) + + def _load_class(self, full_class_path: str): + """ + Load a class from a string representation of its path e.g. "module.submodule.class_name" + """ + module_path, _, class_name = full_class_path.rpartition(".") + module = importlib.import_module(module_path) + return getattr(module, class_name) + + def register_descriptor(self, descriptor: ChatGeneratorDescriptor): + """ + Register a new Chat Generator Descriptor. + """ + if descriptor.name in self._descriptors: + raise ValueError(f"Descriptor {descriptor.name} already exists.") + + self._descriptors[descriptor.name] = descriptor + + def _infer_descriptor(self, model_name: str) -> Optional[ChatGeneratorDescriptor]: + """ + Infer the descriptor based on the model name. + """ + for descriptor in self._descriptors.values(): + if any(pattern.match(model_name) for pattern in descriptor.patterns): + return descriptor + return None + + def create_generator( + self, model_name: str, descriptor_name: Optional[str] = None, **model_kwargs + ) -> Tuple[ChatGeneratorDescriptor, Any]: + """ + Create ChatGenerator instance based on the model name and descriptor. + """ + if descriptor_name: + descriptor = self._descriptors.get(descriptor_name) + if not descriptor: + raise ValueError(f"Invalid descriptor name: {descriptor_name}") + else: + descriptor = self._infer_descriptor(model_name) + if not descriptor: + raise ValueError( + f"Could not infer descriptor for model name: {model_name}" + ) + + return descriptor, self._load_class(descriptor.class_path)( + model=model_name, **model_kwargs + ) diff --git a/haystack_experimental/components/tools/openapi/openapi_tool.py b/haystack_experimental/components/tools/openapi/openapi_tool.py index 0b2e1ac7..dc232941 100644 --- a/haystack_experimental/components/tools/openapi/openapi_tool.py +++ b/haystack_experimental/components/tools/openapi/openapi_tool.py @@ -9,7 +9,9 @@ from haystack import component, logging from haystack.dataclasses import ChatMessage, ChatRole -from haystack_experimental.components.tools.openapi.generator_factory import create_generator +from haystack_experimental.components.tools.openapi.generator_factory import ( + ChatGeneratorDescriptorManager, +) from haystack_experimental.components.tools.openapi.openapi import ( ClientConfiguration, OpenAPIServiceClient, @@ -48,16 +50,20 @@ def __init__( tool_spec: Optional[Union[str, Path]] = None, tool_credentials: Optional[Union[str, Dict[str, Any]]] = None, ): - self.llm_id, self.chat_generator = create_generator(model) - self.config_openapi = ( - ClientConfiguration( + + manager = ChatGeneratorDescriptorManager() + self.descriptor, self.chat_generator = manager.create_generator( + model_name=model + ) + self.config_openapi: Optional[ClientConfiguration] = None + self.open_api_service: Optional[OpenAPIServiceClient] = None + if tool_spec: + self.config_openapi = ClientConfiguration( openapi_spec=tool_spec, credentials=tool_credentials, - llm_provider=self.llm_id.provider.value, + llm_provider=self.descriptor.name, ) - if tool_spec - else None - ) + self.open_api_service = OpenAPIServiceClient(self.config_openapi) @component.output_types(service_response=List[ChatMessage]) def run( @@ -88,17 +94,29 @@ def run( ClientConfiguration( openapi_spec=tool_spec, credentials=tool_credentials, - llm_provider=self.llm_id.provider.value, + llm_provider=self.descriptor.name, ) if tool_spec else self.config_openapi ) - if not config_openapi: + openapi_service: Optional[OpenAPIServiceClient] = self.open_api_service + if tool_spec: + config_openapi = ClientConfiguration( + openapi_spec=tool_spec, + credentials=tool_credentials, + llm_provider=self.descriptor.name, + ) + openapi_service = OpenAPIServiceClient(config_openapi) + else: + config_openapi = self.config_openapi + + if not openapi_service or not config_openapi: raise ValueError( "OpenAPI specification not provided. Please provide an OpenAPI specification either at initialization " "or during runtime." ) + # merge fc_generator_kwargs, tools definitions comes from the OpenAPI spec, other kwargs are passed by the user fc_generator_kwargs = { "tools": config_openapi.get_tools_definitions(), @@ -110,11 +128,10 @@ def run( f"Invoking chat generator with {last_message.content} to generate function calling payload." ) fc_payload = self.chat_generator.run(messages, fc_generator_kwargs) - - openapi_service = OpenAPIServiceClient(config_openapi) try: invocation_payload = json.loads(fc_payload["replies"][0].content) logger.debug(f"Invoking tool with {invocation_payload}") + # openapi_service is never None here, ignore mypy error service_response = openapi_service.invoke(invocation_payload) except Exception as e: # pylint: disable=broad-exception-caught logger.error(f"Error invoking OpenAPI endpoint. Error: {e}") diff --git a/haystack_experimental/components/tools/openapi/payload_extraction.py b/haystack_experimental/components/tools/openapi/payload_extraction.py index 157070ae..416841bf 100644 --- a/haystack_experimental/components/tools/openapi/payload_extraction.py +++ b/haystack_experimental/components/tools/openapi/payload_extraction.py @@ -27,9 +27,9 @@ def _extract_function_invocation(payload: Any) -> Dict[str, Any]: ) return { "name": fields_and_values.get("name"), - "arguments": json.loads(arguments) - if isinstance(arguments, str) - else arguments, + "arguments": ( + json.loads(arguments) if isinstance(arguments, str) else arguments + ), } return {}