From 659e87c81c94631311df2d8a49e59d6d87d2796d Mon Sep 17 00:00:00 2001 From: anakin87 Date: Fri, 27 Sep 2024 17:00:24 +0200 Subject: [PATCH] more unit tests --- .../dataclasses/chat_message.py | 2 +- .../generators/ollama/test_chat_generator.py | 158 +++++++++++++++++- 2 files changed, 153 insertions(+), 7 deletions(-) diff --git a/haystack_experimental/dataclasses/chat_message.py b/haystack_experimental/dataclasses/chat_message.py index 5416f639..9f139af1 100644 --- a/haystack_experimental/dataclasses/chat_message.py +++ b/haystack_experimental/dataclasses/chat_message.py @@ -193,7 +193,7 @@ def from_assistant( :returns: A new ChatMessage instance. """ content: List[ChatMessageContentT] = [] - if text: + if text is not None: content.append(TextContent(text=text)) if tool_calls: content.extend(tool_calls) diff --git a/test/components/generators/ollama/test_chat_generator.py b/test/components/generators/ollama/test_chat_generator.py index 5d2c2290..0e1f5833 100644 --- a/test/components/generators/ollama/test_chat_generator.py +++ b/test/components/generators/ollama/test_chat_generator.py @@ -1,15 +1,39 @@ -from typing import List -from unittest.mock import Mock +from unittest.mock import Mock, patch import sys import json - import pytest + from haystack.components.generators.utils import print_streaming_chunk -from haystack_experimental.dataclasses import ChatMessage, ChatRole, TextContent, ToolCall, Tool +from haystack.dataclasses import StreamingChunk from ollama._types import ResponseError +from haystack_experimental.dataclasses import ChatMessage, ChatRole, TextContent, ToolCall, Tool from haystack_experimental.components.generators.ollama.chat.chat_generator import OllamaChatGenerator, _convert_message_to_ollama_format +# @pytest.fixture +# def mock_ollama_response(): +# with patch("ollama.Client.chat") as mock_create_chat: + # response = + # with patch("openai.resources.chat.completions.Completions.create") as mock_chat_completion_create: + # completion = ChatCompletion( + # id="foo", + # model="gpt-4", + # object="chat.completion", + # choices=[ + # Choice( + # finish_reason="stop", + # logprobs=None, + # index=0, + # message=ChatCompletionMessage(content="Hello world!", role="assistant"), + # ) + # ], + # created=int(datetime.now().timestamp()), + # usage={"prompt_tokens": 57, "completion_tokens": 40, "total_tokens": 97}, + # ) + + # mock_chat_completion_create.return_value = completion + # yield mock_chat_completion_create + @pytest.fixture def tools(): tool_parameters = { @@ -170,9 +194,131 @@ def test_build_message_from_ollama_response(self): observed = OllamaChatGenerator(model=model)._build_message_from_ollama_response(ollama_response) - assert observed._role == "assistant" + assert observed.role == "assistant" assert observed.text == "Hello! How are you today?" + def test_build_message_from_ollama_response_with_tools(self): + model = "some_model" + + ollama_response = { + "model": model, + "created_at": "2023-12-12T14:13:43.416799Z", + "message": {"role": "assistant", "content": "", + "tool_calls": [ + { + "function": { + "name": "get_current_weather", + "arguments": { + "format": "celsius", + "location": "Paris, FR" + } + } + } + ] + }, + "done": True, + "total_duration": 5191566416, + "load_duration": 2154458, + "prompt_eval_count": 26, + "prompt_eval_duration": 383809000, + "eval_count": 298, + "eval_duration": 4799921000, + } + + observed = OllamaChatGenerator(model=model)._build_message_from_ollama_response(ollama_response) + + assert observed.role == "assistant" + assert observed.text == "" + assert observed.tool_call == ToolCall(tool_name="get_current_weather", arguments={"format": "celsius", "location": "Paris, FR"}) + + @patch("haystack_integrations.components.generators.ollama.chat.chat_generator.Client") + def test_run(self, mock_client): + generator = OllamaChatGenerator() + + mock_response = { + "model": "llama3.2", + "created_at": "2023-12-12T14:13:43.416799Z", + "message": { + "role": "assistant", + "content": "Fine. How can I help you today?" + }, + "done": True, + "total_duration": 5191566416, + "load_duration": 2154458, + "prompt_eval_count": 26, + "prompt_eval_duration": 383809000, + "eval_count": 298, + "eval_duration": 4799921000 + } + + mock_client_instance = mock_client.return_value + mock_client_instance.chat.return_value = mock_response + + result = generator.run(messages=[ChatMessage.from_user("Hello! How are you today?")]) + + mock_client_instance.chat.assert_called_once_with( + model="orca-mini", + messages=[ + {"role": "user", "content": "Hello! How are you today?"} + ], + stream=False, + tools=None, + options={} + ) + + assert "replies" in result + assert len(result["replies"]) == 1 + assert result["replies"][0].text == "Fine. How can I help you today?" + assert result["replies"][0].role == "assistant" + + @patch("haystack_integrations.components.generators.ollama.chat.chat_generator.Client") + def test_run_streaming(self, mock_client): + streaming_callback_called = False + + def streaming_callback(chunk: StreamingChunk) -> None: + nonlocal streaming_callback_called + streaming_callback_called = True + + generator = OllamaChatGenerator(streaming_callback=streaming_callback) + + mock_response = iter([{ + "model": "llama3.2", + "created_at": "2023-12-12T14:13:43.416799Z", + "message": { + "role": "assistant", + "content": "first chunk " + }, + "done": False, + }, + { + "model": "llama3.2", + "created_at": "2023-12-12T14:13:43.416799Z", + "message": { + "role": "assistant", + "content": "second chunk" + }, + "done": True, + "total_duration": 4883583458, + "load_duration": 1334875, + "prompt_eval_count": 26, + "prompt_eval_duration": 342546000, + "eval_count": 282, + "eval_duration": 4535599000,}]) + + + mock_client_instance = mock_client.return_value + mock_client_instance.chat.return_value = mock_response + + result = generator.run(messages=[ChatMessage.from_user("irrelevant")]) + + assert streaming_callback_called + + assert "replies" in result + assert len(result["replies"]) == 1 + assert result["replies"][0].text == "first chunk second chunk" + assert result["replies"][0].role == "assistant" + + def test_run_fail_with_tools_and_streaming(self, tools): component = OllamaChatGenerator(tools=tools, streaming_callback=print_streaming_chunk) @@ -185,7 +331,7 @@ def test_run_fail_with_tools_and_streaming(self, tools): sys.platform != "linux", reason="For simplicity, we only run the integration tests on Linux.", ) - def test_run(self): + def test_live_run(self): chat_generator = OllamaChatGenerator(model="llama3.2:3b") user_questions_and_assistant_answers = [