Skip to content

Commit

Permalink
fix for new chatmessage; serialize chat_template
Browse files Browse the repository at this point in the history
  • Loading branch information
anakin87 committed Dec 19, 2024
1 parent 39184a6 commit 0aad128
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 1 deletion.
6 changes: 5 additions & 1 deletion haystack/components/generators/chat/hugging_face_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from haystack.utils.hf import ( # pylint: disable=ungrouped-imports
HFTokenStreamingHandler,
StopWordsCriteria,
convert_message_to_hf_format,
deserialize_hf_model_kwargs,
serialize_hf_model_kwargs,
)
Expand Down Expand Up @@ -201,6 +202,7 @@ def to_dict(self) -> Dict[str, Any]:
generation_kwargs=self.generation_kwargs,
streaming_callback=callback_name,
token=self.token.to_dict() if self.token else None,
chat_template=self.chat_template,
)

huggingface_pipeline_kwargs = serialization_dict["init_parameters"]["huggingface_pipeline_kwargs"]
Expand Down Expand Up @@ -270,9 +272,11 @@ def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str,
# streamer parameter hooks into HF streaming, HFTokenStreamingHandler is an adapter to our streaming
generation_kwargs["streamer"] = HFTokenStreamingHandler(tokenizer, self.streaming_callback, stop_words)

hf_messages = [convert_message_to_hf_format(message) for message in messages]

# Prepare the prompt for the model
prepared_prompt = tokenizer.apply_chat_template(
messages, tokenize=False, chat_template=self.chat_template, add_generation_prompt=True
hf_messages, tokenize=False, chat_template=self.chat_template, add_generation_prompt=True
)

# Avoid some unnecessary warnings in the generation pipeline call
Expand Down
7 changes: 7 additions & 0 deletions releasenotes/notes/hflocalchat-fixes-ddf71e8c4c73e566.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
fixes:
- |
Make the HuggingFaceLocalChatGenerator compatible with the new ChatMessage format, by converting the messages to
the format expected by Hugging Face.
Serialize the chat_template parameter.
37 changes: 37 additions & 0 deletions test/components/generators/chat/test_hugging_face_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ def test_to_dict(self, model_info_mock):
generation_kwargs={"n": 5},
stop_words=["stop", "words"],
streaming_callback=lambda x: x,
chat_template="irrelevant",
)

# Call the to_dict method
Expand All @@ -146,13 +147,15 @@ def test_to_dict(self, model_info_mock):
assert init_params["huggingface_pipeline_kwargs"]["model"] == "NousResearch/Llama-2-7b-chat-hf"
assert "token" not in init_params["huggingface_pipeline_kwargs"]
assert init_params["generation_kwargs"] == {"max_new_tokens": 512, "n": 5, "stop_sequences": ["stop", "words"]}
assert init_params["chat_template"] == "irrelevant"

def test_from_dict(self, model_info_mock):
generator = HuggingFaceLocalChatGenerator(
model="NousResearch/Llama-2-7b-chat-hf",
generation_kwargs={"n": 5},
stop_words=["stop", "words"],
streaming_callback=streaming_callback_handler,
chat_template="irrelevant",
)
# Call the to_dict method
result = generator.to_dict()
Expand All @@ -162,6 +165,7 @@ def test_from_dict(self, model_info_mock):
assert generator_2.token == Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False)
assert generator_2.generation_kwargs == {"max_new_tokens": 512, "n": 5, "stop_sequences": ["stop", "words"]}
assert generator_2.streaming_callback is streaming_callback_handler
assert generator_2.chat_template == "irrelevant"

@patch("haystack.components.generators.chat.hugging_face_local.pipeline")
def test_warm_up(self, pipeline_mock, monkeypatch):
Expand Down Expand Up @@ -218,3 +222,36 @@ def test_run_with_custom_generation_parameters(self, model_info_mock, mock_pipel
chat_message = results["replies"][0]
assert chat_message.is_from(ChatRole.ASSISTANT)
assert chat_message.text == "Berlin is cool"

@patch("haystack.components.generators.chat.hugging_face_local.convert_message_to_hf_format")
def test_messages_conversion_is_called(self, mock_convert, model_info_mock):
generator = HuggingFaceLocalChatGenerator(model="fake-model")

messages = [ChatMessage.from_user("Hello"), ChatMessage.from_assistant("Hi there")]

with patch.object(generator, "pipeline") as mock_pipeline:
mock_pipeline.tokenizer.apply_chat_template.return_value = "test prompt"
mock_pipeline.return_value = [{"generated_text": "test response"}]

generator.warm_up()
generator.run(messages)

assert mock_convert.call_count == 2
mock_convert.assert_any_call(messages[0])
mock_convert.assert_any_call(messages[1])

@pytest.mark.integration
@pytest.mark.flaky(reruns=3, reruns_delay=10)
def test_live_run(self):
messages = [ChatMessage.from_user("Please create a summary about the following topic: Climate change")]

llm = HuggingFaceLocalChatGenerator(
model="Qwen/Qwen2.5-0.5B-Instruct", generation_kwargs={"max_new_tokens": 50}
)
llm.warm_up()

result = llm.run(messages)

assert "replies" in result
assert isinstance(result["replies"][0], ChatMessage)
assert "climate change" in result["replies"][0].text.lower()

0 comments on commit 0aad128

Please sign in to comment.