From 5e7854492091ccc7b46f40ff7ec728b7e2bebced Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Wed, 2 Oct 2024 15:35:39 +0200 Subject: [PATCH] feat: support for tools in `OllamaChatGenerator` (#106) * progress * some progress * try fixing tests * different test model * tools handling and tests * refinements * minor fixes * fix * more unit tests * formatting * incorporate feedback from review * update README --- .github/workflows/tests.yml | 22 + README.md | 7 +- docs/pydoc/config/generators_api.yml | 3 +- haystack_experimental/components/__init__.py | 4 +- .../components/generators/ollama/__init__.py | 9 + .../generators/ollama/chat/__init__.py | 9 + .../generators/ollama/chat/chat_generator.py | 248 ++++++++++ .../dataclasses/chat_message.py | 2 +- pyproject.toml | 3 +- .../generators/ollama/test_chat_generator.py | 462 ++++++++++++++++++ 10 files changed, 762 insertions(+), 7 deletions(-) create mode 100644 haystack_experimental/components/generators/ollama/__init__.py create mode 100644 haystack_experimental/components/generators/ollama/chat/__init__.py create mode 100644 haystack_experimental/components/generators/ollama/chat/chat_generator.py create mode 100644 test/components/generators/ollama/test_chat_generator.py diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index ba960d2c..de95e1bb 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -35,6 +35,7 @@ env: COHERE_API_KEY: ${{ secrets.COHERE_API_KEY }} FIRECRAWL_API_KEY: ${{ secrets.FIRECRAWL_API_KEY }} SERPERDEV_API_KEY: ${{ secrets.SERPERDEV_API_KEY }} + OLLAMA_LLM_FOR_TESTS: "llama3.2:3b" jobs: linting: @@ -117,5 +118,26 @@ jobs: - name: Install Hatch run: pip install hatch==${{ env.HATCH_VERSION }} + - name: Install Ollama and pull the required models + if: matrix.os == 'ubuntu-latest' + run: | + curl -fsSL https://ollama.com/install.sh | sh + ollama serve & + + # Check if the service is up and running with a timeout of 60 seconds + timeout=60 + while [ $timeout -gt 0 ] && ! curl -sSf http://localhost:11434/ > /dev/null; do + echo "Waiting for Ollama service to start..." + sleep 5 + ((timeout-=5)) + done + + if [ $timeout -eq 0 ]; then + echo "Timed out waiting for Ollama service to start." + exit 1 + fi + + ollama pull ${{ env.OLLAMA_LLM_FOR_TESTS }} + - name: Run run: hatch run test:integration diff --git a/README.md b/README.md index e6cba746..c49f6d9d 100644 --- a/README.md +++ b/README.md @@ -36,17 +36,17 @@ that includes it. Once it reaches the end of its lifespan, the experiment will b The latest version of the package contains the following experiments: - + ======= | Name | Type | Expected End Date | Dependencies | Cookbook | Discussion | | --------------------------- | -------------------------- | ---------------------------- | ------------ | -------- | ---------- | | [`EvaluationHarness`][1] | Evaluation orchestrator | October 2024 | None | Open In Colab | [Discuss](https://github.com/deepset-ai/haystack-experimental/discussions/74) | | [`OpenAIFunctionCaller`][2] | Function Calling Component | October 2024 | None | 🔜 | | | [`OpenAPITool`][3] | OpenAPITool component | October 2024 | jsonref | Open In Colab | [Discuss](https://github.com/deepset-ai/haystack-experimental/discussions/79)| -| Support for Tools: [refactored `ChatMessage` dataclass][10], [`Tool` dataclass][4], [refactored `OpenAIChatGenerator`][11], [`ToolInvoker` component][12] | Tool Calling support | November 2024 | jsonschema | Open In Colab | [Discuss](https://github.com/deepset-ai/haystack-experimental/discussions/98)| +| Support for Tools: [refactored `ChatMessage` dataclass][10], [`Tool` dataclass][4], [refactored `OpenAIChatGenerator`][11], [refactored `OllamaChatGenerator`][14], [`ToolInvoker` component][12] | Tool Calling support | November 2024 | jsonschema | Open In Colab | [Discuss](https://github.com/deepset-ai/haystack-experimental/discussions/98)| | [`ChatMessageWriter`][5] | Memory Component | December 2024 | None | Open In Colab | [Discuss](https://github.com/deepset-ai/haystack-experimental/discussions/75) | | [`ChatMessageRetriever`][6] | Memory Component | December 2024 | None | Open In Colab | [Discuss](https://github.com/deepset-ai/haystack-experimental/discussions/75) | -| [`InMemoryChatMessageStore`][7] | Memory Store | December 2024 | None | Open In Colab | [Discuss](https://github.com/deepset-ai/haystack-experimental/discussions/75) | +| [`InMemoryChatMessageStore`][7] | Memory Store | December 2024 | None | Open In Colab | [Discuss](https://github.com/deepset-ai/haystack-experimental/discussions/75) | | [`Auto-Merging Retriever`][8] & [`HierarchicalDocumentSplitter`][9]| Document Splitting & Retrieval Technique | December 2024 | None | Open In Colab | [Discuss](https://github.com/deepset-ai/haystack-experimental/discussions/78) | | [`LLMetadataExtractor`][13] | Metadata extraction with LLM | December 2024 | None | | | @@ -63,6 +63,7 @@ The latest version of the package contains the following experiments: [11]: https://github.com/deepset-ai/haystack-experimental/blob/main/haystack_experimental/components/generators/chat/openai.py [12]: https://github.com/deepset-ai/haystack-experimental/blob/main/haystack_experimental/components/tools/tool_invoker.py [13]: https://github.com/deepset-ai/haystack-experimental/blob/main/haystack_experimental/components/extractors/llm_metadata_extractor.py +[14]: https://github.com/deepset-ai/haystack-experimental/blob/main/haystack_experimental/components/generators/ollama/chat/chat_generator.py ## Usage diff --git a/docs/pydoc/config/generators_api.yml b/docs/pydoc/config/generators_api.yml index 1a0ec7c5..3ef80214 100644 --- a/docs/pydoc/config/generators_api.yml +++ b/docs/pydoc/config/generators_api.yml @@ -1,7 +1,8 @@ loaders: - type: haystack_pydoc_tools.loaders.CustomPythonLoader search_path: [../../../] - modules: ["haystack_experimental.components.generators.chat.openai"] + modules: ["haystack_experimental.components.generators.chat.openai", + "haystack_experimental.components.generators.ollama.chat.chat_generator"] ignore_when_discovered: ["__init__"] processors: - type: filter diff --git a/haystack_experimental/components/__init__.py b/haystack_experimental/components/__init__.py index 82324492..a6323214 100644 --- a/haystack_experimental/components/__init__.py +++ b/haystack_experimental/components/__init__.py @@ -5,6 +5,7 @@ from .extractors import LLMMetadataExtractor from .generators.chat import OpenAIChatGenerator +from .generators.ollama.chat.chat_generator import OllamaChatGenerator from .retrievers.auto_merging_retriever import AutoMergingRetriever from .retrievers.chat_message_retriever import ChatMessageRetriever from .splitters import HierarchicalDocumentSplitter @@ -15,9 +16,10 @@ "AutoMergingRetriever", "ChatMessageWriter", "ChatMessageRetriever", + "OllamaChatGenerator", "OpenAIChatGenerator", "LLMMetadataExtractor", "HierarchicalDocumentSplitter", "OpenAIFunctionCaller", - "ToolInvoker" + "ToolInvoker", ] diff --git a/haystack_experimental/components/generators/ollama/__init__.py b/haystack_experimental/components/generators/ollama/__init__.py new file mode 100644 index 00000000..e57d5847 --- /dev/null +++ b/haystack_experimental/components/generators/ollama/__init__.py @@ -0,0 +1,9 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from .chat.chat_generator import OllamaChatGenerator + +__all__ = [ + "OllamaChatGenerator", +] diff --git a/haystack_experimental/components/generators/ollama/chat/__init__.py b/haystack_experimental/components/generators/ollama/chat/__init__.py new file mode 100644 index 00000000..85b66d68 --- /dev/null +++ b/haystack_experimental/components/generators/ollama/chat/__init__.py @@ -0,0 +1,9 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from .chat_generator import OllamaChatGenerator + +__all__ = [ + "OllamaChatGenerator", +] diff --git a/haystack_experimental/components/generators/ollama/chat/chat_generator.py b/haystack_experimental/components/generators/ollama/chat/chat_generator.py new file mode 100644 index 00000000..967cd1ab --- /dev/null +++ b/haystack_experimental/components/generators/ollama/chat/chat_generator.py @@ -0,0 +1,248 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Callable, Dict, List, Optional, Type + +from haystack import component, default_from_dict +from haystack.dataclasses import StreamingChunk +from haystack.lazy_imports import LazyImport +from haystack.utils.callable_serialization import deserialize_callable + +from haystack_experimental.dataclasses import ChatMessage, ToolCall +from haystack_experimental.dataclasses.tool import Tool, deserialize_tools_inplace + +with LazyImport("Run 'pip install ollama-haystack'") as ollama_integration_import: + # pylint: disable=import-error + from haystack_integrations.components.generators.ollama import OllamaChatGenerator as OllamaChatGeneratorBase + + +# The following code block ensures that: +# - we reuse existing code where possible +# - people can use haystack-experimental without installing ollama-haystack. +# +# +# If ollama-haystack is installed: all works correctly. +# +# If ollama-haystack is not installed: +# - haystack-experimental package works fine (no import errors). +# - OllamaChatGenerator fails with ImportError at init (due to ollama_integration_import.check()). + +if ollama_integration_import.is_successful(): + chatgenerator_base_class: Type[OllamaChatGeneratorBase] = OllamaChatGeneratorBase +else: + chatgenerator_base_class: Type[object] = object # type: ignore[no-redef] + + +def _convert_message_to_ollama_format(message: ChatMessage) -> Dict[str, Any]: + """ + Convert a message to the format expected by Ollama Chat API. + """ + text_contents = message.texts + tool_calls = message.tool_calls + tool_call_results = message.tool_call_results + + if not text_contents and not tool_calls and not tool_call_results: + raise ValueError("A `ChatMessage` must contain at least one `TextContent`, `ToolCall`, or `ToolCallResult`.") + elif len(text_contents) + len(tool_call_results) > 1: + raise ValueError("A `ChatMessage` can only contain one `TextContent` or one `ToolCallResult`.") + + ollama_msg: Dict[str, Any] = {"role": message._role.value} + + if tool_call_results: + # Ollama does not provide a way to communicate errors in tool invocations, so we ignore the error field + ollama_msg["content"] = tool_call_results[0].result + return ollama_msg + + if text_contents: + ollama_msg["content"] = text_contents[0] + if tool_calls: + # Ollama does not support tool call id, so we ignore it + ollama_msg["tool_calls"] = [ + {"type": "function", "function": {"name": tc.tool_name, "arguments": tc.arguments}} for tc in tool_calls + ] + return ollama_msg + + +@component() +class OllamaChatGenerator(chatgenerator_base_class): + """ + Supports models running on Ollama. + + Find the full list of supported models [here](https://ollama.ai/library). + + Usage example: + ```python + from haystack_experimental.components.generators.ollama import OllamaChatGenerator + from haystack_experimental.dataclasses import ChatMessage + + generator = OllamaChatGenerator(model="zephyr", + url = "http://localhost:11434", + generation_kwargs={ + "num_predict": 100, + "temperature": 0.9, + }) + + messages = [ChatMessage.from_system("\nYou are a helpful, respectful and honest assistant"), + ChatMessage.from_user("What's Natural Language Processing?")] + + print(generator.run(messages=messages)) + ``` + """ + + def __init__( + self, + model: str = "orca-mini", + url: str = "http://localhost:11434", + generation_kwargs: Optional[Dict[str, Any]] = None, + timeout: int = 120, + streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, + tools: Optional[List[Tool]] = None, + ): + """ + Creates an instance of OllamaChatGenerator. + + :param model: + The name of the model to use. The model should be available in the running Ollama instance. + :param url: + The URL of a running Ollama instance. + :param generation_kwargs: + Optional arguments to pass to the Ollama generation endpoint, such as temperature, + top_p, and others. See the available arguments in + [Ollama docs](https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values). + :param timeout: + The number of seconds before throwing a timeout error from the Ollama API. + :param streaming_callback: + A callback function that is called when a new token is received from the stream. + The callback function accepts StreamingChunk as an argument. + :param tools: + A list of tools for which the model can prepare calls. + Not all models support tools. For a list of models compatible with tools, see the + [models page](https://ollama.com/search?c=tools). + """ + ollama_integration_import.check() + + if tools: + tool_names = [tool.name for tool in tools] + duplicate_tool_names = {name for name in tool_names if tool_names.count(name) > 1} + if duplicate_tool_names: + raise ValueError(f"Duplicate tool names found: {duplicate_tool_names}") + self.tools = tools + + super(OllamaChatGenerator, self).__init__( + model=model, + url=url, + generation_kwargs=generation_kwargs, + timeout=timeout, + streaming_callback=streaming_callback, + ) + + def to_dict(self) -> Dict[str, Any]: + """ + Serialize this component to a dictionary. + + :returns: + The serialized component as a dictionary. + """ + serialized = super(OllamaChatGenerator, self).to_dict() + serialized["init_parameters"]["tools"] = [tool.to_dict() for tool in self.tools] if self.tools else None + return serialized + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "OllamaChatGenerator": + """ + Deserialize this component from a dictionary. + + :param data: The dictionary representation of this component. + :returns: + The deserialized component instance. + """ + deserialize_tools_inplace(data["init_parameters"], key="tools") + init_params = data.get("init_parameters", {}) + serialized_callback_handler = init_params.get("streaming_callback") + if serialized_callback_handler: + data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler) + + return default_from_dict(cls, data) + + def _build_message_from_ollama_response(self, ollama_response: Dict[str, Any]) -> ChatMessage: + """ + Converts the non-streaming response from the Ollama API to a ChatMessage. + """ + ollama_message = ollama_response["message"] + + text = ollama_message["content"] + + tool_calls = [] + if ollama_tool_calls := ollama_message.get("tool_calls"): + for ollama_tc in ollama_tool_calls: + tool_calls.append( + ToolCall(tool_name=ollama_tc["function"]["name"], arguments=ollama_tc["function"]["arguments"]) + ) + + message = ChatMessage.from_assistant(text=text, tool_calls=tool_calls) + + message.meta.update({key: value for key, value in ollama_response.items() if key != "message"}) + return message + + def _convert_to_streaming_response(self, chunks: List[StreamingChunk]) -> Dict[str, List[Any]]: + """ + Converts a list of chunks response required Haystack format. + """ + + # Unaltered from the integration code. Overridden to use the experimental ChatMessage dataclass. + + replies = [ChatMessage.from_assistant("".join([c.content for c in chunks]))] + meta = {key: value for key, value in chunks[0].meta.items() if key != "message"} + + return {"replies": replies, "meta": [meta]} + + @component.output_types(replies=List[ChatMessage]) + def run( + self, + messages: List[ChatMessage], + generation_kwargs: Optional[Dict[str, Any]] = None, + tools: Optional[List[Tool]] = None, + ): + """ + Runs an Ollama Model on a given chat history. + + :param messages: + A list of ChatMessage instances representing the input messages. + :param generation_kwargs: + Optional arguments to pass to the Ollama generation endpoint, such as temperature, + top_p, etc. See the + [Ollama docs](https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values). + :param tools: + A list of tools for which the model can prepare calls. If set, it will override the `tools` parameter set + during component initialization. + :returns: A dictionary with the following keys: + - `replies`: The responses from the model + """ + generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})} + + stream = self.streaming_callback is not None + tools = tools or self.tools + + if stream and tools: + raise ValueError("Ollama does not support tools and streaming at the same time. Please choose one.") + + ollama_tools = None + if tools: + tool_names = [tool.name for tool in tools] + duplicate_tool_names = {name for name in tool_names if tool_names.count(name) > 1} + if duplicate_tool_names: + raise ValueError(f"Duplicate tool names found: {duplicate_tool_names}") + + ollama_tools = [{"type": "function", "function": {**t.tool_spec}} for t in tools] + + ollama_messages = [_convert_message_to_ollama_format(msg) for msg in messages] + response = self._client.chat( + model=self.model, messages=ollama_messages, tools=ollama_tools, stream=stream, options=generation_kwargs + ) + + if stream: + chunks: List[StreamingChunk] = self._handle_streaming_response(response) + return self._convert_to_streaming_response(chunks) + + return {"replies": [self._build_message_from_ollama_response(response)]} diff --git a/haystack_experimental/dataclasses/chat_message.py b/haystack_experimental/dataclasses/chat_message.py index 5416f639..9f139af1 100644 --- a/haystack_experimental/dataclasses/chat_message.py +++ b/haystack_experimental/dataclasses/chat_message.py @@ -193,7 +193,7 @@ def from_assistant( :returns: A new ChatMessage instance. """ content: List[ChatMessageContentT] = [] - if text: + if text is not None: content.append(TextContent(text=text)) if tool_calls: content.extend(tool_calls) diff --git a/pyproject.toml b/pyproject.toml index 93c3b910..d2b7f892 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,8 +57,9 @@ extra-dependencies = [ "cohere-haystack", "anthropic-haystack", "fastapi", - # Tool + # Tools support "jsonschema", + "ollama-haystack>=1.0.0", # LLMMetadataExtractor dependencies "amazon-bedrock-haystack>=1.0.2", "google-vertex-haystack>=2.0.0", diff --git a/test/components/generators/ollama/test_chat_generator.py b/test/components/generators/ollama/test_chat_generator.py new file mode 100644 index 00000000..ca74edac --- /dev/null +++ b/test/components/generators/ollama/test_chat_generator.py @@ -0,0 +1,462 @@ +from unittest.mock import Mock, patch +import sys +import json +import pytest + +from haystack.components.generators.utils import print_streaming_chunk +from haystack.dataclasses import StreamingChunk +from ollama._types import ResponseError + +from haystack_experimental.dataclasses import ( + ChatMessage, + ChatRole, + TextContent, + ToolCall, + Tool, +) +from haystack_experimental.components.generators.ollama.chat.chat_generator import ( + OllamaChatGenerator, + _convert_message_to_ollama_format, +) + + +@pytest.fixture +def tools(): + tool_parameters = { + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"], + } + tool = Tool( + name="weather", + description="useful to determine the weather in a given location", + parameters=tool_parameters, + function=lambda x: x, + ) + + return [tool] + + +def test_convert_message_to_ollama_format(): + message = ChatMessage.from_system("You are good assistant") + assert _convert_message_to_ollama_format(message) == { + "role": "system", + "content": "You are good assistant", + } + + message = ChatMessage.from_user("I have a question") + assert _convert_message_to_ollama_format(message) == { + "role": "user", + "content": "I have a question", + } + + message = ChatMessage.from_assistant(text="I have an answer", meta={"finish_reason": "stop"}) + assert _convert_message_to_ollama_format(message) == { + "role": "assistant", + "content": "I have an answer", + } + + message = ChatMessage.from_assistant( + tool_calls=[ToolCall(id="123", tool_name="weather", arguments={"city": "Paris"})] + ) + assert _convert_message_to_ollama_format(message) == { + "role": "assistant", + "tool_calls": [ + { + "type": "function", + "function": {"name": "weather", "arguments": {"city": "Paris"}}, + } + ], + } + + tool_result = json.dumps({"weather": "sunny", "temperature": "25"}) + message = ChatMessage.from_tool( + tool_result=tool_result, + origin=ToolCall(tool_name="weather", arguments={"city": "Paris"}), + ) + assert _convert_message_to_ollama_format(message) == { + "role": "tool", + "content": tool_result, + } + + +def test_convert_message_to_ollama_invalid(): + message = ChatMessage(_role=ChatRole.ASSISTANT, _content=[]) + with pytest.raises(ValueError): + _convert_message_to_ollama_format(message) + + message = ChatMessage( + _role=ChatRole.ASSISTANT, + _content=[ + TextContent(text="I have an answer"), + TextContent(text="I have another answer"), + ], + ) + with pytest.raises(ValueError): + _convert_message_to_ollama_format(message) + + +class TestOllamaChatGenerator: + def test_init_default(self): + component = OllamaChatGenerator() + assert component.model == "orca-mini" + assert component.url == "http://localhost:11434" + assert component.generation_kwargs == {} + assert component.timeout == 120 + assert component.streaming_callback is None + assert component.tools is None + + def test_init(self, tools): + component = OllamaChatGenerator( + model="llama2", + url="http://my-custom-endpoint:11434", + generation_kwargs={"temperature": 0.5}, + timeout=5, + streaming_callback=print_streaming_chunk, + tools=tools, + ) + + assert component.model == "llama2" + assert component.url == "http://my-custom-endpoint:11434" + assert component.generation_kwargs == {"temperature": 0.5} + assert component.timeout == 5 + assert component.streaming_callback is print_streaming_chunk + assert component.tools == tools + + def test_init_fail_with_duplicate_tool_names(self, tools): + + duplicate_tools = [tools[0], tools[0]] + with pytest.raises(ValueError): + OllamaChatGenerator(tools=duplicate_tools) + + def test_to_dict(self): + tool = Tool( + name="name", + description="description", + parameters={"x": {"type": "string"}}, + function=print, + ) + + component = OllamaChatGenerator( + model="llama2", + streaming_callback=print_streaming_chunk, + url="custom_url", + generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, + tools=[tool], + ) + data = component.to_dict() + assert data == { + "type": "haystack_experimental.components.generators.ollama.chat.chat_generator.OllamaChatGenerator", + "init_parameters": { + "timeout": 120, + "model": "llama2", + "url": "custom_url", + "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", + "generation_kwargs": { + "max_tokens": 10, + "some_test_param": "test-params", + }, + "tools": [ + { + "description": "description", + "function": "builtins.print", + "name": "name", + "parameters": { + "x": { + "type": "string", + }, + }, + }, + ], + }, + } + + def test_from_dict(self): + tool = Tool( + name="name", + description="description", + parameters={"x": {"type": "string"}}, + function=print, + ) + + data = { + "type": "haystack_experimental.components.generators.ollama.chat.chat_generator.OllamaChatGenerator", + "init_parameters": { + "timeout": 120, + "model": "llama2", + "url": "custom_url", + "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", + "generation_kwargs": { + "max_tokens": 10, + "some_test_param": "test-params", + }, + "tools": [ + { + "description": "description", + "function": "builtins.print", + "name": "name", + "parameters": { + "x": { + "type": "string", + }, + }, + }, + ], + }, + } + component = OllamaChatGenerator.from_dict(data) + assert component.model == "llama2" + assert component.streaming_callback is print_streaming_chunk + assert component.url == "custom_url" + assert component.generation_kwargs == { + "max_tokens": 10, + "some_test_param": "test-params", + } + assert component.timeout == 120 + assert component.tools == [tool] + + def test_build_message_from_ollama_response(self): + model = "some_model" + + ollama_response = { + "model": model, + "created_at": "2023-12-12T14:13:43.416799Z", + "message": {"role": "assistant", "content": "Hello! How are you today?"}, + "done": True, + "total_duration": 5191566416, + "load_duration": 2154458, + "prompt_eval_count": 26, + "prompt_eval_duration": 383809000, + "eval_count": 298, + "eval_duration": 4799921000, + } + + observed = OllamaChatGenerator(model=model)._build_message_from_ollama_response(ollama_response) + + assert observed.role == "assistant" + assert observed.text == "Hello! How are you today?" + + def test_build_message_from_ollama_response_with_tools(self): + model = "some_model" + + ollama_response = { + "model": model, + "created_at": "2023-12-12T14:13:43.416799Z", + "message": { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "function": { + "name": "get_current_weather", + "arguments": {"format": "celsius", "location": "Paris, FR"}, + } + } + ], + }, + "done": True, + "total_duration": 5191566416, + "load_duration": 2154458, + "prompt_eval_count": 26, + "prompt_eval_duration": 383809000, + "eval_count": 298, + "eval_duration": 4799921000, + } + + observed = OllamaChatGenerator(model=model)._build_message_from_ollama_response(ollama_response) + + assert observed.role == "assistant" + assert observed.text == "" + assert observed.tool_call == ToolCall( + tool_name="get_current_weather", + arguments={"format": "celsius", "location": "Paris, FR"}, + ) + + @patch("haystack_integrations.components.generators.ollama.chat.chat_generator.Client") + def test_run(self, mock_client): + generator = OllamaChatGenerator() + + mock_response = { + "model": "llama3.2", + "created_at": "2023-12-12T14:13:43.416799Z", + "message": { + "role": "assistant", + "content": "Fine. How can I help you today?", + }, + "done": True, + "total_duration": 5191566416, + "load_duration": 2154458, + "prompt_eval_count": 26, + "prompt_eval_duration": 383809000, + "eval_count": 298, + "eval_duration": 4799921000, + } + + mock_client_instance = mock_client.return_value + mock_client_instance.chat.return_value = mock_response + + result = generator.run(messages=[ChatMessage.from_user("Hello! How are you today?")]) + + mock_client_instance.chat.assert_called_once_with( + model="orca-mini", + messages=[{"role": "user", "content": "Hello! How are you today?"}], + stream=False, + tools=None, + options={}, + ) + + assert "replies" in result + assert len(result["replies"]) == 1 + assert result["replies"][0].text == "Fine. How can I help you today?" + assert result["replies"][0].role == "assistant" + + @patch("haystack_integrations.components.generators.ollama.chat.chat_generator.Client") + def test_run_streaming(self, mock_client): + streaming_callback_called = False + + def streaming_callback(chunk: StreamingChunk) -> None: + nonlocal streaming_callback_called + streaming_callback_called = True + + generator = OllamaChatGenerator(streaming_callback=streaming_callback) + + mock_response = iter( + [ + { + "model": "llama3.2", + "created_at": "2023-12-12T14:13:43.416799Z", + "message": {"role": "assistant", "content": "first chunk "}, + "done": False, + }, + { + "model": "llama3.2", + "created_at": "2023-12-12T14:13:43.416799Z", + "message": {"role": "assistant", "content": "second chunk"}, + "done": True, + "total_duration": 4883583458, + "load_duration": 1334875, + "prompt_eval_count": 26, + "prompt_eval_duration": 342546000, + "eval_count": 282, + "eval_duration": 4535599000, + }, + ] + ) + + mock_client_instance = mock_client.return_value + mock_client_instance.chat.return_value = mock_response + + result = generator.run(messages=[ChatMessage.from_user("irrelevant")]) + + assert streaming_callback_called + + assert "replies" in result + assert len(result["replies"]) == 1 + assert result["replies"][0].text == "first chunk second chunk" + assert result["replies"][0].role == "assistant" + + def test_run_fail_with_tools_and_streaming(self, tools): + component = OllamaChatGenerator(tools=tools, streaming_callback=print_streaming_chunk) + + with pytest.raises(ValueError): + message = ChatMessage.from_user("irrelevant") + component.run([message]) + + @pytest.mark.integration + @pytest.mark.skipif( + sys.platform != "linux", + reason="For simplicity, we only run the integration tests on Linux.", + ) + def test_live_run(self): + chat_generator = OllamaChatGenerator(model="llama3.2:3b") + + user_questions_and_assistant_answers = [ + ("What's the capital of France?", "Paris"), + ("What is the capital of Canada?", "Ottawa"), + ("What is the capital of Ghana?", "Accra"), + ] + + for question, answer in user_questions_and_assistant_answers: + message = ChatMessage.from_user(question) + + response = chat_generator.run([message]) + + assert isinstance(response, dict) + assert isinstance(response["replies"], list) + assert answer in response["replies"][0].text + + @pytest.mark.integration + @pytest.mark.skipif( + sys.platform != "linux", + reason="For simplicity, we only run the integration tests on Linux.", + ) + def test_run_with_chat_history(self): + chat_generator = OllamaChatGenerator(model="llama3.2:3b") + + chat_messages = [ + ChatMessage.from_user("What is the largest city in the United Kingdom by population?"), + ChatMessage.from_assistant("London is the largest city in the United Kingdom by population"), + ChatMessage.from_user("And what is the second largest?"), + ] + + response = chat_generator.run(chat_messages) + + assert isinstance(response, dict) + assert isinstance(response["replies"], list) + + assert any(city in response["replies"][-1].text for city in ["Manchester", "Birmingham", "Glasgow"]) + + @pytest.mark.integration + @pytest.mark.skipif( + sys.platform != "linux", + reason="For simplicity, we only run the integration tests on Linux.", + ) + def test_run_model_unavailable(self): + component = OllamaChatGenerator(model="unknown_model") + + with pytest.raises(ResponseError): + message = ChatMessage.from_user("irrelevant") + component.run([message]) + + @pytest.mark.integration + @pytest.mark.skipif( + sys.platform != "linux", + reason="For simplicity, we only run the integration tests on Linux.", + ) + def test_run_with_streaming(self): + streaming_callback = Mock() + chat_generator = OllamaChatGenerator(model="llama3.2:3b", streaming_callback=streaming_callback) + + chat_messages = [ + ChatMessage.from_user("What is the largest city in the United Kingdom by population?"), + ChatMessage.from_assistant("London is the largest city in the United Kingdom by population"), + ChatMessage.from_user("And what is the second largest?"), + ] + + response = chat_generator.run(chat_messages) + + streaming_callback.assert_called() + + assert isinstance(response, dict) + assert isinstance(response["replies"], list) + assert any(city in response["replies"][-1].text for city in ["Manchester", "Birmingham", "Glasgow"]) + + @pytest.mark.integration + @pytest.mark.skipif( + sys.platform != "linux", + reason="For simplicity, we only run the integration tests on Linux.", + ) + def test_run_with_tools(self, tools): + chat_generator = OllamaChatGenerator(model="llama3.2:3b", tools=tools) + + message = ChatMessage.from_user("What is the weather in Paris?") + response = chat_generator.run([message]) + + assert len(response["replies"]) == 1 + message = response["replies"][0] + + assert message.tool_calls + tool_call = message.tool_call + assert isinstance(tool_call, ToolCall) + assert tool_call.tool_name == "weather" + assert tool_call.arguments == {"city": "Paris"}