Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support for tools in OllamaChatGenerator #106

Merged
merged 15 commits into from
Oct 2, 2024
22 changes: 22 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ env:
COHERE_API_KEY: ${{ secrets.COHERE_API_KEY }}
FIRECRAWL_API_KEY: ${{ secrets.FIRECRAWL_API_KEY }}
SERPERDEV_API_KEY: ${{ secrets.SERPERDEV_API_KEY }}
OLLAMA_LLM_FOR_TESTS: "llama3.2:3b"
shadeMe marked this conversation as resolved.
Show resolved Hide resolved

jobs:
linting:
Expand Down Expand Up @@ -117,5 +118,26 @@ jobs:
- name: Install Hatch
run: pip install hatch==${{ env.HATCH_VERSION }}

- name: Install Ollama and pull the required models
shadeMe marked this conversation as resolved.
Show resolved Hide resolved
if: matrix.os == 'ubuntu-latest'
run: |
curl -fsSL https://ollama.com/install.sh | sh
ollama serve &

# Check if the service is up and running with a timeout of 60 seconds
timeout=60
while [ $timeout -gt 0 ] && ! curl -sSf http://localhost:11434/ > /dev/null; do
echo "Waiting for Ollama service to start..."
sleep 5
((timeout-=5))
done

if [ $timeout -eq 0 ]; then
echo "Timed out waiting for Ollama service to start."
exit 1
fi

ollama pull ${{ env.OLLAMA_LLM_FOR_TESTS }}

- name: Run
run: hatch run test:integration
3 changes: 2 additions & 1 deletion docs/pydoc/config/generators_api.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
loaders:
- type: haystack_pydoc_tools.loaders.CustomPythonLoader
search_path: [../../../]
modules: ["haystack_experimental.components.generators.chat.openai"]
modules: ["haystack_experimental.components.generators.chat.openai",
"haystack_experimental.components.generators.ollama.chat.chat_generator"]
ignore_when_discovered: ["__init__"]
processors:
- type: filter
Expand Down
4 changes: 3 additions & 1 deletion haystack_experimental/components/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from .extractors import LLMMetadataExtractor
from .generators.chat import OpenAIChatGenerator
from .generators.ollama.chat.chat_generator import OllamaChatGenerator
from .retrievers.auto_merging_retriever import AutoMergingRetriever
from .retrievers.chat_message_retriever import ChatMessageRetriever
from .splitters import HierarchicalDocumentSplitter
Expand All @@ -15,9 +16,10 @@
"AutoMergingRetriever",
"ChatMessageWriter",
"ChatMessageRetriever",
"OllamaChatGenerator",
"OpenAIChatGenerator",
"LLMMetadataExtractor",
"HierarchicalDocumentSplitter",
"OpenAIFunctionCaller",
"ToolInvoker"
"ToolInvoker",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0

from .chat.chat_generator import OllamaChatGenerator

__all__ = [
"OllamaChatGenerator",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0

from .chat_generator import OllamaChatGenerator

__all__ = [
"OllamaChatGenerator",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,244 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0

from typing import Any, Callable, Dict, List, Optional, Type

from haystack import component, default_from_dict
from haystack.dataclasses import StreamingChunk
from haystack.lazy_imports import LazyImport
from haystack.utils.callable_serialization import deserialize_callable

from haystack_experimental.dataclasses import ChatMessage, ToolCall
from haystack_experimental.dataclasses.tool import Tool, deserialize_tools_inplace

with LazyImport("Run 'pip install ollama-haystack'") as ollama_integration_import:
# pylint: disable=import-error
from haystack_integrations.components.generators.ollama import OllamaChatGenerator as OllamaChatGeneratorBase


if ollama_integration_import.is_successful():
chatgenerator_base_class: Type[OllamaChatGeneratorBase] = OllamaChatGeneratorBase
else:
chatgenerator_base_class: Type[object] = object # type: ignore[no-redef]
shadeMe marked this conversation as resolved.
Show resolved Hide resolved


def _convert_message_to_ollama_format(message: ChatMessage) -> Dict[str, Any]:
"""
Convert a message to the format expected by Ollama Chat API.
"""
text_contents = message.texts
tool_calls = message.tool_calls
tool_call_results = message.tool_call_results

if not text_contents and not tool_calls and not tool_call_results:
raise ValueError("A `ChatMessage` must contain at least one `TextContent`, `ToolCall`, or `ToolCallResult`.")
elif len(text_contents) + len(tool_call_results) > 1:
raise ValueError("A `ChatMessage` can only contain one `TextContent` or one `ToolCallResult`.")

ollama_msg: Dict[str, Any] = {"role": message._role.value}

if tool_call_results:
result = tool_call_results[0]
anakin87 marked this conversation as resolved.
Show resolved Hide resolved
ollama_msg["content"] = result.result
# Ollama does not provide a way to communicate errors in tool invocations, so we ignore the error field
return ollama_msg

if text_contents:
ollama_msg["content"] = text_contents[0]
if tool_calls:
ollama_tool_calls = []
for tc in tool_calls:
# Ollama does not support tool call id, so we ignore it
ollama_tool_calls.append(
{
"type": "function",
"function": {"name": tc.tool_name, "arguments": tc.arguments},
}
)
ollama_msg["tool_calls"] = ollama_tool_calls
return ollama_msg


@component()
class OllamaChatGenerator(chatgenerator_base_class):
"""
Supports models running on Ollama.

Find the full list of supported models [here](https://ollama.ai/library).

Usage example:
```python
from haystack_experimental.components.generators.ollama import OllamaChatGenerator
from haystack_experimental.dataclasses import ChatMessage

generator = OllamaChatGenerator(model="zephyr",
url = "http://localhost:11434",
generation_kwargs={
"num_predict": 100,
"temperature": 0.9,
})

messages = [ChatMessage.from_system("\nYou are a helpful, respectful and honest assistant"),
ChatMessage.from_user("What's Natural Language Processing?")]

print(generator.run(messages=messages))
```
"""

def __init__(
self,
model: str = "orca-mini",
url: str = "http://localhost:11434",
generation_kwargs: Optional[Dict[str, Any]] = None,
timeout: int = 120,
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
tools: Optional[List[Tool]] = None,
):
"""
Creates an instance of OllamaChatGenerator.

:param model:
The name of the model to use. The model should be available in the running Ollama instance.
:param url:
The URL of a running Ollama instance.
:param generation_kwargs:
Optional arguments to pass to the Ollama generation endpoint, such as temperature,
top_p, and others. See the available arguments in
[Ollama docs](https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values).
:param timeout:
The number of seconds before throwing a timeout error from the Ollama API.
:param streaming_callback:
A callback function that is called when a new token is received from the stream.
The callback function accepts StreamingChunk as an argument.
:param tools:
A list of tools for which the model can prepare calls.
Not all models support tools. For a list of models compatible with tools, see the
[models page](https://ollama.com/search?c=tools).
"""
ollama_integration_import.check()
shadeMe marked this conversation as resolved.
Show resolved Hide resolved

if tools:
tool_names = [tool.name for tool in tools]
duplicate_tool_names = {name for name in tool_names if tool_names.count(name) > 1}
if duplicate_tool_names:
raise ValueError(f"Duplicate tool names found: {duplicate_tool_names}")
self.tools = tools

super(OllamaChatGenerator, self).__init__(
model=model,
url=url,
generation_kwargs=generation_kwargs,
timeout=timeout,
streaming_callback=streaming_callback,
)

def to_dict(self) -> Dict[str, Any]:
"""
Serialize this component to a dictionary.

:returns:
The serialized component as a dictionary.
"""
serialized = super(OllamaChatGenerator, self).to_dict()
serialized["init_parameters"]["tools"] = [tool.to_dict() for tool in self.tools] if self.tools else None
return serialized

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "OllamaChatGenerator":
"""
Deserialize this component from a dictionary.

:param data: The dictionary representation of this component.
:returns:
The deserialized component instance.
"""
deserialize_tools_inplace(data["init_parameters"], key="tools")
init_params = data.get("init_parameters", {})
serialized_callback_handler = init_params.get("streaming_callback")
shadeMe marked this conversation as resolved.
Show resolved Hide resolved
if serialized_callback_handler:
data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler)

return default_from_dict(cls, data)

def _build_message_from_ollama_response(self, ollama_response: Dict[str, Any]) -> ChatMessage:
"""
Converts the non-streaming response from the Ollama API to a ChatMessage.
"""
ollama_message = ollama_response["message"]

text = ollama_message["content"]

tool_calls = []
anakin87 marked this conversation as resolved.
Show resolved Hide resolved
if ollama_tool_calls := ollama_message.get("tool_calls"):
for ollama_tc in ollama_tool_calls:
tool_calls.append(
ToolCall(tool_name=ollama_tc["function"]["name"], arguments=ollama_tc["function"]["arguments"])
)

message = ChatMessage.from_assistant(text=text, tool_calls=tool_calls)

message.meta.update({key: value for key, value in ollama_response.items() if key != "message"})
return message

def _convert_to_streaming_response(self, chunks: List[StreamingChunk]) -> Dict[str, List[Any]]:
"""
Converts a list of chunks response required Haystack format.
"""

# Unaltered from the integration code. Overridden to use the experimental ChatMessage dataclass.

replies = [ChatMessage.from_assistant("".join([c.content for c in chunks]))]
meta = {key: value for key, value in chunks[0].meta.items() if key != "message"}

return {"replies": replies, "meta": [meta]}

@component.output_types(replies=List[ChatMessage])
def run(
self,
messages: List[ChatMessage],
generation_kwargs: Optional[Dict[str, Any]] = None,
tools: Optional[List[Tool]] = None,
):
"""
Runs an Ollama Model on a given chat history.

:param messages:
A list of ChatMessage instances representing the input messages.
:param generation_kwargs:
Optional arguments to pass to the Ollama generation endpoint, such as temperature,
top_p, etc. See the
[Ollama docs](https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values).
:param tools:
A list of tools for which the model can prepare calls. If set, it will override the `tools` parameter set
during component initialization.
:returns: A dictionary with the following keys:
- `replies`: The responses from the model
"""
generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}

stream = self.streaming_callback is not None
tools = tools or self.tools

if stream and tools:
raise ValueError("Ollama does not support tools and streaming at the same time. Please choose one.")

ollama_tools = None
if tools:
tool_names = [tool.name for tool in tools]
duplicate_tool_names = {name for name in tool_names if tool_names.count(name) > 1}
if duplicate_tool_names:
raise ValueError(f"Duplicate tool names found: {duplicate_tool_names}")

ollama_tools = [{"type": "function", "function": {**t.tool_spec}} for t in tools]

ollama_messages = [_convert_message_to_ollama_format(msg) for msg in messages]
response = self._client.chat(
model=self.model, messages=ollama_messages, tools=ollama_tools, stream=stream, options=generation_kwargs
)

if stream:
chunks: List[StreamingChunk] = self._handle_streaming_response(response)
return self._convert_to_streaming_response(chunks)

return {"replies": [self._build_message_from_ollama_response(response)]}
2 changes: 1 addition & 1 deletion haystack_experimental/dataclasses/chat_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def from_assistant(
:returns: A new ChatMessage instance.
"""
content: List[ChatMessageContentT] = []
if text:
if text is not None:
content.append(TextContent(text=text))
if tool_calls:
content.extend(tool_calls)
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,9 @@ extra-dependencies = [
"cohere-haystack",
"anthropic-haystack",
"fastapi",
# Tool
# Tools support
"jsonschema",
"ollama-haystack>=1.0.0",
# LLMMetadataExtractor dependencies
"amazon-bedrock-haystack>=1.0.2",
"google-vertex-haystack>=2.0.0",
Expand Down
Loading
Loading