Skip to content

Commit

Permalink
progress
Browse files Browse the repository at this point in the history
  • Loading branch information
anakin87 committed Dec 18, 2024
1 parent b12461d commit 1e30ded
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -280,12 +280,10 @@ def _get_response(self, response_body: GenerationResponse) -> List[ChatMessage]:
# Remove content from metadata
metadata.pop("content", None)
if part._raw_part.text != "":
replies.append(ChatMessage.from_assistant(content=part._raw_part.text, meta=metadata))
replies.append(ChatMessage.from_assistant(part._raw_part.text, meta=metadata))
elif part.function_call:
metadata["function_call"] = part.function_call
new_message = ChatMessage.from_assistant(
content=json.dumps(dict(part.function_call.args)), meta=metadata
)
new_message = ChatMessage.from_assistant(json.dumps(dict(part.function_call.args)), meta=metadata)
new_message.name = part.function_call.name
replies.append(new_message)
return replies
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Any, Dict, List, Optional

from haystack import component
from haystack.dataclasses import ChatMessage, ChatRole
from haystack.dataclasses import ChatMessage
from llama_cpp import Llama
from llama_cpp.llama_tokenizer import LlamaHFTokenizer

Expand All @@ -21,6 +21,10 @@ def _convert_message_to_llamacpp_format(message: ChatMessage) -> Dict[str, str]:
if message.name:
formatted_msg["name"] = message.name

if formatted_msg["role"] == "tool":
formatted_msg["name"] = message.tool_call_result.origin.tool_name
formatted_msg["content"] = message.tool_call_result.result

return formatted_msg


Expand Down Expand Up @@ -114,26 +118,26 @@ def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str,
formatted_messages = [_convert_message_to_llamacpp_format(msg) for msg in messages]

response = self.model.create_chat_completion(messages=formatted_messages, **updated_generation_kwargs)
replies = [
ChatMessage(
content=choice["message"]["content"],
role=ChatRole[choice["message"]["role"].upper()],
name=None,
meta={
"response_id": response["id"],
"model": response["model"],
"created": response["created"],
"index": choice["index"],
"finish_reason": choice["finish_reason"],
"usage": response["usage"],
},
)
for choice in response["choices"]
]

for reply, choice in zip(replies, response["choices"]):

replies = []

for choice in response["choices"]:
meta = {
"response_id": response["id"],
"model": response["model"],
"created": response["created"],
"index": choice["index"],
"finish_reason": choice["finish_reason"],
"usage": response["usage"],
}

name = None
tool_calls = choice.get("message", {}).get("tool_calls", [])
if tool_calls:
reply.meta["tool_calls"] = tool_calls
reply.name = tool_calls[0]["function"]["name"] if tool_calls else None
meta["tool_calls"] = tool_calls
name = tool_calls[0]["function"]["name"]

reply = ChatMessage.from_assistant(choice["message"]["content"], meta=meta, name=name)
replies.append(reply)

return {"replies": replies}
10 changes: 5 additions & 5 deletions integrations/llama_cpp/tests/test_chat_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,11 @@ def test_convert_message_to_llamacpp_format():
assert _convert_message_to_llamacpp_format(message) == {"role": "user", "content": "I have a question"}

message = ChatMessage.from_function("Function call", "function_name")
assert _convert_message_to_llamacpp_format(message) == {
"role": "function",
"content": "Function call",
"name": "function_name",
}
converted_message = _convert_message_to_llamacpp_format(message)

assert converted_message["role"] in ("function", "tool")
assert converted_message["name"] == "function_name"
assert converted_message["content"] == "Function call"


class TestLlamaCppChatGenerator:
Expand Down

0 comments on commit 1e30ded

Please sign in to comment.