Skip to content

Commit

Permalink
Refactoring step 6 - simplify generator factory
Browse files Browse the repository at this point in the history
  • Loading branch information
vblagoje committed Jun 5, 2024
1 parent d9ebb8c commit bdc180f
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 89 deletions.
161 changes: 87 additions & 74 deletions haystack_experimental/components/tools/openapi/generator_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,88 +5,101 @@
import importlib
import re
from dataclasses import dataclass
from enum import Enum
from typing import Any, Dict, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple


class LLMProvider(Enum):
"""
Enum for different LLM providers
"""
OPENAI = "openai"
ANTHROPIC = "anthropic"
COHERE = "cohere"


PROVIDER_DETAILS: Dict[LLMProvider, Dict[str, Any]] = {
LLMProvider.OPENAI: {
"class_path": "haystack.components.generators.chat.openai.OpenAIChatGenerator",
"patterns": [re.compile(r"^gpt.*")],
},
LLMProvider.ANTHROPIC: {
"class_path": "haystack_integrations.components.generators.anthropic.AnthropicChatGenerator",
"patterns": [re.compile(r"^claude.*")],
},
LLMProvider.COHERE: {
"class_path": "haystack_integrations.components.generators.cohere.CohereChatGenerator",
"patterns": [re.compile(r"^command-r.*")],
},
}


def load_class(full_class_path: str):
@dataclass
class ChatGeneratorDescriptor:
"""
Load a class from a string representation of its path e.g. "module.submodule.class_name"
Dataclass to describe a Chat Generator
"""
module_path, _, class_name = full_class_path.rpartition(".")
module = importlib.import_module(module_path)
return getattr(module, class_name)


@dataclass
class LLMIdentifier:
""" Dataclass to hold the LLM provider and model name"""
provider: LLMProvider
class_path: str
patterns: List[re.Pattern]
name: str
model_name: str

def __post_init__(self):
if not isinstance(self.provider, LLMProvider):
raise ValueError(f"Invalid provider: {self.provider}")

if not isinstance(self.model_name, str):
raise ValueError(f"Model name must be a string: {self.model_name}")

details = PROVIDER_DETAILS.get(self.provider)
if not details or not any(
pattern.match(self.model_name) for pattern in details["patterns"]
):
raise ValueError(
f"Invalid combination of provider {self.provider} and model name {self.model_name}"
)


def create_generator(
model_name: str, provider: Optional[str] = None, **model_kwargs
) -> Tuple[LLMIdentifier, Any]:
class ChatGeneratorDescriptorManager:
"""
Create ChatGenerator instance based on the model name and provider.
Class to manage Chat Generator Descriptors
"""
provider_enum = None
if provider:
if provider.lower() not in LLMProvider.__members__:
raise ValueError(f"Invalid provider: {provider}")
provider_enum = LLMProvider[provider.lower()]
else:
for prov, details in PROVIDER_DETAILS.items():
if any(pattern.match(model_name) for pattern in details["patterns"]):
provider_enum = prov
break

if provider_enum is None:
raise ValueError(f"Could not infer provider for model name: {model_name}")

llm_identifier = LLMIdentifier(provider=provider_enum, model_name=model_name)
class_path = PROVIDER_DETAILS[llm_identifier.provider]["class_path"]
return llm_identifier, load_class(class_path)(
model=llm_identifier.model_name, **model_kwargs
)
def __init__(self):
self._descriptors: Dict[str, ChatGeneratorDescriptor] = {}
self._register_default_descriptors()

def _register_default_descriptors(self):
"""
Register default Chat Generator Descriptors.
"""
default_descriptors = [
ChatGeneratorDescriptor(
class_path="haystack.components.generators.chat.openai.OpenAIChatGenerator",
patterns=[re.compile(r"^gpt.*")],
name="openai",
model_name="gpt-3.5-turbo",
),
ChatGeneratorDescriptor(
class_path="haystack_integrations.components.generators.anthropic.AnthropicChatGenerator",
patterns=[re.compile(r"^claude.*")],
name="anthropic",
model_name="claude-1",
),
ChatGeneratorDescriptor(
class_path="haystack_integrations.components.generators.cohere.CohereChatGenerator",
patterns=[re.compile(r"^command-r.*")],
name="cohere",
model_name="command-r",
),
]

for descriptor in default_descriptors:
self.register_descriptor(descriptor)

def _load_class(self, full_class_path: str):
"""
Load a class from a string representation of its path e.g. "module.submodule.class_name"
"""
module_path, _, class_name = full_class_path.rpartition(".")
module = importlib.import_module(module_path)
return getattr(module, class_name)

def register_descriptor(self, descriptor: ChatGeneratorDescriptor):
"""
Register a new Chat Generator Descriptor.
"""
if descriptor.name in self._descriptors:
raise ValueError(f"Descriptor {descriptor.name} already exists.")

self._descriptors[descriptor.name] = descriptor

def _infer_descriptor(self, model_name: str) -> Optional[ChatGeneratorDescriptor]:
"""
Infer the descriptor based on the model name.
"""
for descriptor in self._descriptors.values():
if any(pattern.match(model_name) for pattern in descriptor.patterns):
return descriptor
return None

def create_generator(
self, model_name: str, descriptor_name: Optional[str] = None, **model_kwargs
) -> Tuple[ChatGeneratorDescriptor, Any]:
"""
Create ChatGenerator instance based on the model name and descriptor.
"""
if descriptor_name:
descriptor = self._descriptors.get(descriptor_name)
if not descriptor:
raise ValueError(f"Invalid descriptor name: {descriptor_name}")
else:
descriptor = self._infer_descriptor(model_name)
if not descriptor:
raise ValueError(
f"Could not infer descriptor for model name: {model_name}"
)

return descriptor, self._load_class(descriptor.class_path)(
model=model_name, **model_kwargs
)
41 changes: 29 additions & 12 deletions haystack_experimental/components/tools/openapi/openapi_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
from haystack import component, logging
from haystack.dataclasses import ChatMessage, ChatRole

from haystack_experimental.components.tools.openapi.generator_factory import create_generator
from haystack_experimental.components.tools.openapi.generator_factory import (
ChatGeneratorDescriptorManager,
)
from haystack_experimental.components.tools.openapi.openapi import (
ClientConfiguration,
OpenAPIServiceClient,
Expand Down Expand Up @@ -48,16 +50,20 @@ def __init__(
tool_spec: Optional[Union[str, Path]] = None,
tool_credentials: Optional[Union[str, Dict[str, Any]]] = None,
):
self.llm_id, self.chat_generator = create_generator(model)
self.config_openapi = (
ClientConfiguration(

manager = ChatGeneratorDescriptorManager()
self.descriptor, self.chat_generator = manager.create_generator(
model_name=model
)
self.config_openapi: Optional[ClientConfiguration] = None
self.open_api_service: Optional[OpenAPIServiceClient] = None
if tool_spec:
self.config_openapi = ClientConfiguration(
openapi_spec=tool_spec,
credentials=tool_credentials,
llm_provider=self.llm_id.provider.value,
llm_provider=self.descriptor.name,
)
if tool_spec
else None
)
self.open_api_service = OpenAPIServiceClient(self.config_openapi)

@component.output_types(service_response=List[ChatMessage])
def run(
Expand Down Expand Up @@ -88,17 +94,29 @@ def run(
ClientConfiguration(
openapi_spec=tool_spec,
credentials=tool_credentials,
llm_provider=self.llm_id.provider.value,
llm_provider=self.descriptor.name,
)
if tool_spec
else self.config_openapi
)

if not config_openapi:
openapi_service: Optional[OpenAPIServiceClient] = self.open_api_service
if tool_spec:
config_openapi = ClientConfiguration(
openapi_spec=tool_spec,
credentials=tool_credentials,
llm_provider=self.descriptor.name,
)
openapi_service = OpenAPIServiceClient(config_openapi)
else:
config_openapi = self.config_openapi

if not openapi_service or not config_openapi:
raise ValueError(
"OpenAPI specification not provided. Please provide an OpenAPI specification either at initialization "
"or during runtime."
)

# merge fc_generator_kwargs, tools definitions comes from the OpenAPI spec, other kwargs are passed by the user
fc_generator_kwargs = {
"tools": config_openapi.get_tools_definitions(),
Expand All @@ -110,11 +128,10 @@ def run(
f"Invoking chat generator with {last_message.content} to generate function calling payload."
)
fc_payload = self.chat_generator.run(messages, fc_generator_kwargs)

openapi_service = OpenAPIServiceClient(config_openapi)
try:
invocation_payload = json.loads(fc_payload["replies"][0].content)
logger.debug(f"Invoking tool with {invocation_payload}")
# openapi_service is never None here, ignore mypy error
service_response = openapi_service.invoke(invocation_payload)
except Exception as e: # pylint: disable=broad-exception-caught
logger.error(f"Error invoking OpenAPI endpoint. Error: {e}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ def _extract_function_invocation(payload: Any) -> Dict[str, Any]:
)
return {
"name": fields_and_values.get("name"),
"arguments": json.loads(arguments)
if isinstance(arguments, str)
else arguments,
"arguments": (
json.loads(arguments) if isinstance(arguments, str) else arguments
),
}
return {}

Expand Down

0 comments on commit bdc180f

Please sign in to comment.