diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml
index ba960d2c..de95e1bb 100644
--- a/.github/workflows/tests.yml
+++ b/.github/workflows/tests.yml
@@ -35,6 +35,7 @@ env:
COHERE_API_KEY: ${{ secrets.COHERE_API_KEY }}
FIRECRAWL_API_KEY: ${{ secrets.FIRECRAWL_API_KEY }}
SERPERDEV_API_KEY: ${{ secrets.SERPERDEV_API_KEY }}
+ OLLAMA_LLM_FOR_TESTS: "llama3.2:3b"
jobs:
linting:
@@ -117,5 +118,26 @@ jobs:
- name: Install Hatch
run: pip install hatch==${{ env.HATCH_VERSION }}
+ - name: Install Ollama and pull the required models
+ if: matrix.os == 'ubuntu-latest'
+ run: |
+ curl -fsSL https://ollama.com/install.sh | sh
+ ollama serve &
+
+ # Check if the service is up and running with a timeout of 60 seconds
+ timeout=60
+ while [ $timeout -gt 0 ] && ! curl -sSf http://localhost:11434/ > /dev/null; do
+ echo "Waiting for Ollama service to start..."
+ sleep 5
+ ((timeout-=5))
+ done
+
+ if [ $timeout -eq 0 ]; then
+ echo "Timed out waiting for Ollama service to start."
+ exit 1
+ fi
+
+ ollama pull ${{ env.OLLAMA_LLM_FOR_TESTS }}
+
- name: Run
run: hatch run test:integration
diff --git a/README.md b/README.md
index e6cba746..c49f6d9d 100644
--- a/README.md
+++ b/README.md
@@ -36,17 +36,17 @@ that includes it. Once it reaches the end of its lifespan, the experiment will b
The latest version of the package contains the following experiments:
-
+
=======
| Name | Type | Expected End Date | Dependencies | Cookbook | Discussion |
| --------------------------- | -------------------------- | ---------------------------- | ------------ | -------- | ---------- |
| [`EvaluationHarness`][1] | Evaluation orchestrator | October 2024 | None | | [Discuss](https://github.com/deepset-ai/haystack-experimental/discussions/74) |
| [`OpenAIFunctionCaller`][2] | Function Calling Component | October 2024 | None | 🔜 | |
| [`OpenAPITool`][3] | OpenAPITool component | October 2024 | jsonref | | [Discuss](https://github.com/deepset-ai/haystack-experimental/discussions/79)|
-| Support for Tools: [refactored `ChatMessage` dataclass][10], [`Tool` dataclass][4], [refactored `OpenAIChatGenerator`][11], [`ToolInvoker` component][12] | Tool Calling support | November 2024 | jsonschema | | [Discuss](https://github.com/deepset-ai/haystack-experimental/discussions/98)|
+| Support for Tools: [refactored `ChatMessage` dataclass][10], [`Tool` dataclass][4], [refactored `OpenAIChatGenerator`][11], [refactored `OllamaChatGenerator`][14], [`ToolInvoker` component][12] | Tool Calling support | November 2024 | jsonschema | | [Discuss](https://github.com/deepset-ai/haystack-experimental/discussions/98)|
| [`ChatMessageWriter`][5] | Memory Component | December 2024 | None | | [Discuss](https://github.com/deepset-ai/haystack-experimental/discussions/75) |
| [`ChatMessageRetriever`][6] | Memory Component | December 2024 | None | | [Discuss](https://github.com/deepset-ai/haystack-experimental/discussions/75) |
-| [`InMemoryChatMessageStore`][7] | Memory Store | December 2024 | None | | [Discuss](https://github.com/deepset-ai/haystack-experimental/discussions/75) |
+| [`InMemoryChatMessageStore`][7] | Memory Store | December 2024 | None | | [Discuss](https://github.com/deepset-ai/haystack-experimental/discussions/75) |
| [`Auto-Merging Retriever`][8] & [`HierarchicalDocumentSplitter`][9]| Document Splitting & Retrieval Technique | December 2024 | None | | [Discuss](https://github.com/deepset-ai/haystack-experimental/discussions/78) |
| [`LLMetadataExtractor`][13] | Metadata extraction with LLM | December 2024 | None | | |
@@ -63,6 +63,7 @@ The latest version of the package contains the following experiments:
[11]: https://github.com/deepset-ai/haystack-experimental/blob/main/haystack_experimental/components/generators/chat/openai.py
[12]: https://github.com/deepset-ai/haystack-experimental/blob/main/haystack_experimental/components/tools/tool_invoker.py
[13]: https://github.com/deepset-ai/haystack-experimental/blob/main/haystack_experimental/components/extractors/llm_metadata_extractor.py
+[14]: https://github.com/deepset-ai/haystack-experimental/blob/main/haystack_experimental/components/generators/ollama/chat/chat_generator.py
## Usage
diff --git a/docs/pydoc/config/generators_api.yml b/docs/pydoc/config/generators_api.yml
index 1a0ec7c5..3ef80214 100644
--- a/docs/pydoc/config/generators_api.yml
+++ b/docs/pydoc/config/generators_api.yml
@@ -1,7 +1,8 @@
loaders:
- type: haystack_pydoc_tools.loaders.CustomPythonLoader
search_path: [../../../]
- modules: ["haystack_experimental.components.generators.chat.openai"]
+ modules: ["haystack_experimental.components.generators.chat.openai",
+ "haystack_experimental.components.generators.ollama.chat.chat_generator"]
ignore_when_discovered: ["__init__"]
processors:
- type: filter
diff --git a/haystack_experimental/components/__init__.py b/haystack_experimental/components/__init__.py
index 82324492..a6323214 100644
--- a/haystack_experimental/components/__init__.py
+++ b/haystack_experimental/components/__init__.py
@@ -5,6 +5,7 @@
from .extractors import LLMMetadataExtractor
from .generators.chat import OpenAIChatGenerator
+from .generators.ollama.chat.chat_generator import OllamaChatGenerator
from .retrievers.auto_merging_retriever import AutoMergingRetriever
from .retrievers.chat_message_retriever import ChatMessageRetriever
from .splitters import HierarchicalDocumentSplitter
@@ -15,9 +16,10 @@
"AutoMergingRetriever",
"ChatMessageWriter",
"ChatMessageRetriever",
+ "OllamaChatGenerator",
"OpenAIChatGenerator",
"LLMMetadataExtractor",
"HierarchicalDocumentSplitter",
"OpenAIFunctionCaller",
- "ToolInvoker"
+ "ToolInvoker",
]
diff --git a/haystack_experimental/components/generators/ollama/__init__.py b/haystack_experimental/components/generators/ollama/__init__.py
new file mode 100644
index 00000000..e57d5847
--- /dev/null
+++ b/haystack_experimental/components/generators/ollama/__init__.py
@@ -0,0 +1,9 @@
+# SPDX-FileCopyrightText: 2022-present deepset GmbH
+#
+# SPDX-License-Identifier: Apache-2.0
+
+from .chat.chat_generator import OllamaChatGenerator
+
+__all__ = [
+ "OllamaChatGenerator",
+]
diff --git a/haystack_experimental/components/generators/ollama/chat/__init__.py b/haystack_experimental/components/generators/ollama/chat/__init__.py
new file mode 100644
index 00000000..85b66d68
--- /dev/null
+++ b/haystack_experimental/components/generators/ollama/chat/__init__.py
@@ -0,0 +1,9 @@
+# SPDX-FileCopyrightText: 2022-present deepset GmbH
+#
+# SPDX-License-Identifier: Apache-2.0
+
+from .chat_generator import OllamaChatGenerator
+
+__all__ = [
+ "OllamaChatGenerator",
+]
diff --git a/haystack_experimental/components/generators/ollama/chat/chat_generator.py b/haystack_experimental/components/generators/ollama/chat/chat_generator.py
new file mode 100644
index 00000000..967cd1ab
--- /dev/null
+++ b/haystack_experimental/components/generators/ollama/chat/chat_generator.py
@@ -0,0 +1,248 @@
+# SPDX-FileCopyrightText: 2022-present deepset GmbH
+#
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Any, Callable, Dict, List, Optional, Type
+
+from haystack import component, default_from_dict
+from haystack.dataclasses import StreamingChunk
+from haystack.lazy_imports import LazyImport
+from haystack.utils.callable_serialization import deserialize_callable
+
+from haystack_experimental.dataclasses import ChatMessage, ToolCall
+from haystack_experimental.dataclasses.tool import Tool, deserialize_tools_inplace
+
+with LazyImport("Run 'pip install ollama-haystack'") as ollama_integration_import:
+ # pylint: disable=import-error
+ from haystack_integrations.components.generators.ollama import OllamaChatGenerator as OllamaChatGeneratorBase
+
+
+# The following code block ensures that:
+# - we reuse existing code where possible
+# - people can use haystack-experimental without installing ollama-haystack.
+#
+#
+# If ollama-haystack is installed: all works correctly.
+#
+# If ollama-haystack is not installed:
+# - haystack-experimental package works fine (no import errors).
+# - OllamaChatGenerator fails with ImportError at init (due to ollama_integration_import.check()).
+
+if ollama_integration_import.is_successful():
+ chatgenerator_base_class: Type[OllamaChatGeneratorBase] = OllamaChatGeneratorBase
+else:
+ chatgenerator_base_class: Type[object] = object # type: ignore[no-redef]
+
+
+def _convert_message_to_ollama_format(message: ChatMessage) -> Dict[str, Any]:
+ """
+ Convert a message to the format expected by Ollama Chat API.
+ """
+ text_contents = message.texts
+ tool_calls = message.tool_calls
+ tool_call_results = message.tool_call_results
+
+ if not text_contents and not tool_calls and not tool_call_results:
+ raise ValueError("A `ChatMessage` must contain at least one `TextContent`, `ToolCall`, or `ToolCallResult`.")
+ elif len(text_contents) + len(tool_call_results) > 1:
+ raise ValueError("A `ChatMessage` can only contain one `TextContent` or one `ToolCallResult`.")
+
+ ollama_msg: Dict[str, Any] = {"role": message._role.value}
+
+ if tool_call_results:
+ # Ollama does not provide a way to communicate errors in tool invocations, so we ignore the error field
+ ollama_msg["content"] = tool_call_results[0].result
+ return ollama_msg
+
+ if text_contents:
+ ollama_msg["content"] = text_contents[0]
+ if tool_calls:
+ # Ollama does not support tool call id, so we ignore it
+ ollama_msg["tool_calls"] = [
+ {"type": "function", "function": {"name": tc.tool_name, "arguments": tc.arguments}} for tc in tool_calls
+ ]
+ return ollama_msg
+
+
+@component()
+class OllamaChatGenerator(chatgenerator_base_class):
+ """
+ Supports models running on Ollama.
+
+ Find the full list of supported models [here](https://ollama.ai/library).
+
+ Usage example:
+ ```python
+ from haystack_experimental.components.generators.ollama import OllamaChatGenerator
+ from haystack_experimental.dataclasses import ChatMessage
+
+ generator = OllamaChatGenerator(model="zephyr",
+ url = "http://localhost:11434",
+ generation_kwargs={
+ "num_predict": 100,
+ "temperature": 0.9,
+ })
+
+ messages = [ChatMessage.from_system("\nYou are a helpful, respectful and honest assistant"),
+ ChatMessage.from_user("What's Natural Language Processing?")]
+
+ print(generator.run(messages=messages))
+ ```
+ """
+
+ def __init__(
+ self,
+ model: str = "orca-mini",
+ url: str = "http://localhost:11434",
+ generation_kwargs: Optional[Dict[str, Any]] = None,
+ timeout: int = 120,
+ streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
+ tools: Optional[List[Tool]] = None,
+ ):
+ """
+ Creates an instance of OllamaChatGenerator.
+
+ :param model:
+ The name of the model to use. The model should be available in the running Ollama instance.
+ :param url:
+ The URL of a running Ollama instance.
+ :param generation_kwargs:
+ Optional arguments to pass to the Ollama generation endpoint, such as temperature,
+ top_p, and others. See the available arguments in
+ [Ollama docs](https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values).
+ :param timeout:
+ The number of seconds before throwing a timeout error from the Ollama API.
+ :param streaming_callback:
+ A callback function that is called when a new token is received from the stream.
+ The callback function accepts StreamingChunk as an argument.
+ :param tools:
+ A list of tools for which the model can prepare calls.
+ Not all models support tools. For a list of models compatible with tools, see the
+ [models page](https://ollama.com/search?c=tools).
+ """
+ ollama_integration_import.check()
+
+ if tools:
+ tool_names = [tool.name for tool in tools]
+ duplicate_tool_names = {name for name in tool_names if tool_names.count(name) > 1}
+ if duplicate_tool_names:
+ raise ValueError(f"Duplicate tool names found: {duplicate_tool_names}")
+ self.tools = tools
+
+ super(OllamaChatGenerator, self).__init__(
+ model=model,
+ url=url,
+ generation_kwargs=generation_kwargs,
+ timeout=timeout,
+ streaming_callback=streaming_callback,
+ )
+
+ def to_dict(self) -> Dict[str, Any]:
+ """
+ Serialize this component to a dictionary.
+
+ :returns:
+ The serialized component as a dictionary.
+ """
+ serialized = super(OllamaChatGenerator, self).to_dict()
+ serialized["init_parameters"]["tools"] = [tool.to_dict() for tool in self.tools] if self.tools else None
+ return serialized
+
+ @classmethod
+ def from_dict(cls, data: Dict[str, Any]) -> "OllamaChatGenerator":
+ """
+ Deserialize this component from a dictionary.
+
+ :param data: The dictionary representation of this component.
+ :returns:
+ The deserialized component instance.
+ """
+ deserialize_tools_inplace(data["init_parameters"], key="tools")
+ init_params = data.get("init_parameters", {})
+ serialized_callback_handler = init_params.get("streaming_callback")
+ if serialized_callback_handler:
+ data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler)
+
+ return default_from_dict(cls, data)
+
+ def _build_message_from_ollama_response(self, ollama_response: Dict[str, Any]) -> ChatMessage:
+ """
+ Converts the non-streaming response from the Ollama API to a ChatMessage.
+ """
+ ollama_message = ollama_response["message"]
+
+ text = ollama_message["content"]
+
+ tool_calls = []
+ if ollama_tool_calls := ollama_message.get("tool_calls"):
+ for ollama_tc in ollama_tool_calls:
+ tool_calls.append(
+ ToolCall(tool_name=ollama_tc["function"]["name"], arguments=ollama_tc["function"]["arguments"])
+ )
+
+ message = ChatMessage.from_assistant(text=text, tool_calls=tool_calls)
+
+ message.meta.update({key: value for key, value in ollama_response.items() if key != "message"})
+ return message
+
+ def _convert_to_streaming_response(self, chunks: List[StreamingChunk]) -> Dict[str, List[Any]]:
+ """
+ Converts a list of chunks response required Haystack format.
+ """
+
+ # Unaltered from the integration code. Overridden to use the experimental ChatMessage dataclass.
+
+ replies = [ChatMessage.from_assistant("".join([c.content for c in chunks]))]
+ meta = {key: value for key, value in chunks[0].meta.items() if key != "message"}
+
+ return {"replies": replies, "meta": [meta]}
+
+ @component.output_types(replies=List[ChatMessage])
+ def run(
+ self,
+ messages: List[ChatMessage],
+ generation_kwargs: Optional[Dict[str, Any]] = None,
+ tools: Optional[List[Tool]] = None,
+ ):
+ """
+ Runs an Ollama Model on a given chat history.
+
+ :param messages:
+ A list of ChatMessage instances representing the input messages.
+ :param generation_kwargs:
+ Optional arguments to pass to the Ollama generation endpoint, such as temperature,
+ top_p, etc. See the
+ [Ollama docs](https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values).
+ :param tools:
+ A list of tools for which the model can prepare calls. If set, it will override the `tools` parameter set
+ during component initialization.
+ :returns: A dictionary with the following keys:
+ - `replies`: The responses from the model
+ """
+ generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
+
+ stream = self.streaming_callback is not None
+ tools = tools or self.tools
+
+ if stream and tools:
+ raise ValueError("Ollama does not support tools and streaming at the same time. Please choose one.")
+
+ ollama_tools = None
+ if tools:
+ tool_names = [tool.name for tool in tools]
+ duplicate_tool_names = {name for name in tool_names if tool_names.count(name) > 1}
+ if duplicate_tool_names:
+ raise ValueError(f"Duplicate tool names found: {duplicate_tool_names}")
+
+ ollama_tools = [{"type": "function", "function": {**t.tool_spec}} for t in tools]
+
+ ollama_messages = [_convert_message_to_ollama_format(msg) for msg in messages]
+ response = self._client.chat(
+ model=self.model, messages=ollama_messages, tools=ollama_tools, stream=stream, options=generation_kwargs
+ )
+
+ if stream:
+ chunks: List[StreamingChunk] = self._handle_streaming_response(response)
+ return self._convert_to_streaming_response(chunks)
+
+ return {"replies": [self._build_message_from_ollama_response(response)]}
diff --git a/haystack_experimental/dataclasses/chat_message.py b/haystack_experimental/dataclasses/chat_message.py
index 5416f639..9f139af1 100644
--- a/haystack_experimental/dataclasses/chat_message.py
+++ b/haystack_experimental/dataclasses/chat_message.py
@@ -193,7 +193,7 @@ def from_assistant(
:returns: A new ChatMessage instance.
"""
content: List[ChatMessageContentT] = []
- if text:
+ if text is not None:
content.append(TextContent(text=text))
if tool_calls:
content.extend(tool_calls)
diff --git a/pyproject.toml b/pyproject.toml
index 93c3b910..d2b7f892 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -57,8 +57,9 @@ extra-dependencies = [
"cohere-haystack",
"anthropic-haystack",
"fastapi",
- # Tool
+ # Tools support
"jsonschema",
+ "ollama-haystack>=1.0.0",
# LLMMetadataExtractor dependencies
"amazon-bedrock-haystack>=1.0.2",
"google-vertex-haystack>=2.0.0",
diff --git a/test/components/generators/ollama/test_chat_generator.py b/test/components/generators/ollama/test_chat_generator.py
new file mode 100644
index 00000000..ca74edac
--- /dev/null
+++ b/test/components/generators/ollama/test_chat_generator.py
@@ -0,0 +1,462 @@
+from unittest.mock import Mock, patch
+import sys
+import json
+import pytest
+
+from haystack.components.generators.utils import print_streaming_chunk
+from haystack.dataclasses import StreamingChunk
+from ollama._types import ResponseError
+
+from haystack_experimental.dataclasses import (
+ ChatMessage,
+ ChatRole,
+ TextContent,
+ ToolCall,
+ Tool,
+)
+from haystack_experimental.components.generators.ollama.chat.chat_generator import (
+ OllamaChatGenerator,
+ _convert_message_to_ollama_format,
+)
+
+
+@pytest.fixture
+def tools():
+ tool_parameters = {
+ "type": "object",
+ "properties": {"city": {"type": "string"}},
+ "required": ["city"],
+ }
+ tool = Tool(
+ name="weather",
+ description="useful to determine the weather in a given location",
+ parameters=tool_parameters,
+ function=lambda x: x,
+ )
+
+ return [tool]
+
+
+def test_convert_message_to_ollama_format():
+ message = ChatMessage.from_system("You are good assistant")
+ assert _convert_message_to_ollama_format(message) == {
+ "role": "system",
+ "content": "You are good assistant",
+ }
+
+ message = ChatMessage.from_user("I have a question")
+ assert _convert_message_to_ollama_format(message) == {
+ "role": "user",
+ "content": "I have a question",
+ }
+
+ message = ChatMessage.from_assistant(text="I have an answer", meta={"finish_reason": "stop"})
+ assert _convert_message_to_ollama_format(message) == {
+ "role": "assistant",
+ "content": "I have an answer",
+ }
+
+ message = ChatMessage.from_assistant(
+ tool_calls=[ToolCall(id="123", tool_name="weather", arguments={"city": "Paris"})]
+ )
+ assert _convert_message_to_ollama_format(message) == {
+ "role": "assistant",
+ "tool_calls": [
+ {
+ "type": "function",
+ "function": {"name": "weather", "arguments": {"city": "Paris"}},
+ }
+ ],
+ }
+
+ tool_result = json.dumps({"weather": "sunny", "temperature": "25"})
+ message = ChatMessage.from_tool(
+ tool_result=tool_result,
+ origin=ToolCall(tool_name="weather", arguments={"city": "Paris"}),
+ )
+ assert _convert_message_to_ollama_format(message) == {
+ "role": "tool",
+ "content": tool_result,
+ }
+
+
+def test_convert_message_to_ollama_invalid():
+ message = ChatMessage(_role=ChatRole.ASSISTANT, _content=[])
+ with pytest.raises(ValueError):
+ _convert_message_to_ollama_format(message)
+
+ message = ChatMessage(
+ _role=ChatRole.ASSISTANT,
+ _content=[
+ TextContent(text="I have an answer"),
+ TextContent(text="I have another answer"),
+ ],
+ )
+ with pytest.raises(ValueError):
+ _convert_message_to_ollama_format(message)
+
+
+class TestOllamaChatGenerator:
+ def test_init_default(self):
+ component = OllamaChatGenerator()
+ assert component.model == "orca-mini"
+ assert component.url == "http://localhost:11434"
+ assert component.generation_kwargs == {}
+ assert component.timeout == 120
+ assert component.streaming_callback is None
+ assert component.tools is None
+
+ def test_init(self, tools):
+ component = OllamaChatGenerator(
+ model="llama2",
+ url="http://my-custom-endpoint:11434",
+ generation_kwargs={"temperature": 0.5},
+ timeout=5,
+ streaming_callback=print_streaming_chunk,
+ tools=tools,
+ )
+
+ assert component.model == "llama2"
+ assert component.url == "http://my-custom-endpoint:11434"
+ assert component.generation_kwargs == {"temperature": 0.5}
+ assert component.timeout == 5
+ assert component.streaming_callback is print_streaming_chunk
+ assert component.tools == tools
+
+ def test_init_fail_with_duplicate_tool_names(self, tools):
+
+ duplicate_tools = [tools[0], tools[0]]
+ with pytest.raises(ValueError):
+ OllamaChatGenerator(tools=duplicate_tools)
+
+ def test_to_dict(self):
+ tool = Tool(
+ name="name",
+ description="description",
+ parameters={"x": {"type": "string"}},
+ function=print,
+ )
+
+ component = OllamaChatGenerator(
+ model="llama2",
+ streaming_callback=print_streaming_chunk,
+ url="custom_url",
+ generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"},
+ tools=[tool],
+ )
+ data = component.to_dict()
+ assert data == {
+ "type": "haystack_experimental.components.generators.ollama.chat.chat_generator.OllamaChatGenerator",
+ "init_parameters": {
+ "timeout": 120,
+ "model": "llama2",
+ "url": "custom_url",
+ "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk",
+ "generation_kwargs": {
+ "max_tokens": 10,
+ "some_test_param": "test-params",
+ },
+ "tools": [
+ {
+ "description": "description",
+ "function": "builtins.print",
+ "name": "name",
+ "parameters": {
+ "x": {
+ "type": "string",
+ },
+ },
+ },
+ ],
+ },
+ }
+
+ def test_from_dict(self):
+ tool = Tool(
+ name="name",
+ description="description",
+ parameters={"x": {"type": "string"}},
+ function=print,
+ )
+
+ data = {
+ "type": "haystack_experimental.components.generators.ollama.chat.chat_generator.OllamaChatGenerator",
+ "init_parameters": {
+ "timeout": 120,
+ "model": "llama2",
+ "url": "custom_url",
+ "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk",
+ "generation_kwargs": {
+ "max_tokens": 10,
+ "some_test_param": "test-params",
+ },
+ "tools": [
+ {
+ "description": "description",
+ "function": "builtins.print",
+ "name": "name",
+ "parameters": {
+ "x": {
+ "type": "string",
+ },
+ },
+ },
+ ],
+ },
+ }
+ component = OllamaChatGenerator.from_dict(data)
+ assert component.model == "llama2"
+ assert component.streaming_callback is print_streaming_chunk
+ assert component.url == "custom_url"
+ assert component.generation_kwargs == {
+ "max_tokens": 10,
+ "some_test_param": "test-params",
+ }
+ assert component.timeout == 120
+ assert component.tools == [tool]
+
+ def test_build_message_from_ollama_response(self):
+ model = "some_model"
+
+ ollama_response = {
+ "model": model,
+ "created_at": "2023-12-12T14:13:43.416799Z",
+ "message": {"role": "assistant", "content": "Hello! How are you today?"},
+ "done": True,
+ "total_duration": 5191566416,
+ "load_duration": 2154458,
+ "prompt_eval_count": 26,
+ "prompt_eval_duration": 383809000,
+ "eval_count": 298,
+ "eval_duration": 4799921000,
+ }
+
+ observed = OllamaChatGenerator(model=model)._build_message_from_ollama_response(ollama_response)
+
+ assert observed.role == "assistant"
+ assert observed.text == "Hello! How are you today?"
+
+ def test_build_message_from_ollama_response_with_tools(self):
+ model = "some_model"
+
+ ollama_response = {
+ "model": model,
+ "created_at": "2023-12-12T14:13:43.416799Z",
+ "message": {
+ "role": "assistant",
+ "content": "",
+ "tool_calls": [
+ {
+ "function": {
+ "name": "get_current_weather",
+ "arguments": {"format": "celsius", "location": "Paris, FR"},
+ }
+ }
+ ],
+ },
+ "done": True,
+ "total_duration": 5191566416,
+ "load_duration": 2154458,
+ "prompt_eval_count": 26,
+ "prompt_eval_duration": 383809000,
+ "eval_count": 298,
+ "eval_duration": 4799921000,
+ }
+
+ observed = OllamaChatGenerator(model=model)._build_message_from_ollama_response(ollama_response)
+
+ assert observed.role == "assistant"
+ assert observed.text == ""
+ assert observed.tool_call == ToolCall(
+ tool_name="get_current_weather",
+ arguments={"format": "celsius", "location": "Paris, FR"},
+ )
+
+ @patch("haystack_integrations.components.generators.ollama.chat.chat_generator.Client")
+ def test_run(self, mock_client):
+ generator = OllamaChatGenerator()
+
+ mock_response = {
+ "model": "llama3.2",
+ "created_at": "2023-12-12T14:13:43.416799Z",
+ "message": {
+ "role": "assistant",
+ "content": "Fine. How can I help you today?",
+ },
+ "done": True,
+ "total_duration": 5191566416,
+ "load_duration": 2154458,
+ "prompt_eval_count": 26,
+ "prompt_eval_duration": 383809000,
+ "eval_count": 298,
+ "eval_duration": 4799921000,
+ }
+
+ mock_client_instance = mock_client.return_value
+ mock_client_instance.chat.return_value = mock_response
+
+ result = generator.run(messages=[ChatMessage.from_user("Hello! How are you today?")])
+
+ mock_client_instance.chat.assert_called_once_with(
+ model="orca-mini",
+ messages=[{"role": "user", "content": "Hello! How are you today?"}],
+ stream=False,
+ tools=None,
+ options={},
+ )
+
+ assert "replies" in result
+ assert len(result["replies"]) == 1
+ assert result["replies"][0].text == "Fine. How can I help you today?"
+ assert result["replies"][0].role == "assistant"
+
+ @patch("haystack_integrations.components.generators.ollama.chat.chat_generator.Client")
+ def test_run_streaming(self, mock_client):
+ streaming_callback_called = False
+
+ def streaming_callback(chunk: StreamingChunk) -> None:
+ nonlocal streaming_callback_called
+ streaming_callback_called = True
+
+ generator = OllamaChatGenerator(streaming_callback=streaming_callback)
+
+ mock_response = iter(
+ [
+ {
+ "model": "llama3.2",
+ "created_at": "2023-12-12T14:13:43.416799Z",
+ "message": {"role": "assistant", "content": "first chunk "},
+ "done": False,
+ },
+ {
+ "model": "llama3.2",
+ "created_at": "2023-12-12T14:13:43.416799Z",
+ "message": {"role": "assistant", "content": "second chunk"},
+ "done": True,
+ "total_duration": 4883583458,
+ "load_duration": 1334875,
+ "prompt_eval_count": 26,
+ "prompt_eval_duration": 342546000,
+ "eval_count": 282,
+ "eval_duration": 4535599000,
+ },
+ ]
+ )
+
+ mock_client_instance = mock_client.return_value
+ mock_client_instance.chat.return_value = mock_response
+
+ result = generator.run(messages=[ChatMessage.from_user("irrelevant")])
+
+ assert streaming_callback_called
+
+ assert "replies" in result
+ assert len(result["replies"]) == 1
+ assert result["replies"][0].text == "first chunk second chunk"
+ assert result["replies"][0].role == "assistant"
+
+ def test_run_fail_with_tools_and_streaming(self, tools):
+ component = OllamaChatGenerator(tools=tools, streaming_callback=print_streaming_chunk)
+
+ with pytest.raises(ValueError):
+ message = ChatMessage.from_user("irrelevant")
+ component.run([message])
+
+ @pytest.mark.integration
+ @pytest.mark.skipif(
+ sys.platform != "linux",
+ reason="For simplicity, we only run the integration tests on Linux.",
+ )
+ def test_live_run(self):
+ chat_generator = OllamaChatGenerator(model="llama3.2:3b")
+
+ user_questions_and_assistant_answers = [
+ ("What's the capital of France?", "Paris"),
+ ("What is the capital of Canada?", "Ottawa"),
+ ("What is the capital of Ghana?", "Accra"),
+ ]
+
+ for question, answer in user_questions_and_assistant_answers:
+ message = ChatMessage.from_user(question)
+
+ response = chat_generator.run([message])
+
+ assert isinstance(response, dict)
+ assert isinstance(response["replies"], list)
+ assert answer in response["replies"][0].text
+
+ @pytest.mark.integration
+ @pytest.mark.skipif(
+ sys.platform != "linux",
+ reason="For simplicity, we only run the integration tests on Linux.",
+ )
+ def test_run_with_chat_history(self):
+ chat_generator = OllamaChatGenerator(model="llama3.2:3b")
+
+ chat_messages = [
+ ChatMessage.from_user("What is the largest city in the United Kingdom by population?"),
+ ChatMessage.from_assistant("London is the largest city in the United Kingdom by population"),
+ ChatMessage.from_user("And what is the second largest?"),
+ ]
+
+ response = chat_generator.run(chat_messages)
+
+ assert isinstance(response, dict)
+ assert isinstance(response["replies"], list)
+
+ assert any(city in response["replies"][-1].text for city in ["Manchester", "Birmingham", "Glasgow"])
+
+ @pytest.mark.integration
+ @pytest.mark.skipif(
+ sys.platform != "linux",
+ reason="For simplicity, we only run the integration tests on Linux.",
+ )
+ def test_run_model_unavailable(self):
+ component = OllamaChatGenerator(model="unknown_model")
+
+ with pytest.raises(ResponseError):
+ message = ChatMessage.from_user("irrelevant")
+ component.run([message])
+
+ @pytest.mark.integration
+ @pytest.mark.skipif(
+ sys.platform != "linux",
+ reason="For simplicity, we only run the integration tests on Linux.",
+ )
+ def test_run_with_streaming(self):
+ streaming_callback = Mock()
+ chat_generator = OllamaChatGenerator(model="llama3.2:3b", streaming_callback=streaming_callback)
+
+ chat_messages = [
+ ChatMessage.from_user("What is the largest city in the United Kingdom by population?"),
+ ChatMessage.from_assistant("London is the largest city in the United Kingdom by population"),
+ ChatMessage.from_user("And what is the second largest?"),
+ ]
+
+ response = chat_generator.run(chat_messages)
+
+ streaming_callback.assert_called()
+
+ assert isinstance(response, dict)
+ assert isinstance(response["replies"], list)
+ assert any(city in response["replies"][-1].text for city in ["Manchester", "Birmingham", "Glasgow"])
+
+ @pytest.mark.integration
+ @pytest.mark.skipif(
+ sys.platform != "linux",
+ reason="For simplicity, we only run the integration tests on Linux.",
+ )
+ def test_run_with_tools(self, tools):
+ chat_generator = OllamaChatGenerator(model="llama3.2:3b", tools=tools)
+
+ message = ChatMessage.from_user("What is the weather in Paris?")
+ response = chat_generator.run([message])
+
+ assert len(response["replies"]) == 1
+ message = response["replies"][0]
+
+ assert message.tool_calls
+ tool_call = message.tool_call
+ assert isinstance(tool_call, ToolCall)
+ assert tool_call.tool_name == "weather"
+ assert tool_call.arguments == {"city": "Paris"}