From 9ae65a01ada61eb0f603f78f077e11dccc22fc30 Mon Sep 17 00:00:00 2001 From: anakin87 Date: Thu, 26 Sep 2024 18:30:56 +0200 Subject: [PATCH] progress --- .github/workflows/tests.yml | 22 +++ haystack_experimental/components/__init__.py | 3 +- .../components/generators/ollama/__init__.py | 9 + .../generators/ollama/chat/__init__.py | 9 + .../generators/ollama/chat/chat_generator.py | 183 ++++++++++++++++++ pyproject.toml | 3 +- .../generators/ollama/test_chat_generator.py | 170 ++++++++++++++++ 7 files changed, 396 insertions(+), 3 deletions(-) create mode 100644 haystack_experimental/components/generators/ollama/__init__.py create mode 100644 haystack_experimental/components/generators/ollama/chat/__init__.py create mode 100644 haystack_experimental/components/generators/ollama/chat/chat_generator.py create mode 100644 test/components/generators/ollama/test_chat_generator.py diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index ba960d2c..a688bb41 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -35,6 +35,7 @@ env: COHERE_API_KEY: ${{ secrets.COHERE_API_KEY }} FIRECRAWL_API_KEY: ${{ secrets.FIRECRAWL_API_KEY }} SERPERDEV_API_KEY: ${{ secrets.SERPERDEV_API_KEY }} + OLLAMA_LLM_FOR_TESTS: "llama3.2:1b" jobs: linting: @@ -117,5 +118,26 @@ jobs: - name: Install Hatch run: pip install hatch==${{ env.HATCH_VERSION }} + - name: Install Ollama and pull the required models + if: matrix.os == 'ubuntu-latest' + run: | + curl -fsSL https://ollama.com/install.sh | sh + ollama serve & + + # Check if the service is up and running with a timeout of 60 seconds + timeout=60 + while [ $timeout -gt 0 ] && ! curl -sSf http://localhost:11434/ > /dev/null; do + echo "Waiting for Ollama service to start..." + sleep 5 + ((timeout-=5)) + done + + if [ $timeout -eq 0 ]; then + echo "Timed out waiting for Ollama service to start." + exit 1 + fi + + ollama pull ${{ env.OLLAMA_LLM_FOR_TESTS }} + - name: Run run: hatch run test:integration diff --git a/haystack_experimental/components/__init__.py b/haystack_experimental/components/__init__.py index 75afcfb1..b3065f22 100644 --- a/haystack_experimental/components/__init__.py +++ b/haystack_experimental/components/__init__.py @@ -9,7 +9,6 @@ from .tools import OpenAIFunctionCaller, ToolInvoker from .writers import ChatMessageWriter - _all_ = [ "AutoMergingRetriever", "ChatMessageWriter", @@ -17,5 +16,5 @@ "OpenAIChatGenerator", "HierarchicalDocumentSplitter", "OpenAIFunctionCaller", - "ToolInvoker" + "ToolInvoker", ] diff --git a/haystack_experimental/components/generators/ollama/__init__.py b/haystack_experimental/components/generators/ollama/__init__.py new file mode 100644 index 00000000..e57d5847 --- /dev/null +++ b/haystack_experimental/components/generators/ollama/__init__.py @@ -0,0 +1,9 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from .chat.chat_generator import OllamaChatGenerator + +__all__ = [ + "OllamaChatGenerator", +] diff --git a/haystack_experimental/components/generators/ollama/chat/__init__.py b/haystack_experimental/components/generators/ollama/chat/__init__.py new file mode 100644 index 00000000..85b66d68 --- /dev/null +++ b/haystack_experimental/components/generators/ollama/chat/__init__.py @@ -0,0 +1,9 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from .chat_generator import OllamaChatGenerator + +__all__ = [ + "OllamaChatGenerator", +] diff --git a/haystack_experimental/components/generators/ollama/chat/chat_generator.py b/haystack_experimental/components/generators/ollama/chat/chat_generator.py new file mode 100644 index 00000000..43043b00 --- /dev/null +++ b/haystack_experimental/components/generators/ollama/chat/chat_generator.py @@ -0,0 +1,183 @@ +# from haystack.dataclasses import ChatMessage, StreamingChunk +from typing import Any, Callable, Dict, List, Optional, Type, Union + +from haystack import component, default_from_dict +from haystack.dataclasses import StreamingChunk +from haystack.lazy_imports import LazyImport +from haystack.utils.callable_serialization import deserialize_callable, serialize_callable + +from haystack_experimental.dataclasses import ChatMessage, ToolCall +from haystack_experimental.dataclasses.tool import Tool, deserialize_tools_inplace + +with LazyImport("Run 'pip install ollama-haystack'") as ollama_integration_import: + # pylint: disable=import-error + from haystack_integrations.components.generators.ollama import OllamaChatGenerator as OllamaChatGeneratorBase + +base_class: Union[Type[object], Type["OllamaChatGeneratorBase"]] = object +if ollama_integration_import.is_successful(): + base_class = OllamaChatGeneratorBase + +print(base_class) + + +@component() +class OllamaChatGenerator(base_class): + """ + Supports models running on Ollama. + + Find the full list of supported models [here](https://ollama.ai/library). + + Usage example: + ```python + from haystack_experimental.components.generators.ollama import OllamaChatGenerator + from haystack_experimental.dataclasses import ChatMessage + + generator = OllamaChatGenerator(model="zephyr", + url = "http://localhost:11434", + generation_kwargs={ + "num_predict": 100, + "temperature": 0.9, + }) + + messages = [ChatMessage.from_system("\nYou are a helpful, respectful and honest assistant"), + ChatMessage.from_user("What's Natural Language Processing?")] + + print(generator.run(messages=messages)) + ``` + """ + + def __init__( + self, + model: str = "orca-mini", + url: str = "http://localhost:11434", + generation_kwargs: Optional[Dict[str, Any]] = None, + timeout: int = 120, + streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, + tools: Optional[List[Tool]] = None, + ): + """ + Creates an instance of OllamaChatGenerator. + + :param model: + The name of the model to use. The model should be available in the running Ollama instance. + :param url: + The URL of a running Ollama instance. + :param generation_kwargs: + Optional arguments to pass to the Ollama generation endpoint, such as temperature, + top_p, and others. See the available arguments in + [Ollama docs](https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values). + :param timeout: + The number of seconds before throwing a timeout error from the Ollama API. + :param streaming_callback: + A callback function that is called when a new token is received from the stream. + The callback function accepts StreamingChunk as an argument. + :param tools: + A list of tools for which the model can prepare calls. + Not all models support tools. For a list of models compatible with tools, see the + [models page](https://ollama.com/search?c=tools). + """ + ollama_integration_import.check() + + if tools: + tool_names = [tool.name for tool in tools] + duplicate_tool_names = {name for name in tool_names if tool_names.count(name) > 1} + if duplicate_tool_names: + raise ValueError(f"Duplicate tool names found: {duplicate_tool_names}") + self.tools = tools + + super(OllamaChatGenerator, self).__init__( + model=model, + url=url, + generation_kwargs=generation_kwargs, + timeout=timeout, + streaming_callback=streaming_callback, + ) + + def to_dict(self) -> Dict[str, Any]: + """ + Serialize this component to a dictionary. + + :returns: + The serialized component as a dictionary. + """ + serialized = super(OllamaChatGenerator, self).to_dict() + serialized["init_parameters"]["tools"] = [tool.to_dict() for tool in self.tools] if self.tools else None + return serialized + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "OllamaChatGenerator": + """ + Deserialize this component from a dictionary. + + :param data: The dictionary representation of this component. + :returns: + The deserialized component instance. + """ + deserialize_tools_inplace(data["init_parameters"], key="tools") + init_params = data.get("init_parameters", {}) + serialized_callback_handler = init_params.get("streaming_callback") + if serialized_callback_handler: + data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler) + + return default_from_dict(cls, data) + + # TODO: rework + def _message_to_dict(self, message: ChatMessage) -> Dict[str, str]: + return {"role": message.role.value, "content": message.text or ""} + + # TODO: rework + def _build_message_from_ollama_response(self, ollama_response: Dict[str, Any]) -> ChatMessage: + """ + Converts the non-streaming response from the Ollama API to a ChatMessage. + """ + message = ChatMessage.from_assistant(text=ollama_response["message"]["content"]) + message.meta.update({key: value for key, value in ollama_response.items() if key != "message"}) + return message + + @component.output_types(replies=List[ChatMessage]) + def run( + self, + messages: List[ChatMessage], + generation_kwargs: Optional[Dict[str, Any]] = None, + tools: Optional[List[Tool]] = None, + ): + """ + Runs an Ollama Model on a given chat history. + + :param messages: + A list of ChatMessage instances representing the input messages. + :param generation_kwargs: + Optional arguments to pass to the Ollama generation endpoint, such as temperature, + top_p, etc. See the + [Ollama docs](https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values). + :param tools: + A list of tools for which the model can prepare calls. If set, it will override the `tools` parameter set + during component initialization. + :returns: A dictionary with the following keys: + - `replies`: The responses from the model + """ + generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})} + + stream = self.streaming_callback is not None + + tools = tools or self.tools + if tools: + if stream: + raise ValueError("Ollama does not support tools and streaming at the same time. Please choose one.") + tool_names = [tool.name for tool in tools] + duplicate_tool_names = {name for name in tool_names if tool_names.count(name) > 1} + if duplicate_tool_names: + raise ValueError(f"Duplicate tool names found: {duplicate_tool_names}") + + # ollama_tools = None + # if tools: + # ollama_tools = [{"type": "function", "function": {**t.tool_spec}} for t in tools] + + messages = [self._message_to_dict(message) for message in messages] + response = self._client.chat(model=self.model, messages=messages, stream=stream, options=generation_kwargs) + + if stream: + chunks: List[StreamingChunk] = self._handle_streaming_response(response) + return self._convert_to_streaming_response(chunks) + + return {"replies": [self._build_message_from_ollama_response(response)]} diff --git a/pyproject.toml b/pyproject.toml index fb97b136..6cb9faf9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,8 +57,9 @@ extra-dependencies = [ "cohere-haystack", "anthropic-haystack", "fastapi", - # Tool + # Tools support "jsonschema", + "ollama-haystack>=1.0.0", ] [tool.hatch.envs.test.scripts] diff --git a/test/components/generators/ollama/test_chat_generator.py b/test/components/generators/ollama/test_chat_generator.py new file mode 100644 index 00000000..5a4d85ee --- /dev/null +++ b/test/components/generators/ollama/test_chat_generator.py @@ -0,0 +1,170 @@ +from typing import List +from unittest.mock import Mock + +import pytest +from haystack.components.generators.utils import print_streaming_chunk +from haystack_experimental.dataclasses import ChatMessage, ChatRole, TextContent +from ollama._types import ResponseError + +from haystack_experimental.components.generators.ollama.chat.chat_generator import OllamaChatGenerator + +@pytest.fixture +def chat_messages() -> List[ChatMessage]: + return [ + ChatMessage.from_user("Tell me about why Super Mario is the greatest superhero"), + ChatMessage.from_assistant( + text="Super Mario has prevented Bowser from destroying the world", meta={"something": "something"} + ), + ] + + +class TestOllamaChatGenerator: + def test_init_default(self): + component = OllamaChatGenerator() + assert component.model == "orca-mini" + assert component.url == "http://localhost:11434" + assert component.generation_kwargs == {} + assert component.timeout == 120 + + def test_init(self): + component = OllamaChatGenerator( + model="llama2", + url="http://my-custom-endpoint:11434", + generation_kwargs={"temperature": 0.5}, + timeout=5, + ) + + assert component.model == "llama2" + assert component.url == "http://my-custom-endpoint:11434" + assert component.generation_kwargs == {"temperature": 0.5} + assert component.timeout == 5 + + def test_to_dict(self): + component = OllamaChatGenerator( + model="llama2", + streaming_callback=print_streaming_chunk, + url="custom_url", + generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, + ) + data = component.to_dict() + assert data == { + "type": "haystack_integrations.components.generators.ollama.chat.chat_generator.OllamaChatGenerator", + "init_parameters": { + "timeout": 120, + "model": "llama2", + "url": "custom_url", + "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", + "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, + }, + } + + def test_from_dict(self): + data = { + "type": "haystack_integrations.components.generators.ollama.chat.chat_generator.OllamaChatGenerator", + "init_parameters": { + "timeout": 120, + "model": "llama2", + "url": "custom_url", + "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", + "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, + }, + } + component = OllamaChatGenerator.from_dict(data) + assert component.model == "llama2" + assert component.streaming_callback is print_streaming_chunk + assert component.url == "custom_url" + assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} + + def test_build_message_from_ollama_response(self): + model = "some_model" + + ollama_response = { + "model": model, + "created_at": "2023-12-12T14:13:43.416799Z", + "message": {"role": "assistant", "content": "Hello! How are you today?"}, + "done": True, + "total_duration": 5191566416, + "load_duration": 2154458, + "prompt_eval_count": 26, + "prompt_eval_duration": 383809000, + "eval_count": 298, + "eval_duration": 4799921000, + } + + observed = OllamaChatGenerator(model=model)._build_message_from_ollama_response(ollama_response) + + assert observed.role == "assistant" + assert observed.content == "Hello! How are you today?" + + @pytest.mark.integration + def test_run(self): + chat_generator = OllamaChatGenerator(model="llama3.2:1b") + + user_questions_and_assistant_answers = [ + ("What's the capital of France?", "Paris"), + ("What is the capital of Canada?", "Ottawa"), + ("What is the capital of Ghana?", "Accra"), + ] + + for question, answer in user_questions_and_assistant_answers: + message = ChatMessage.from_user(question) + + response = chat_generator.run([message]) + + assert isinstance(response, dict) + assert isinstance(response["replies"], list) + assert answer in response["replies"][0].text + + @pytest.mark.integration + def test_run_with_chat_history(self): + chat_generator = OllamaChatGenerator(model="llama3.2:1b") + + chat_history = [ + {"role": "user", "content": "What is the largest city in the United Kingdom by population?"}, + {"role": "assistant", "content": "London is the largest city in the United Kingdom by population"}, + {"role": "user", "content": "And what is the second largest?"}, + ] + + chat_messages = [ + ChatMessage(_role=ChatRole(message["role"]), _content=TextContent(message["content"])) + for message in chat_history + ] + response = chat_generator.run(chat_messages) + + assert isinstance(response, dict) + assert isinstance(response["replies"], list) + assert "Manchester" in response["replies"][-1].text or "Glasgow" in response["replies"][-1].text + + @pytest.mark.integration + def test_run_model_unavailable(self): + component = OllamaChatGenerator(model="Alistair_and_Stefano_are_great") + + with pytest.raises(ResponseError): + message = ChatMessage.from_user( + "Based on your infinite wisdom, can you tell me why Alistair and Stefano are so great?" + ) + component.run([message]) + + @pytest.mark.integration + def test_run_with_streaming(self): + streaming_callback = Mock() + chat_generator = OllamaChatGenerator(streaming_callback=streaming_callback) + + chat_history = [ + {"role": "user", "content": "What is the largest city in the United Kingdom by population?"}, + {"role": "assistant", "content": "London is the largest city in the United Kingdom by population"}, + {"role": "user", "content": "And what is the second largest?"}, + ] + + chat_messages = [ + ChatMessage(_role=ChatRole(message["role"]), _content=TextContent(message["content"])) + for message in chat_history + ] + + response = chat_generator.run(chat_messages) + + streaming_callback.assert_called() + + assert isinstance(response, dict) + assert isinstance(response["replies"], list) + assert "Manchester" in response["replies"][-1].text or "Glasgow" in response["replies"][-1].text