diff --git a/integrations/ollama/example/chat_generator_example.py b/integrations/ollama/example/chat_generator_example.py new file mode 100644 index 000000000..1e941214a --- /dev/null +++ b/integrations/ollama/example/chat_generator_example.py @@ -0,0 +1,49 @@ +# In order to run this example, you will need to have an instance of Ollama running with the +# orca-mini model downloaded. We suggest you use the following commands to serve an orca-mini +# model from Ollama +# +# docker run -d -p 11434:11434 --name ollama ollama/ollama:latest +# docker exec ollama ollama pull orca-mini + +from haystack.dataclasses import ChatMessage + +from ollama_haystack import OllamaChatGenerator + +messages = [ + ChatMessage.from_user("What's Natural Language Processing?"), + ChatMessage.from_system( + "Natural Language Processing (NLP) is a field of computer science and artificial " + "intelligence concerned with the interaction between computers and human language" + ), + ChatMessage.from_user("How do I get started?"), +] +client = OllamaChatGenerator(model="orca-mini", timeout=45, url="http://localhost:11434/api/chat") + +response = client.run(messages, generation_kwargs={"temperature": 0.2}) + +print(response["replies"]) +# +# [ +# ChatMessage( +# content="Natural Language Processing (NLP) is a broad field of computer science and artificial intelligence " +# "that involves the interaction between computers and human language. To get started in NLP, " +# "you can start by learning about the different techniques and tools used in NLP such as machine " +# "learning algorithms, deep learning frameworks, and natural language processing libraries. You can " +# "also learn about the applications of NLP in various fields such as chatbots, sentiment analysis, " +# "speech recognition, and text classification. Additionally, you can explore the available resources " +# "such as online courses, tutorials, and books on NLP to gain a deeper understanding of the field.", +# role=, +# name=None, +# meta={ +# "model": "orca-mini", +# "created_at": "2024-01-08T15:35:23.378609793Z", +# "done": True, +# "total_duration": 20026330217, +# "load_duration": 1540167, +# "prompt_eval_count": 99, +# "prompt_eval_duration": 8486609000, +# "eval_count": 124, +# "eval_duration": 11532988000, +# }, +# ) +# ] diff --git a/integrations/ollama/example/example.py b/integrations/ollama/example/generator_example.py similarity index 100% rename from integrations/ollama/example/example.py rename to integrations/ollama/example/generator_example.py diff --git a/integrations/ollama/src/ollama_haystack/__init__.py b/integrations/ollama/src/ollama_haystack/__init__.py index 8bbb69641..10bc38121 100644 --- a/integrations/ollama/src/ollama_haystack/__init__.py +++ b/integrations/ollama/src/ollama_haystack/__init__.py @@ -2,6 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 +from ollama_haystack.chat.chat_generator import OllamaChatGenerator from ollama_haystack.generator import OllamaGenerator -__all__ = ["OllamaGenerator"] +__all__ = ["OllamaGenerator", "OllamaChatGenerator"] diff --git a/integrations/ollama/src/ollama_haystack/chat/__init__.py b/integrations/ollama/src/ollama_haystack/chat/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/integrations/ollama/src/ollama_haystack/chat/chat_generator.py b/integrations/ollama/src/ollama_haystack/chat/chat_generator.py new file mode 100644 index 000000000..6a8c5493b --- /dev/null +++ b/integrations/ollama/src/ollama_haystack/chat/chat_generator.py @@ -0,0 +1,96 @@ +from typing import Any, Dict, List, Optional + +import requests +from haystack import component +from haystack.dataclasses import ChatMessage +from requests import Response + + +@component +class OllamaChatGenerator: + """ + Chat Generator based on Ollama. Ollama is a library for easily running LLMs locally. + This component provides an interface to generate text using a LLM running in Ollama. + """ + + def __init__( + self, + model: str = "orca-mini", + url: str = "http://localhost:11434/api/chat", + generation_kwargs: Optional[Dict[str, Any]] = None, + template: Optional[str] = None, + timeout: int = 120, + ): + """ + :param model: The name of the model to use. The model should be available in the running Ollama instance. + Default is "orca-mini". + :param url: The URL of the chat endpoint of a running Ollama instance. + Default is "http://localhost:11434/api/chat". + :param generation_kwargs: Optional arguments to pass to the Ollama generation endpoint, such as temperature, + top_p, and others. See the available arguments in + [Ollama docs](https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values). + :param template: The full prompt template (overrides what is defined in the Ollama Modelfile). + :param timeout: The number of seconds before throwing a timeout error from the Ollama API. + Default is 120 seconds. + """ + + self.timeout = timeout + self.template = template + self.generation_kwargs = generation_kwargs or {} + self.url = url + self.model = model + + def _message_to_dict(self, message: ChatMessage) -> Dict[str, str]: + return {"role": message.role.value, "content": message.content} + + def _create_json_payload(self, messages: List[ChatMessage], generation_kwargs=None) -> Dict[str, Any]: + """ + Returns A dictionary of JSON arguments for a POST request to an Ollama service + :param messages: A history/list of chat messages + :param generation_kwargs: + :return: A dictionary of arguments for a POST request to an Ollama service + """ + generation_kwargs = generation_kwargs or {} + return { + "messages": [self._message_to_dict(message) for message in messages], + "model": self.model, + "stream": False, + "template": self.template, + "options": generation_kwargs, + } + + def _build_message_from_ollama_response(self, ollama_response: Response) -> ChatMessage: + """ + Converts the non-streaming response from the Ollama API to a ChatMessage. + :param ollama_response: The completion returned by the Ollama API. + :return: The ChatMessage. + """ + json_content = ollama_response.json() + message = ChatMessage.from_assistant(content=json_content["message"]["content"]) + message.meta.update({key: value for key, value in json_content.items() if key != "message"}) + return message + + @component.output_types(replies=List[ChatMessage]) + def run( + self, + messages: List[ChatMessage], + generation_kwargs: Optional[Dict[str, Any]] = None, + ): + """ + Run an Ollama Model on a given chat history. + :param messages: A list of ChatMessage instances representing the input messages. + :param generation_kwargs: Optional arguments to pass to the Ollama generation endpoint, such as temperature, + top_p, etc. See the + [Ollama docs](https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values). + :return: A dictionary of the replies containing their metadata + """ + generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})} + + json_payload = self._create_json_payload(messages, generation_kwargs) + + response = requests.post(url=self.url, json=json_payload, timeout=self.timeout) + + # throw error on unsuccessful response + response.raise_for_status() + + return {"replies": [self._build_message_from_ollama_response(response)]} diff --git a/integrations/ollama/src/ollama_haystack/generator.py b/integrations/ollama/src/ollama_haystack/generator.py index f9731d5d3..55bd65d8a 100644 --- a/integrations/ollama/src/ollama_haystack/generator.py +++ b/integrations/ollama/src/ollama_haystack/generator.py @@ -20,7 +20,7 @@ def __init__( system_prompt: Optional[str] = None, template: Optional[str] = None, raw: bool = False, - timeout: int = 30, + timeout: int = 120, ): """ :param model: The name of the model to use. The model should be available in the running Ollama instance. @@ -35,7 +35,7 @@ def __init__( :param raw: If True, no formatting will be applied to the prompt. You may choose to use the raw parameter if you are specifying a full templated prompt in your API request. :param timeout: The number of seconds before throwing a timeout error from the Ollama API. - Default is 30 seconds. + Default is 120 seconds. """ self.timeout = timeout self.raw = raw diff --git a/integrations/ollama/tests/test_chat_generator.py b/integrations/ollama/tests/test_chat_generator.py new file mode 100644 index 000000000..f4e361c7d --- /dev/null +++ b/integrations/ollama/tests/test_chat_generator.py @@ -0,0 +1,128 @@ +from typing import List +from unittest.mock import Mock + +import pytest +from haystack.dataclasses import ChatMessage, ChatRole +from requests import HTTPError, Response + +from ollama_haystack import OllamaChatGenerator + + +@pytest.fixture +def chat_messages() -> List[ChatMessage]: + return [ + ChatMessage.from_user("Tell me about why Super Mario is the greatest superhero"), + ChatMessage.from_assistant( + "Super Mario has prevented Bowser from destroying the world", {"something": "something"} + ), + ] + + +class TestOllamaChatGenerator: + def test_init_default(self): + component = OllamaChatGenerator() + assert component.model == "orca-mini" + assert component.url == "http://localhost:11434/api/chat" + assert component.generation_kwargs == {} + assert component.template is None + assert component.timeout == 120 + + def test_init(self): + component = OllamaChatGenerator( + model="llama2", + url="http://my-custom-endpoint:11434/api/chat", + generation_kwargs={"temperature": 0.5}, + timeout=5, + ) + + assert component.model == "llama2" + assert component.url == "http://my-custom-endpoint:11434/api/chat" + assert component.generation_kwargs == {"temperature": 0.5} + assert component.template is None + assert component.timeout == 5 + + def test_create_json_payload(self, chat_messages): + observed = OllamaChatGenerator(model="some_model")._create_json_payload(chat_messages, {"temperature": 0.1}) + expected = { + "messages": [ + {"role": "user", "content": "Tell me about why Super Mario is the greatest superhero"}, + {"role": "assistant", "content": "Super Mario has prevented Bowser from destroying the world"}, + ], + "model": "some_model", + "stream": False, + "template": None, + "options": {"temperature": 0.1}, + } + + assert observed == expected + + def test_build_message_from_ollama_response(self): + model = "some_model" + + mock_ollama_response = Mock(Response) + mock_ollama_response.json.return_value = { + "model": model, + "created_at": "2023-12-12T14:13:43.416799Z", + "message": {"role": "assistant", "content": "Hello! How are you today?"}, + "done": True, + "total_duration": 5191566416, + "load_duration": 2154458, + "prompt_eval_count": 26, + "prompt_eval_duration": 383809000, + "eval_count": 298, + "eval_duration": 4799921000, + } + + observed = OllamaChatGenerator(model=model)._build_message_from_ollama_response(mock_ollama_response) + + assert observed.role == "assistant" + assert observed.content == "Hello! How are you today?" + + @pytest.mark.integration + def test_run(self): + chat_generator = OllamaChatGenerator() + + user_questions_and_assistant_answers = [ + ("What's the capital of France?", "Paris"), + ("What is the capital of Canada?", "Ottawa"), + ("What is the capital of Ghana?", "Accra"), + ] + + for question, answer in user_questions_and_assistant_answers: + message = ChatMessage.from_user(question) + + response = chat_generator.run([message]) + + assert isinstance(response, dict) + assert isinstance(response["replies"], list) + assert answer in response["replies"][0].content + + @pytest.mark.integration + def test_run_with_chat_history(self): + chat_generator = OllamaChatGenerator() + + chat_history = [ + {"role": "user", "content": "What is the largest city in the United Kingdom by population?"}, + {"role": "assistant", "content": "London is the largest city in the United Kingdom by population"}, + {"role": "user", "content": "And what is the second largest?"}, + ] + + chat_messages = [ + ChatMessage(role=ChatRole(message["role"]), content=message["content"], name=None) + for message in chat_history + ] + response = chat_generator.run(chat_messages) + + assert isinstance(response, dict) + assert isinstance(response["replies"], list) + assert "Manchester" in response["replies"][-1].content + + @pytest.mark.integration + def test_run_model_unavailable(self): + component = OllamaChatGenerator(model="Alistair_and_Stefano_are_great") + + with pytest.raises(HTTPError): + message = ChatMessage.from_user( + "Based on your infinite wisdom, can you tell me why Alistair and Stefano are so great?" + ) + component.run([message]) diff --git a/integrations/ollama/tests/test_generator.py b/integrations/ollama/tests/test_generator.py index c2450a3ec..18c4d2826 100644 --- a/integrations/ollama/tests/test_generator.py +++ b/integrations/ollama/tests/test_generator.py @@ -42,7 +42,7 @@ def test_init_default(self): assert component.system_prompt is None assert component.template is None assert component.raw is False - assert component.timeout == 30 + assert component.timeout == 120 def test_init(self): component = OllamaGenerator(