From a665f1f65be1f169f21bc9d3484051f1ad1c3636 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Fri, 2 Aug 2024 16:47:20 +0200 Subject: [PATCH] introduce utility function (#939) Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> --- .../llama_cpp/chat/chat_generator.py | 17 +++++++++++++++- .../llama_cpp/tests/test_chat_generator.py | 20 ++++++++++++++++++- 2 files changed, 35 insertions(+), 2 deletions(-) diff --git a/integrations/llama_cpp/src/haystack_integrations/components/generators/llama_cpp/chat/chat_generator.py b/integrations/llama_cpp/src/haystack_integrations/components/generators/llama_cpp/chat/chat_generator.py index e305c2a3d..d43700215 100644 --- a/integrations/llama_cpp/src/haystack_integrations/components/generators/llama_cpp/chat/chat_generator.py +++ b/integrations/llama_cpp/src/haystack_integrations/components/generators/llama_cpp/chat/chat_generator.py @@ -9,6 +9,21 @@ logger = logging.getLogger(__name__) +def _convert_message_to_llamacpp_format(message: ChatMessage) -> Dict[str, str]: + """ + Convert a message to the format expected by Llama.cpp. + :returns: A dictionary with the following keys: + - `role` + - `content` + - `name` (optional) + """ + formatted_msg = {"role": message.role.value, "content": message.content} + if message.name: + formatted_msg["name"] = message.name + + return formatted_msg + + @component class LlamaCppChatGenerator: """ @@ -96,7 +111,7 @@ def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, return {"replies": []} updated_generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})} - formatted_messages = [msg.to_openai_format() for msg in messages] + formatted_messages = [_convert_message_to_llamacpp_format(msg) for msg in messages] response = self.model.create_chat_completion(messages=formatted_messages, **updated_generation_kwargs) replies = [ diff --git a/integrations/llama_cpp/tests/test_chat_generator.py b/integrations/llama_cpp/tests/test_chat_generator.py index a30195c8e..7bd6ef122 100644 --- a/integrations/llama_cpp/tests/test_chat_generator.py +++ b/integrations/llama_cpp/tests/test_chat_generator.py @@ -10,7 +10,10 @@ from haystack.components.retrievers.in_memory import InMemoryBM25Retriever from haystack.dataclasses import ChatMessage, ChatRole from haystack.document_stores.in_memory import InMemoryDocumentStore -from haystack_integrations.components.generators.llama_cpp import LlamaCppChatGenerator +from haystack_integrations.components.generators.llama_cpp.chat.chat_generator import ( + LlamaCppChatGenerator, + _convert_message_to_llamacpp_format, +) @pytest.fixture @@ -29,6 +32,21 @@ def download_file(file_link, filename, capsys): print("\nModel file already exists.") +def test_convert_message_to_llamacpp_format(): + message = ChatMessage.from_system("You are good assistant") + assert _convert_message_to_llamacpp_format(message) == {"role": "system", "content": "You are good assistant"} + + message = ChatMessage.from_user("I have a question") + assert _convert_message_to_llamacpp_format(message) == {"role": "user", "content": "I have a question"} + + message = ChatMessage.from_function("Function call", "function_name") + assert _convert_message_to_llamacpp_format(message) == { + "role": "function", + "content": "Function call", + "name": "function_name", + } + + class TestLlamaCppChatGenerator: @pytest.fixture def generator(self, model_path, capsys):