From c8d93f11f83fb3907511b0446bd9e8008090ed6b Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Tue, 20 Aug 2024 21:02:54 +0200 Subject: [PATCH] Anthropic allows multiple system messages, simplify --- .../anthropic/chat/chat_generator.py | 33 ++++++++++--------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/chat/chat_generator.py b/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/chat/chat_generator.py index 5492abf01..8390405c1 100644 --- a/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/chat/chat_generator.py +++ b/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/chat/chat_generator.py @@ -179,20 +179,20 @@ def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, f"Model parameters {disallowed_params} are not allowed and will be ignored. " f"Allowed parameters are {self.ALLOWED_PARAMS}." ) - - # adapt ChatMessage(s) to the format expected by the Anthropic API - anthropic_formatted_messages = self._convert_to_anthropic_format(messages) - - # system message provided by the user overrides the system message from the self.generation_kwargs - system = [anthropic_formatted_messages[0]] if messages and messages[0].is_from(ChatRole.SYSTEM) else None - if system: - anthropic_formatted_messages = anthropic_formatted_messages[1:] + system_messages: List[ChatMessage] = [msg for msg in messages if msg.is_from(ChatRole.SYSTEM)] + non_system_messages: List[ChatMessage] = [msg for msg in messages if not msg.is_from(ChatRole.SYSTEM)] + system_messages_formatted: List[Dict[str, Any]] = ( + self._convert_to_anthropic_format(system_messages) if system_messages else [] + ) + messages_formatted: List[Dict[str, Any]] = ( + self._convert_to_anthropic_format(non_system_messages) if non_system_messages else [] + ) response: Union[Message, Stream[MessageStreamEvent]] = self.client.messages.create( max_tokens=filtered_generation_kwargs.pop("max_tokens", 512), - system=system if system else filtered_generation_kwargs.pop("system", ""), + system=system_messages_formatted or filtered_generation_kwargs.pop("system", ""), model=self.model, - messages=anthropic_formatted_messages, + messages=messages_formatted, stream=self.streaming_callback is not None, **filtered_generation_kwargs, ) @@ -261,14 +261,15 @@ def _convert_to_anthropic_format(self, messages: List[ChatMessage]) -> List[Dict anthropic_formatted_messages = [] for m in messages: message_dict = dataclasses.asdict(m) - filtered_message = {k: v for k, v in message_dict.items() if k in {"role", "content"} and v} + formatted_message = {k: v for k, v in message_dict.items() if k in {"role", "content"} and v} if m.is_from(ChatRole.SYSTEM): # system messages need to be in the format expected by the Anthropic API - filtered_message.pop("role") - filtered_message["type"] = "text" - filtered_message["text"] = filtered_message.pop("content") - filtered_message.update(m.meta or {}) - anthropic_formatted_messages.append(filtered_message) + # remove role and content from the message dict, add type and text + formatted_message.pop("role") + formatted_message["type"] = "text" + formatted_message["text"] = formatted_message.pop("content") + formatted_message.update(m.meta or {}) + anthropic_formatted_messages.append(formatted_message) return anthropic_formatted_messages def _connect_chunks(