Skip to content

Commit

Permalink
Ollama Chat Generator (#176)
Browse files Browse the repository at this point in the history
* update inits to expose ollamachatgenerator

* add ollama chat generator

* add tests for ollama chat generator

* add tests for init method

* Change order of chat history to chronological

* add test for chat history

* add return type to _build_message

* refactor message_to_dict to one liner

* add return types to fixtures

* add test for unavailable model

* drop streaming references for now

* drop streaming callback from tests

* Update integrations/ollama/src/ollama_haystack/chat/chat_generator.py

Co-authored-by: Stefano Fiorucci <[email protected]>

* Update integrations/ollama/src/ollama_haystack/chat/chat_generator.py

Co-authored-by: Stefano Fiorucci <[email protected]>

* Update integrations/ollama/src/ollama_haystack/chat/chat_generator.py

Co-authored-by: Stefano Fiorucci <[email protected]>

* drop _chat_history_to_dict

* drop intermediate ollama to haystack response methods

* change metadata to meta

* lint with black

* refactor chat message fixture into one list

* add chat generator example

* rename example -> generator example

* add new chat generator example

* Update integrations/ollama/src/ollama_haystack/chat/chat_generator.py

Co-authored-by: Stefano Fiorucci <[email protected]>

* update test for new timeout

* Update test_chat_generator.py

* increase generator timeout

* add docstrings

* fix

---------

Co-authored-by: Stefano Fiorucci <[email protected]>
  • Loading branch information
AlistairLR112 and anakin87 authored Jan 8, 2024
1 parent 9411c99 commit 7435282
Show file tree
Hide file tree
Showing 8 changed files with 278 additions and 4 deletions.
49 changes: 49 additions & 0 deletions integrations/ollama/example/chat_generator_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# In order to run this example, you will need to have an instance of Ollama running with the
# orca-mini model downloaded. We suggest you use the following commands to serve an orca-mini
# model from Ollama
#
# docker run -d -p 11434:11434 --name ollama ollama/ollama:latest
# docker exec ollama ollama pull orca-mini

from haystack.dataclasses import ChatMessage

from ollama_haystack import OllamaChatGenerator

messages = [
ChatMessage.from_user("What's Natural Language Processing?"),
ChatMessage.from_system(
"Natural Language Processing (NLP) is a field of computer science and artificial "
"intelligence concerned with the interaction between computers and human language"
),
ChatMessage.from_user("How do I get started?"),
]
client = OllamaChatGenerator(model="orca-mini", timeout=45, url="http://localhost:11434/api/chat")

response = client.run(messages, generation_kwargs={"temperature": 0.2})

print(response["replies"])
#
# [
# ChatMessage(
# content="Natural Language Processing (NLP) is a broad field of computer science and artificial intelligence "
# "that involves the interaction between computers and human language. To get started in NLP, "
# "you can start by learning about the different techniques and tools used in NLP such as machine "
# "learning algorithms, deep learning frameworks, and natural language processing libraries. You can "
# "also learn about the applications of NLP in various fields such as chatbots, sentiment analysis, "
# "speech recognition, and text classification. Additionally, you can explore the available resources "
# "such as online courses, tutorials, and books on NLP to gain a deeper understanding of the field.",
# role=<ChatRole.ASSISTANT: 'assistant'>,
# name=None,
# meta={
# "model": "orca-mini",
# "created_at": "2024-01-08T15:35:23.378609793Z",
# "done": True,
# "total_duration": 20026330217,
# "load_duration": 1540167,
# "prompt_eval_count": 99,
# "prompt_eval_duration": 8486609000,
# "eval_count": 124,
# "eval_duration": 11532988000,
# },
# )
# ]
File renamed without changes.
3 changes: 2 additions & 1 deletion integrations/ollama/src/ollama_haystack/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# SPDX-License-Identifier: Apache-2.0

from ollama_haystack.chat.chat_generator import OllamaChatGenerator
from ollama_haystack.generator import OllamaGenerator

__all__ = ["OllamaGenerator"]
__all__ = ["OllamaGenerator", "OllamaChatGenerator"]
Empty file.
96 changes: 96 additions & 0 deletions integrations/ollama/src/ollama_haystack/chat/chat_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from typing import Any, Dict, List, Optional

import requests
from haystack import component
from haystack.dataclasses import ChatMessage
from requests import Response


@component
class OllamaChatGenerator:
"""
Chat Generator based on Ollama. Ollama is a library for easily running LLMs locally.
This component provides an interface to generate text using a LLM running in Ollama.
"""

def __init__(
self,
model: str = "orca-mini",
url: str = "http://localhost:11434/api/chat",
generation_kwargs: Optional[Dict[str, Any]] = None,
template: Optional[str] = None,
timeout: int = 120,
):
"""
:param model: The name of the model to use. The model should be available in the running Ollama instance.
Default is "orca-mini".
:param url: The URL of the chat endpoint of a running Ollama instance.
Default is "http://localhost:11434/api/chat".
:param generation_kwargs: Optional arguments to pass to the Ollama generation endpoint, such as temperature,
top_p, and others. See the available arguments in
[Ollama docs](https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values).
:param template: The full prompt template (overrides what is defined in the Ollama Modelfile).
:param timeout: The number of seconds before throwing a timeout error from the Ollama API.
Default is 120 seconds.
"""

self.timeout = timeout
self.template = template
self.generation_kwargs = generation_kwargs or {}
self.url = url
self.model = model

def _message_to_dict(self, message: ChatMessage) -> Dict[str, str]:
return {"role": message.role.value, "content": message.content}

def _create_json_payload(self, messages: List[ChatMessage], generation_kwargs=None) -> Dict[str, Any]:
"""
Returns A dictionary of JSON arguments for a POST request to an Ollama service
:param messages: A history/list of chat messages
:param generation_kwargs:
:return: A dictionary of arguments for a POST request to an Ollama service
"""
generation_kwargs = generation_kwargs or {}
return {
"messages": [self._message_to_dict(message) for message in messages],
"model": self.model,
"stream": False,
"template": self.template,
"options": generation_kwargs,
}

def _build_message_from_ollama_response(self, ollama_response: Response) -> ChatMessage:
"""
Converts the non-streaming response from the Ollama API to a ChatMessage.
:param ollama_response: The completion returned by the Ollama API.
:return: The ChatMessage.
"""
json_content = ollama_response.json()
message = ChatMessage.from_assistant(content=json_content["message"]["content"])
message.meta.update({key: value for key, value in json_content.items() if key != "message"})
return message

@component.output_types(replies=List[ChatMessage])
def run(
self,
messages: List[ChatMessage],
generation_kwargs: Optional[Dict[str, Any]] = None,
):
"""
Run an Ollama Model on a given chat history.
:param messages: A list of ChatMessage instances representing the input messages.
:param generation_kwargs: Optional arguments to pass to the Ollama generation endpoint, such as temperature,
top_p, etc. See the
[Ollama docs](https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values).
:return: A dictionary of the replies containing their metadata
"""
generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}

json_payload = self._create_json_payload(messages, generation_kwargs)

response = requests.post(url=self.url, json=json_payload, timeout=self.timeout)

# throw error on unsuccessful response
response.raise_for_status()

return {"replies": [self._build_message_from_ollama_response(response)]}
4 changes: 2 additions & 2 deletions integrations/ollama/src/ollama_haystack/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def __init__(
system_prompt: Optional[str] = None,
template: Optional[str] = None,
raw: bool = False,
timeout: int = 30,
timeout: int = 120,
):
"""
:param model: The name of the model to use. The model should be available in the running Ollama instance.
Expand All @@ -35,7 +35,7 @@ def __init__(
:param raw: If True, no formatting will be applied to the prompt. You may choose to use the raw parameter
if you are specifying a full templated prompt in your API request.
:param timeout: The number of seconds before throwing a timeout error from the Ollama API.
Default is 30 seconds.
Default is 120 seconds.
"""
self.timeout = timeout
self.raw = raw
Expand Down
128 changes: 128 additions & 0 deletions integrations/ollama/tests/test_chat_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
from typing import List
from unittest.mock import Mock

import pytest
from haystack.dataclasses import ChatMessage, ChatRole
from requests import HTTPError, Response

from ollama_haystack import OllamaChatGenerator


@pytest.fixture
def chat_messages() -> List[ChatMessage]:
return [
ChatMessage.from_user("Tell me about why Super Mario is the greatest superhero"),
ChatMessage.from_assistant(
"Super Mario has prevented Bowser from destroying the world", {"something": "something"}
),
]


class TestOllamaChatGenerator:
def test_init_default(self):
component = OllamaChatGenerator()
assert component.model == "orca-mini"
assert component.url == "http://localhost:11434/api/chat"
assert component.generation_kwargs == {}
assert component.template is None
assert component.timeout == 120

def test_init(self):
component = OllamaChatGenerator(
model="llama2",
url="http://my-custom-endpoint:11434/api/chat",
generation_kwargs={"temperature": 0.5},
timeout=5,
)

assert component.model == "llama2"
assert component.url == "http://my-custom-endpoint:11434/api/chat"
assert component.generation_kwargs == {"temperature": 0.5}
assert component.template is None
assert component.timeout == 5

def test_create_json_payload(self, chat_messages):
observed = OllamaChatGenerator(model="some_model")._create_json_payload(chat_messages, {"temperature": 0.1})
expected = {
"messages": [
{"role": "user", "content": "Tell me about why Super Mario is the greatest superhero"},
{"role": "assistant", "content": "Super Mario has prevented Bowser from destroying the world"},
],
"model": "some_model",
"stream": False,
"template": None,
"options": {"temperature": 0.1},
}

assert observed == expected

def test_build_message_from_ollama_response(self):
model = "some_model"

mock_ollama_response = Mock(Response)
mock_ollama_response.json.return_value = {
"model": model,
"created_at": "2023-12-12T14:13:43.416799Z",
"message": {"role": "assistant", "content": "Hello! How are you today?"},
"done": True,
"total_duration": 5191566416,
"load_duration": 2154458,
"prompt_eval_count": 26,
"prompt_eval_duration": 383809000,
"eval_count": 298,
"eval_duration": 4799921000,
}

observed = OllamaChatGenerator(model=model)._build_message_from_ollama_response(mock_ollama_response)

assert observed.role == "assistant"
assert observed.content == "Hello! How are you today?"

@pytest.mark.integration
def test_run(self):
chat_generator = OllamaChatGenerator()

user_questions_and_assistant_answers = [
("What's the capital of France?", "Paris"),
("What is the capital of Canada?", "Ottawa"),
("What is the capital of Ghana?", "Accra"),
]

for question, answer in user_questions_and_assistant_answers:
message = ChatMessage.from_user(question)

response = chat_generator.run([message])

assert isinstance(response, dict)
assert isinstance(response["replies"], list)
assert answer in response["replies"][0].content

@pytest.mark.integration
def test_run_with_chat_history(self):
chat_generator = OllamaChatGenerator()

chat_history = [
{"role": "user", "content": "What is the largest city in the United Kingdom by population?"},
{"role": "assistant", "content": "London is the largest city in the United Kingdom by population"},
{"role": "user", "content": "And what is the second largest?"},
]

chat_messages = [
ChatMessage(role=ChatRole(message["role"]), content=message["content"], name=None)
for message in chat_history
]
response = chat_generator.run(chat_messages)

assert isinstance(response, dict)
assert isinstance(response["replies"], list)
assert "Manchester" in response["replies"][-1].content

@pytest.mark.integration
def test_run_model_unavailable(self):
component = OllamaChatGenerator(model="Alistair_and_Stefano_are_great")

with pytest.raises(HTTPError):
message = ChatMessage.from_user(
"Based on your infinite wisdom, can you tell me why Alistair and Stefano are so great?"
)
component.run([message])
2 changes: 1 addition & 1 deletion integrations/ollama/tests/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def test_init_default(self):
assert component.system_prompt is None
assert component.template is None
assert component.raw is False
assert component.timeout == 30
assert component.timeout == 120

def test_init(self):
component = OllamaGenerator(
Expand Down

0 comments on commit 7435282

Please sign in to comment.