Skip to content

Commit

Permalink
message conversion function
Browse files Browse the repository at this point in the history
  • Loading branch information
anakin87 committed Dec 19, 2024
1 parent 91619a7 commit 4db6f40
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 2 deletions.
40 changes: 39 additions & 1 deletion haystack/utils/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from typing import Any, Callable, Dict, List, Optional, Union

from haystack import logging
from haystack.dataclasses import StreamingChunk
from haystack.dataclasses import ChatMessage, StreamingChunk
from haystack.lazy_imports import LazyImport
from haystack.utils.auth import Secret
from haystack.utils.device import ComponentDevice
Expand Down Expand Up @@ -270,6 +270,44 @@ def check_generation_params(kwargs: Optional[Dict[str, Any]], additional_accepte
)


def convert_message_to_hf_format(message: ChatMessage) -> Dict[str, Any]:
"""
Convert a message to the format expected by Hugging Face.
"""
text_contents = message.texts
tool_calls = message.tool_calls
tool_call_results = message.tool_call_results

if not text_contents and not tool_calls and not tool_call_results:
raise ValueError("A `ChatMessage` must contain at least one `TextContent`, `ToolCall`, or `ToolCallResult`.")
elif len(text_contents) + len(tool_call_results) > 1:
raise ValueError("A `ChatMessage` can only contain one `TextContent` or one `ToolCallResult`.")

# HF always expects a content field, even if it is empty
hf_msg: Dict[str, Any] = {"role": message._role.value, "content": ""}

if tool_call_results:
result = tool_call_results[0]
hf_msg["content"] = result.result
if tc_id := result.origin.id:
hf_msg["tool_call_id"] = tc_id
# HF does not provide a way to communicate errors in tool invocations, so we ignore the error field
return hf_msg

if text_contents:
hf_msg["content"] = text_contents[0]
if tool_calls:
hf_tool_calls = []
for tc in tool_calls:
hf_tool_call = {"type": "function", "function": {"name": tc.tool_name, "arguments": tc.arguments}}
if tc.id is not None:
hf_tool_call["id"] = tc.id
hf_tool_calls.append(hf_tool_call)
hf_msg["tool_calls"] = hf_tool_calls

return hf_msg


with LazyImport(message="Run 'pip install \"transformers[torch]\"'") as transformers_import:
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast, StoppingCriteria, TextStreamer

Expand Down
59 changes: 58 additions & 1 deletion test/utils/test_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,12 @@
#
# SPDX-License-Identifier: Apache-2.0
import logging
from haystack.utils.hf import resolve_hf_device_map

import pytest

from haystack.utils.hf import resolve_hf_device_map, convert_message_to_hf_format
from haystack.utils.device import ComponentDevice
from haystack.dataclasses import ChatMessage, ToolCall, ChatRole, TextContent


def test_resolve_hf_device_map_only_device():
Expand All @@ -23,3 +27,56 @@ def test_resolve_hf_device_map_device_and_device_map(caplog):
)
assert "The parameters `device` and `device_map` from `model_kwargs` are both provided." in caplog.text
assert model_kwargs["device_map"] == "cuda:0"


def test_convert_message_to_hf_format():
message = ChatMessage.from_system("You are good assistant")
assert convert_message_to_hf_format(message) == {"role": "system", "content": "You are good assistant"}

message = ChatMessage.from_user("I have a question")
assert convert_message_to_hf_format(message) == {"role": "user", "content": "I have a question"}

message = ChatMessage.from_assistant(text="I have an answer", meta={"finish_reason": "stop"})
assert convert_message_to_hf_format(message) == {"role": "assistant", "content": "I have an answer"}

message = ChatMessage.from_assistant(
tool_calls=[ToolCall(id="123", tool_name="weather", arguments={"city": "Paris"})]
)
assert convert_message_to_hf_format(message) == {
"role": "assistant",
"content": "",
"tool_calls": [
{"id": "123", "type": "function", "function": {"name": "weather", "arguments": {"city": "Paris"}}}
],
}

message = ChatMessage.from_assistant(tool_calls=[ToolCall(tool_name="weather", arguments={"city": "Paris"})])
assert convert_message_to_hf_format(message) == {
"role": "assistant",
"content": "",
"tool_calls": [{"type": "function", "function": {"name": "weather", "arguments": {"city": "Paris"}}}],
}

tool_result = {"weather": "sunny", "temperature": "25"}
message = ChatMessage.from_tool(
tool_result=tool_result, origin=ToolCall(id="123", tool_name="weather", arguments={"city": "Paris"})
)
assert convert_message_to_hf_format(message) == {"role": "tool", "content": tool_result, "tool_call_id": "123"}

message = ChatMessage.from_tool(
tool_result=tool_result, origin=ToolCall(tool_name="weather", arguments={"city": "Paris"})
)
assert convert_message_to_hf_format(message) == {"role": "tool", "content": tool_result}


def test_convert_message_to_hf_invalid():
message = ChatMessage(_role=ChatRole.ASSISTANT, _content=[])
with pytest.raises(ValueError):
convert_message_to_hf_format(message)

message = ChatMessage(
_role=ChatRole.ASSISTANT,
_content=[TextContent(text="I have an answer"), TextContent(text="I have another answer")],
)
with pytest.raises(ValueError):
convert_message_to_hf_format(message)

0 comments on commit 4db6f40

Please sign in to comment.