Skip to content

Commit

Permalink
feat: support for tools in OllamaChatGenerator (#106)
Browse files Browse the repository at this point in the history
* progress

* some progress

* try fixing tests

* different test model

* tools handling and tests

* refinements

* minor fixes

* fix

* more unit tests

* formatting

* incorporate feedback from review

* update README
  • Loading branch information
anakin87 authored Oct 2, 2024
1 parent 3c834bb commit 5e78544
Show file tree
Hide file tree
Showing 10 changed files with 762 additions and 7 deletions.
22 changes: 22 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ env:
COHERE_API_KEY: ${{ secrets.COHERE_API_KEY }}
FIRECRAWL_API_KEY: ${{ secrets.FIRECRAWL_API_KEY }}
SERPERDEV_API_KEY: ${{ secrets.SERPERDEV_API_KEY }}
OLLAMA_LLM_FOR_TESTS: "llama3.2:3b"

jobs:
linting:
Expand Down Expand Up @@ -117,5 +118,26 @@ jobs:
- name: Install Hatch
run: pip install hatch==${{ env.HATCH_VERSION }}

- name: Install Ollama and pull the required models
if: matrix.os == 'ubuntu-latest'
run: |
curl -fsSL https://ollama.com/install.sh | sh
ollama serve &
# Check if the service is up and running with a timeout of 60 seconds
timeout=60
while [ $timeout -gt 0 ] && ! curl -sSf http://localhost:11434/ > /dev/null; do
echo "Waiting for Ollama service to start..."
sleep 5
((timeout-=5))
done
if [ $timeout -eq 0 ]; then
echo "Timed out waiting for Ollama service to start."
exit 1
fi
ollama pull ${{ env.OLLAMA_LLM_FOR_TESTS }}
- name: Run
run: hatch run test:integration
7 changes: 4 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,17 @@ that includes it. Once it reaches the end of its lifespan, the experiment will b

The latest version of the package contains the following experiments:


=======
| Name | Type | Expected End Date | Dependencies | Cookbook | Discussion |
| --------------------------- | -------------------------- | ---------------------------- | ------------ | -------- | ---------- |
| [`EvaluationHarness`][1] | Evaluation orchestrator | October 2024 | None | <a href="https://colab.research.google.com/github/deepset-ai/haystack-cookbook/blob/main/notebooks/rag_eval_harness.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a> | [Discuss](https://github.com/deepset-ai/haystack-experimental/discussions/74) |
| [`OpenAIFunctionCaller`][2] | Function Calling Component | October 2024 | None | 🔜 | |
| [`OpenAPITool`][3] | OpenAPITool component | October 2024 | jsonref | <a href="https://colab.research.google.com/github/deepset-ai/haystack-cookbook/blob/main/notebooks/openapitool.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/> | [Discuss](https://github.com/deepset-ai/haystack-experimental/discussions/79)|
| Support for Tools: [refactored `ChatMessage` dataclass][10], [`Tool` dataclass][4], [refactored `OpenAIChatGenerator`][11], [`ToolInvoker` component][12] | Tool Calling support | November 2024 | jsonschema | <a href="https://colab.research.google.com/github/deepset-ai/haystack-cookbook/blob/main/notebooks/tools_support.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/> | [Discuss](https://github.com/deepset-ai/haystack-experimental/discussions/98)|
| Support for Tools: [refactored `ChatMessage` dataclass][10], [`Tool` dataclass][4], [refactored `OpenAIChatGenerator`][11], [refactored `OllamaChatGenerator`][14], [`ToolInvoker` component][12] | Tool Calling support | November 2024 | jsonschema | <a href="https://colab.research.google.com/github/deepset-ai/haystack-cookbook/blob/main/notebooks/tools_support.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/> | [Discuss](https://github.com/deepset-ai/haystack-experimental/discussions/98)|
| [`ChatMessageWriter`][5] | Memory Component | December 2024 | None | <a href="https://colab.research.google.com/github/deepset-ai/haystack-cookbook/blob/main/notebooks/conversational_rag_using_memory.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/> | [Discuss](https://github.com/deepset-ai/haystack-experimental/discussions/75) |
| [`ChatMessageRetriever`][6] | Memory Component | December 2024 | None | <a href="https://colab.research.google.com/github/deepset-ai/haystack-cookbook/blob/main/notebooks/conversational_rag_using_memory.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/> | [Discuss](https://github.com/deepset-ai/haystack-experimental/discussions/75) |
| [`InMemoryChatMessageStore`][7] | Memory Store | December 2024 | None | <a href="https://colab.research.google.com/github/deepset-ai/haystack-cookbook/blob/main/notebooks/conversational_rag_using_memory.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/> | [Discuss](https://github.com/deepset-ai/haystack-experimental/discussions/75) |
| [`InMemoryChatMessageStore`][7] | Memory Store | December 2024 | None | <a href="https://colab.research.google.com/github/deepset-ai/haystack-cookbook/blob/main/notebooks/conversational_rag_using_memory.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/> | [Discuss](https://github.com/deepset-ai/haystack-experimental/discussions/75) |
| [`Auto-Merging Retriever`][8] & [`HierarchicalDocumentSplitter`][9]| Document Splitting & Retrieval Technique | December 2024 | None | <a href="https://colab.research.google.com/github/deepset-ai/haystack-cookbook/blob/main/notebooks/auto_merging_retriever.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/> | [Discuss](https://github.com/deepset-ai/haystack-experimental/discussions/78) |
| [`LLMetadataExtractor`][13] | Metadata extraction with LLM | December 2024 | None | | |

Expand All @@ -63,6 +63,7 @@ The latest version of the package contains the following experiments:
[11]: https://github.com/deepset-ai/haystack-experimental/blob/main/haystack_experimental/components/generators/chat/openai.py
[12]: https://github.com/deepset-ai/haystack-experimental/blob/main/haystack_experimental/components/tools/tool_invoker.py
[13]: https://github.com/deepset-ai/haystack-experimental/blob/main/haystack_experimental/components/extractors/llm_metadata_extractor.py
[14]: https://github.com/deepset-ai/haystack-experimental/blob/main/haystack_experimental/components/generators/ollama/chat/chat_generator.py

## Usage

Expand Down
3 changes: 2 additions & 1 deletion docs/pydoc/config/generators_api.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
loaders:
- type: haystack_pydoc_tools.loaders.CustomPythonLoader
search_path: [../../../]
modules: ["haystack_experimental.components.generators.chat.openai"]
modules: ["haystack_experimental.components.generators.chat.openai",
"haystack_experimental.components.generators.ollama.chat.chat_generator"]
ignore_when_discovered: ["__init__"]
processors:
- type: filter
Expand Down
4 changes: 3 additions & 1 deletion haystack_experimental/components/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from .extractors import LLMMetadataExtractor
from .generators.chat import OpenAIChatGenerator
from .generators.ollama.chat.chat_generator import OllamaChatGenerator
from .retrievers.auto_merging_retriever import AutoMergingRetriever
from .retrievers.chat_message_retriever import ChatMessageRetriever
from .splitters import HierarchicalDocumentSplitter
Expand All @@ -15,9 +16,10 @@
"AutoMergingRetriever",
"ChatMessageWriter",
"ChatMessageRetriever",
"OllamaChatGenerator",
"OpenAIChatGenerator",
"LLMMetadataExtractor",
"HierarchicalDocumentSplitter",
"OpenAIFunctionCaller",
"ToolInvoker"
"ToolInvoker",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0

from .chat.chat_generator import OllamaChatGenerator

__all__ = [
"OllamaChatGenerator",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0

from .chat_generator import OllamaChatGenerator

__all__ = [
"OllamaChatGenerator",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,248 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0

from typing import Any, Callable, Dict, List, Optional, Type

from haystack import component, default_from_dict
from haystack.dataclasses import StreamingChunk
from haystack.lazy_imports import LazyImport
from haystack.utils.callable_serialization import deserialize_callable

from haystack_experimental.dataclasses import ChatMessage, ToolCall
from haystack_experimental.dataclasses.tool import Tool, deserialize_tools_inplace

with LazyImport("Run 'pip install ollama-haystack'") as ollama_integration_import:
# pylint: disable=import-error
from haystack_integrations.components.generators.ollama import OllamaChatGenerator as OllamaChatGeneratorBase


# The following code block ensures that:
# - we reuse existing code where possible
# - people can use haystack-experimental without installing ollama-haystack.
#
#
# If ollama-haystack is installed: all works correctly.
#
# If ollama-haystack is not installed:
# - haystack-experimental package works fine (no import errors).
# - OllamaChatGenerator fails with ImportError at init (due to ollama_integration_import.check()).

if ollama_integration_import.is_successful():
chatgenerator_base_class: Type[OllamaChatGeneratorBase] = OllamaChatGeneratorBase
else:
chatgenerator_base_class: Type[object] = object # type: ignore[no-redef]


def _convert_message_to_ollama_format(message: ChatMessage) -> Dict[str, Any]:
"""
Convert a message to the format expected by Ollama Chat API.
"""
text_contents = message.texts
tool_calls = message.tool_calls
tool_call_results = message.tool_call_results

if not text_contents and not tool_calls and not tool_call_results:
raise ValueError("A `ChatMessage` must contain at least one `TextContent`, `ToolCall`, or `ToolCallResult`.")
elif len(text_contents) + len(tool_call_results) > 1:
raise ValueError("A `ChatMessage` can only contain one `TextContent` or one `ToolCallResult`.")

ollama_msg: Dict[str, Any] = {"role": message._role.value}

if tool_call_results:
# Ollama does not provide a way to communicate errors in tool invocations, so we ignore the error field
ollama_msg["content"] = tool_call_results[0].result
return ollama_msg

if text_contents:
ollama_msg["content"] = text_contents[0]
if tool_calls:
# Ollama does not support tool call id, so we ignore it
ollama_msg["tool_calls"] = [
{"type": "function", "function": {"name": tc.tool_name, "arguments": tc.arguments}} for tc in tool_calls
]
return ollama_msg


@component()
class OllamaChatGenerator(chatgenerator_base_class):
"""
Supports models running on Ollama.
Find the full list of supported models [here](https://ollama.ai/library).
Usage example:
```python
from haystack_experimental.components.generators.ollama import OllamaChatGenerator
from haystack_experimental.dataclasses import ChatMessage
generator = OllamaChatGenerator(model="zephyr",
url = "http://localhost:11434",
generation_kwargs={
"num_predict": 100,
"temperature": 0.9,
})
messages = [ChatMessage.from_system("\nYou are a helpful, respectful and honest assistant"),
ChatMessage.from_user("What's Natural Language Processing?")]
print(generator.run(messages=messages))
```
"""

def __init__(
self,
model: str = "orca-mini",
url: str = "http://localhost:11434",
generation_kwargs: Optional[Dict[str, Any]] = None,
timeout: int = 120,
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
tools: Optional[List[Tool]] = None,
):
"""
Creates an instance of OllamaChatGenerator.
:param model:
The name of the model to use. The model should be available in the running Ollama instance.
:param url:
The URL of a running Ollama instance.
:param generation_kwargs:
Optional arguments to pass to the Ollama generation endpoint, such as temperature,
top_p, and others. See the available arguments in
[Ollama docs](https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values).
:param timeout:
The number of seconds before throwing a timeout error from the Ollama API.
:param streaming_callback:
A callback function that is called when a new token is received from the stream.
The callback function accepts StreamingChunk as an argument.
:param tools:
A list of tools for which the model can prepare calls.
Not all models support tools. For a list of models compatible with tools, see the
[models page](https://ollama.com/search?c=tools).
"""
ollama_integration_import.check()

if tools:
tool_names = [tool.name for tool in tools]
duplicate_tool_names = {name for name in tool_names if tool_names.count(name) > 1}
if duplicate_tool_names:
raise ValueError(f"Duplicate tool names found: {duplicate_tool_names}")
self.tools = tools

super(OllamaChatGenerator, self).__init__(
model=model,
url=url,
generation_kwargs=generation_kwargs,
timeout=timeout,
streaming_callback=streaming_callback,
)

def to_dict(self) -> Dict[str, Any]:
"""
Serialize this component to a dictionary.
:returns:
The serialized component as a dictionary.
"""
serialized = super(OllamaChatGenerator, self).to_dict()
serialized["init_parameters"]["tools"] = [tool.to_dict() for tool in self.tools] if self.tools else None
return serialized

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "OllamaChatGenerator":
"""
Deserialize this component from a dictionary.
:param data: The dictionary representation of this component.
:returns:
The deserialized component instance.
"""
deserialize_tools_inplace(data["init_parameters"], key="tools")
init_params = data.get("init_parameters", {})
serialized_callback_handler = init_params.get("streaming_callback")
if serialized_callback_handler:
data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler)

return default_from_dict(cls, data)

def _build_message_from_ollama_response(self, ollama_response: Dict[str, Any]) -> ChatMessage:
"""
Converts the non-streaming response from the Ollama API to a ChatMessage.
"""
ollama_message = ollama_response["message"]

text = ollama_message["content"]

tool_calls = []
if ollama_tool_calls := ollama_message.get("tool_calls"):
for ollama_tc in ollama_tool_calls:
tool_calls.append(
ToolCall(tool_name=ollama_tc["function"]["name"], arguments=ollama_tc["function"]["arguments"])
)

message = ChatMessage.from_assistant(text=text, tool_calls=tool_calls)

message.meta.update({key: value for key, value in ollama_response.items() if key != "message"})
return message

def _convert_to_streaming_response(self, chunks: List[StreamingChunk]) -> Dict[str, List[Any]]:
"""
Converts a list of chunks response required Haystack format.
"""

# Unaltered from the integration code. Overridden to use the experimental ChatMessage dataclass.

replies = [ChatMessage.from_assistant("".join([c.content for c in chunks]))]
meta = {key: value for key, value in chunks[0].meta.items() if key != "message"}

return {"replies": replies, "meta": [meta]}

@component.output_types(replies=List[ChatMessage])
def run(
self,
messages: List[ChatMessage],
generation_kwargs: Optional[Dict[str, Any]] = None,
tools: Optional[List[Tool]] = None,
):
"""
Runs an Ollama Model on a given chat history.
:param messages:
A list of ChatMessage instances representing the input messages.
:param generation_kwargs:
Optional arguments to pass to the Ollama generation endpoint, such as temperature,
top_p, etc. See the
[Ollama docs](https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values).
:param tools:
A list of tools for which the model can prepare calls. If set, it will override the `tools` parameter set
during component initialization.
:returns: A dictionary with the following keys:
- `replies`: The responses from the model
"""
generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}

stream = self.streaming_callback is not None
tools = tools or self.tools

if stream and tools:
raise ValueError("Ollama does not support tools and streaming at the same time. Please choose one.")

ollama_tools = None
if tools:
tool_names = [tool.name for tool in tools]
duplicate_tool_names = {name for name in tool_names if tool_names.count(name) > 1}
if duplicate_tool_names:
raise ValueError(f"Duplicate tool names found: {duplicate_tool_names}")

ollama_tools = [{"type": "function", "function": {**t.tool_spec}} for t in tools]

ollama_messages = [_convert_message_to_ollama_format(msg) for msg in messages]
response = self._client.chat(
model=self.model, messages=ollama_messages, tools=ollama_tools, stream=stream, options=generation_kwargs
)

if stream:
chunks: List[StreamingChunk] = self._handle_streaming_response(response)
return self._convert_to_streaming_response(chunks)

return {"replies": [self._build_message_from_ollama_response(response)]}
2 changes: 1 addition & 1 deletion haystack_experimental/dataclasses/chat_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def from_assistant(
:returns: A new ChatMessage instance.
"""
content: List[ChatMessageContentT] = []
if text:
if text is not None:
content.append(TextContent(text=text))
if tool_calls:
content.extend(tool_calls)
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,9 @@ extra-dependencies = [
"cohere-haystack",
"anthropic-haystack",
"fastapi",
# Tool
# Tools support
"jsonschema",
"ollama-haystack>=1.0.0",
# LLMMetadataExtractor dependencies
"amazon-bedrock-haystack>=1.0.2",
"google-vertex-haystack>=2.0.0",
Expand Down
Loading

0 comments on commit 5e78544

Please sign in to comment.