From 0aad128c0c701a42f682aaa75cf690018025b74c Mon Sep 17 00:00:00 2001 From: anakin87 Date: Thu, 19 Dec 2024 14:26:00 +0100 Subject: [PATCH] fix for new chatmessage; serialize chat_template --- .../generators/chat/hugging_face_local.py | 6 ++- .../hflocalchat-fixes-ddf71e8c4c73e566.yaml | 7 ++++ .../chat/test_hugging_face_local.py | 37 +++++++++++++++++++ 3 files changed, 49 insertions(+), 1 deletion(-) create mode 100644 releasenotes/notes/hflocalchat-fixes-ddf71e8c4c73e566.yaml diff --git a/haystack/components/generators/chat/hugging_face_local.py b/haystack/components/generators/chat/hugging_face_local.py index 988bffc8b4..1ad152f1e3 100644 --- a/haystack/components/generators/chat/hugging_face_local.py +++ b/haystack/components/generators/chat/hugging_face_local.py @@ -25,6 +25,7 @@ from haystack.utils.hf import ( # pylint: disable=ungrouped-imports HFTokenStreamingHandler, StopWordsCriteria, + convert_message_to_hf_format, deserialize_hf_model_kwargs, serialize_hf_model_kwargs, ) @@ -201,6 +202,7 @@ def to_dict(self) -> Dict[str, Any]: generation_kwargs=self.generation_kwargs, streaming_callback=callback_name, token=self.token.to_dict() if self.token else None, + chat_template=self.chat_template, ) huggingface_pipeline_kwargs = serialization_dict["init_parameters"]["huggingface_pipeline_kwargs"] @@ -270,9 +272,11 @@ def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, # streamer parameter hooks into HF streaming, HFTokenStreamingHandler is an adapter to our streaming generation_kwargs["streamer"] = HFTokenStreamingHandler(tokenizer, self.streaming_callback, stop_words) + hf_messages = [convert_message_to_hf_format(message) for message in messages] + # Prepare the prompt for the model prepared_prompt = tokenizer.apply_chat_template( - messages, tokenize=False, chat_template=self.chat_template, add_generation_prompt=True + hf_messages, tokenize=False, chat_template=self.chat_template, add_generation_prompt=True ) # Avoid some unnecessary warnings in the generation pipeline call diff --git a/releasenotes/notes/hflocalchat-fixes-ddf71e8c4c73e566.yaml b/releasenotes/notes/hflocalchat-fixes-ddf71e8c4c73e566.yaml new file mode 100644 index 0000000000..fd8c96a6bb --- /dev/null +++ b/releasenotes/notes/hflocalchat-fixes-ddf71e8c4c73e566.yaml @@ -0,0 +1,7 @@ +--- +fixes: + - | + Make the HuggingFaceLocalChatGenerator compatible with the new ChatMessage format, by converting the messages to + the format expected by Hugging Face. + + Serialize the chat_template parameter. diff --git a/test/components/generators/chat/test_hugging_face_local.py b/test/components/generators/chat/test_hugging_face_local.py index 8f6749c2d8..fe5308b7b3 100644 --- a/test/components/generators/chat/test_hugging_face_local.py +++ b/test/components/generators/chat/test_hugging_face_local.py @@ -135,6 +135,7 @@ def test_to_dict(self, model_info_mock): generation_kwargs={"n": 5}, stop_words=["stop", "words"], streaming_callback=lambda x: x, + chat_template="irrelevant", ) # Call the to_dict method @@ -146,6 +147,7 @@ def test_to_dict(self, model_info_mock): assert init_params["huggingface_pipeline_kwargs"]["model"] == "NousResearch/Llama-2-7b-chat-hf" assert "token" not in init_params["huggingface_pipeline_kwargs"] assert init_params["generation_kwargs"] == {"max_new_tokens": 512, "n": 5, "stop_sequences": ["stop", "words"]} + assert init_params["chat_template"] == "irrelevant" def test_from_dict(self, model_info_mock): generator = HuggingFaceLocalChatGenerator( @@ -153,6 +155,7 @@ def test_from_dict(self, model_info_mock): generation_kwargs={"n": 5}, stop_words=["stop", "words"], streaming_callback=streaming_callback_handler, + chat_template="irrelevant", ) # Call the to_dict method result = generator.to_dict() @@ -162,6 +165,7 @@ def test_from_dict(self, model_info_mock): assert generator_2.token == Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False) assert generator_2.generation_kwargs == {"max_new_tokens": 512, "n": 5, "stop_sequences": ["stop", "words"]} assert generator_2.streaming_callback is streaming_callback_handler + assert generator_2.chat_template == "irrelevant" @patch("haystack.components.generators.chat.hugging_face_local.pipeline") def test_warm_up(self, pipeline_mock, monkeypatch): @@ -218,3 +222,36 @@ def test_run_with_custom_generation_parameters(self, model_info_mock, mock_pipel chat_message = results["replies"][0] assert chat_message.is_from(ChatRole.ASSISTANT) assert chat_message.text == "Berlin is cool" + + @patch("haystack.components.generators.chat.hugging_face_local.convert_message_to_hf_format") + def test_messages_conversion_is_called(self, mock_convert, model_info_mock): + generator = HuggingFaceLocalChatGenerator(model="fake-model") + + messages = [ChatMessage.from_user("Hello"), ChatMessage.from_assistant("Hi there")] + + with patch.object(generator, "pipeline") as mock_pipeline: + mock_pipeline.tokenizer.apply_chat_template.return_value = "test prompt" + mock_pipeline.return_value = [{"generated_text": "test response"}] + + generator.warm_up() + generator.run(messages) + + assert mock_convert.call_count == 2 + mock_convert.assert_any_call(messages[0]) + mock_convert.assert_any_call(messages[1]) + + @pytest.mark.integration + @pytest.mark.flaky(reruns=3, reruns_delay=10) + def test_live_run(self): + messages = [ChatMessage.from_user("Please create a summary about the following topic: Climate change")] + + llm = HuggingFaceLocalChatGenerator( + model="Qwen/Qwen2.5-0.5B-Instruct", generation_kwargs={"max_new_tokens": 50} + ) + llm.warm_up() + + result = llm.run(messages) + + assert "replies" in result + assert isinstance(result["replies"][0], ChatMessage) + assert "climate change" in result["replies"][0].text.lower()