Skip to content

Commit

Permalink
introduce utility function (#939)
Browse files Browse the repository at this point in the history
Co-authored-by: Silvano Cerza <[email protected]>
  • Loading branch information
anakin87 and silvanocerza authored Aug 2, 2024
1 parent 7cee6c8 commit a665f1f
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,21 @@
logger = logging.getLogger(__name__)


def _convert_message_to_llamacpp_format(message: ChatMessage) -> Dict[str, str]:
"""
Convert a message to the format expected by Llama.cpp.
:returns: A dictionary with the following keys:
- `role`
- `content`
- `name` (optional)
"""
formatted_msg = {"role": message.role.value, "content": message.content}
if message.name:
formatted_msg["name"] = message.name

return formatted_msg


@component
class LlamaCppChatGenerator:
"""
Expand Down Expand Up @@ -96,7 +111,7 @@ def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str,
return {"replies": []}

updated_generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
formatted_messages = [msg.to_openai_format() for msg in messages]
formatted_messages = [_convert_message_to_llamacpp_format(msg) for msg in messages]

response = self.model.create_chat_completion(messages=formatted_messages, **updated_generation_kwargs)
replies = [
Expand Down
20 changes: 19 additions & 1 deletion integrations/llama_cpp/tests/test_chat_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
from haystack.components.retrievers.in_memory import InMemoryBM25Retriever
from haystack.dataclasses import ChatMessage, ChatRole
from haystack.document_stores.in_memory import InMemoryDocumentStore
from haystack_integrations.components.generators.llama_cpp import LlamaCppChatGenerator
from haystack_integrations.components.generators.llama_cpp.chat.chat_generator import (
LlamaCppChatGenerator,
_convert_message_to_llamacpp_format,
)


@pytest.fixture
Expand All @@ -29,6 +32,21 @@ def download_file(file_link, filename, capsys):
print("\nModel file already exists.")


def test_convert_message_to_llamacpp_format():
message = ChatMessage.from_system("You are good assistant")
assert _convert_message_to_llamacpp_format(message) == {"role": "system", "content": "You are good assistant"}

message = ChatMessage.from_user("I have a question")
assert _convert_message_to_llamacpp_format(message) == {"role": "user", "content": "I have a question"}

message = ChatMessage.from_function("Function call", "function_name")
assert _convert_message_to_llamacpp_format(message) == {
"role": "function",
"content": "Function call",
"name": "function_name",
}


class TestLlamaCppChatGenerator:
@pytest.fixture
def generator(self, model_path, capsys):
Expand Down

0 comments on commit a665f1f

Please sign in to comment.