From 1e30dedfedb06eb75b4d52aae881ac6c41d5d85c Mon Sep 17 00:00:00 2001 From: anakin87 Date: Wed, 18 Dec 2024 16:41:58 +0100 Subject: [PATCH 1/3] progress --- .../generators/google_vertex/chat/gemini.py | 6 +-- .../llama_cpp/chat/chat_generator.py | 46 ++++++++++--------- .../llama_cpp/tests/test_chat_generator.py | 10 ++-- 3 files changed, 32 insertions(+), 30 deletions(-) diff --git a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py index 2309ca718..845e24f5f 100644 --- a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py +++ b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py @@ -280,12 +280,10 @@ def _get_response(self, response_body: GenerationResponse) -> List[ChatMessage]: # Remove content from metadata metadata.pop("content", None) if part._raw_part.text != "": - replies.append(ChatMessage.from_assistant(content=part._raw_part.text, meta=metadata)) + replies.append(ChatMessage.from_assistant(part._raw_part.text, meta=metadata)) elif part.function_call: metadata["function_call"] = part.function_call - new_message = ChatMessage.from_assistant( - content=json.dumps(dict(part.function_call.args)), meta=metadata - ) + new_message = ChatMessage.from_assistant(json.dumps(dict(part.function_call.args)), meta=metadata) new_message.name = part.function_call.name replies.append(new_message) return replies diff --git a/integrations/llama_cpp/src/haystack_integrations/components/generators/llama_cpp/chat/chat_generator.py b/integrations/llama_cpp/src/haystack_integrations/components/generators/llama_cpp/chat/chat_generator.py index 014dd7169..4c96b5daa 100644 --- a/integrations/llama_cpp/src/haystack_integrations/components/generators/llama_cpp/chat/chat_generator.py +++ b/integrations/llama_cpp/src/haystack_integrations/components/generators/llama_cpp/chat/chat_generator.py @@ -2,7 +2,7 @@ from typing import Any, Dict, List, Optional from haystack import component -from haystack.dataclasses import ChatMessage, ChatRole +from haystack.dataclasses import ChatMessage from llama_cpp import Llama from llama_cpp.llama_tokenizer import LlamaHFTokenizer @@ -21,6 +21,10 @@ def _convert_message_to_llamacpp_format(message: ChatMessage) -> Dict[str, str]: if message.name: formatted_msg["name"] = message.name + if formatted_msg["role"] == "tool": + formatted_msg["name"] = message.tool_call_result.origin.tool_name + formatted_msg["content"] = message.tool_call_result.result + return formatted_msg @@ -114,26 +118,26 @@ def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, formatted_messages = [_convert_message_to_llamacpp_format(msg) for msg in messages] response = self.model.create_chat_completion(messages=formatted_messages, **updated_generation_kwargs) - replies = [ - ChatMessage( - content=choice["message"]["content"], - role=ChatRole[choice["message"]["role"].upper()], - name=None, - meta={ - "response_id": response["id"], - "model": response["model"], - "created": response["created"], - "index": choice["index"], - "finish_reason": choice["finish_reason"], - "usage": response["usage"], - }, - ) - for choice in response["choices"] - ] - - for reply, choice in zip(replies, response["choices"]): + + replies = [] + + for choice in response["choices"]: + meta = { + "response_id": response["id"], + "model": response["model"], + "created": response["created"], + "index": choice["index"], + "finish_reason": choice["finish_reason"], + "usage": response["usage"], + } + + name = None tool_calls = choice.get("message", {}).get("tool_calls", []) if tool_calls: - reply.meta["tool_calls"] = tool_calls - reply.name = tool_calls[0]["function"]["name"] if tool_calls else None + meta["tool_calls"] = tool_calls + name = tool_calls[0]["function"]["name"] + + reply = ChatMessage.from_assistant(choice["message"]["content"], meta=meta, name=name) + replies.append(reply) + return {"replies": replies} diff --git a/integrations/llama_cpp/tests/test_chat_generator.py b/integrations/llama_cpp/tests/test_chat_generator.py index 0ddd78c4f..87639f684 100644 --- a/integrations/llama_cpp/tests/test_chat_generator.py +++ b/integrations/llama_cpp/tests/test_chat_generator.py @@ -41,11 +41,11 @@ def test_convert_message_to_llamacpp_format(): assert _convert_message_to_llamacpp_format(message) == {"role": "user", "content": "I have a question"} message = ChatMessage.from_function("Function call", "function_name") - assert _convert_message_to_llamacpp_format(message) == { - "role": "function", - "content": "Function call", - "name": "function_name", - } + converted_message = _convert_message_to_llamacpp_format(message) + + assert converted_message["role"] in ("function", "tool") + assert converted_message["name"] == "function_name" + assert converted_message["content"] == "Function call" class TestLlamaCppChatGenerator: From 8ec7433685b3d22db19f683617ed305c1908b1c4 Mon Sep 17 00:00:00 2001 From: anakin87 Date: Wed, 18 Dec 2024 17:15:25 +0100 Subject: [PATCH 2/3] remove vertex changes from this PR --- .../components/generators/google_vertex/chat/gemini.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py index 845e24f5f..2309ca718 100644 --- a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py +++ b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py @@ -280,10 +280,12 @@ def _get_response(self, response_body: GenerationResponse) -> List[ChatMessage]: # Remove content from metadata metadata.pop("content", None) if part._raw_part.text != "": - replies.append(ChatMessage.from_assistant(part._raw_part.text, meta=metadata)) + replies.append(ChatMessage.from_assistant(content=part._raw_part.text, meta=metadata)) elif part.function_call: metadata["function_call"] = part.function_call - new_message = ChatMessage.from_assistant(json.dumps(dict(part.function_call.args)), meta=metadata) + new_message = ChatMessage.from_assistant( + content=json.dumps(dict(part.function_call.args)), meta=metadata + ) new_message.name = part.function_call.name replies.append(new_message) return replies From ebf48e25c58f5c9c881d92b30f32f06ad17a17cd Mon Sep 17 00:00:00 2001 From: anakin87 Date: Wed, 18 Dec 2024 17:24:29 +0100 Subject: [PATCH 3/3] fix --- .../components/generators/llama_cpp/chat/chat_generator.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/integrations/llama_cpp/src/haystack_integrations/components/generators/llama_cpp/chat/chat_generator.py b/integrations/llama_cpp/src/haystack_integrations/components/generators/llama_cpp/chat/chat_generator.py index 4c96b5daa..d2150f61f 100644 --- a/integrations/llama_cpp/src/haystack_integrations/components/generators/llama_cpp/chat/chat_generator.py +++ b/integrations/llama_cpp/src/haystack_integrations/components/generators/llama_cpp/chat/chat_generator.py @@ -137,7 +137,12 @@ def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, meta["tool_calls"] = tool_calls name = tool_calls[0]["function"]["name"] - reply = ChatMessage.from_assistant(choice["message"]["content"], meta=meta, name=name) + reply = ChatMessage.from_assistant(choice["message"]["content"], meta=meta) + if name: + if hasattr(reply, "_name"): + reply._name = name # new ChatMessage + elif hasattr(reply, "name"): + reply.name = name # legacy ChatMessage replies.append(reply) return {"replies": replies}