Skip to content

Commit

Permalink
generator: Ollama (#876)
Browse files Browse the repository at this point in the history
Add two Ollama generators, using either the chat or generate functions
from the Ollama package.
  • Loading branch information
jmartin-tech committed Sep 30, 2024
2 parents cc88954 + c090a4d commit aadc9aa
Show file tree
Hide file tree
Showing 6 changed files with 240 additions and 0 deletions.
8 changes: 8 additions & 0 deletions docs/source/garak.generators.ollama.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
garak.generators.ollama
========================

.. automodule:: garak.generators.ollama
:members:
:undoc-members:
:show-inheritance:

1 change: 1 addition & 0 deletions docs/source/generators.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ For a detailed oversight into how a generator operates, see :ref:`garak.generato
garak.generators.langchain_serve
garak.generators.litellm
garak.generators.octo
garak.generators.ollama
garak.generators.openai
garak.generators.nemo
garak.generators.nim
Expand Down
85 changes: 85 additions & 0 deletions garak/generators/ollama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""Ollama interface"""

from typing import List, Union

import backoff
import ollama

from garak import _config
from garak.generators.base import Generator
from httpx import TimeoutException


def _give_up(error):
return isinstance(error, ollama.ResponseError) and error.status_code == 404


class OllamaGenerator(Generator):
"""Interface for Ollama endpoints
Model names can be passed in short form like "llama2" or specific versions or sizes like "gemma:7b" or "llama2:latest"
"""

DEFAULT_PARAMS = Generator.DEFAULT_PARAMS | {
"timeout": 30, # Add a timeout of 30 seconds. Ollama can tend to hang forever on failures, if this is not present
"host": "127.0.0.1:11434", # The default host of an Ollama server. This can be overwritten with a passed config or generator config file.
}

active = True
generator_family_name = "Ollama"
parallel_capable = False

def __init__(self, name="", config_root=_config):
super().__init__(name, config_root) # Sets the name and generations

self.client = ollama.Client(
self.host, timeout=self.timeout
) # Instantiates the client with the timeout

@backoff.on_exception(
backoff.fibo,
(TimeoutException, ollama.ResponseError),
max_value=70,
giveup=_give_up,
)
@backoff.on_predicate(
backoff.fibo, lambda ans: ans == [None] or len(ans) == 0, max_tries=3
) # Ollama sometimes returns empty responses. Only 3 retries to not delay generations expecting empty responses too much
def _call_model(
self, prompt: str, generations_this_call: int = 1
) -> List[Union[str, None]]:
response = self.client.generate(self.name, prompt)
return [response.get("response", None)]


class OllamaGeneratorChat(OllamaGenerator):
"""Interface for Ollama endpoints, using the chat functionality
Model names can be passed in short form like "llama2" or specific versions or sizes like "gemma:7b" or "llama2:latest"
"""

@backoff.on_exception(
backoff.fibo,
(TimeoutException, ollama.ResponseError),
max_value=70,
giveup=_give_up,
)
@backoff.on_predicate(
backoff.fibo, lambda ans: ans == [None] or len(ans) == 0, max_tries=3
) # Ollama sometimes returns empty responses. Only 3 retries to not delay generations expecting empty responses too much
def _call_model(
self, prompt: str, generations_this_call: int = 1
) -> List[Union[str, None]]:
response = self.client.chat(
model=self.name,
messages=[
{
"role": "user",
"content": prompt,
},
],
)
return [response.get("message", {}).get("content", None)] # Return the response or None


DEFAULT_CLASS = "OllamaGeneratorChat"
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ dependencies = [
"lorem==0.1.1",
"xdg-base-dirs>=6.0.1",
"wn==0.9.5",
"ollama>=0.1.7"
]

[project.optional-dependencies]
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ python-magic>=0.4.21; sys_platform != "win32"
lorem==0.1.1
xdg-base-dirs>=6.0.1
wn==0.9.5
ollama>=0.1.7
# tests
pytest>=8.0
requests-mock==1.12.1
Expand Down
144 changes: 144 additions & 0 deletions tests/generators/test_ollama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
import pytest
import ollama
import respx
import httpx
from httpx import ConnectError
from garak.generators.ollama import OllamaGeneratorChat, OllamaGenerator

PINGED_OLLAMA_SERVER = False # Avoid calling the server multiple times if it is not running
OLLAMA_SERVER_UP = False


def ollama_is_running():
global PINGED_OLLAMA_SERVER
global OLLAMA_SERVER_UP

if not PINGED_OLLAMA_SERVER:
try:
ollama.list() # Gets a list of all pulled models. Used as a ping
OLLAMA_SERVER_UP = True
except ConnectError:
OLLAMA_SERVER_UP = False
finally:
PINGED_OLLAMA_SERVER = True
return OLLAMA_SERVER_UP


def no_models():
return len(ollama.list()) == 0 or len(ollama.list()["models"]) == 0


@pytest.mark.skipif(
not ollama_is_running(),
reason=f"Ollama server is not currently running",
)
def test_error_on_nonexistant_model_chat():
model_name = "non-existant-model"
gen = OllamaGeneratorChat(model_name)
with pytest.raises(ollama.ResponseError):
gen.generate("This shouldnt work")


@pytest.mark.skipif(
not ollama_is_running(),
reason=f"Ollama server is not currently running",
)
def test_error_on_nonexistant_model():
model_name = "non-existant-model"
gen = OllamaGenerator(model_name)
with pytest.raises(ollama.ResponseError):
gen.generate("This shouldnt work")


@pytest.mark.skipif(
not ollama_is_running(),
reason=f"Ollama server is not currently running",
)
@pytest.mark.skipif(
not ollama_is_running() or no_models(), # Avoid checking models if no server
reason=f"No Ollama models pulled",
)
# This test might fail if the GPU is busy, and the generation takes more than 30 seconds
def test_generation_on_pulled_model_chat():
model_name = ollama.list()["models"][0]["name"]
gen = OllamaGeneratorChat(model_name)
responses = gen.generate('Say "Hello!"')
assert len(responses) == 1
assert all(isinstance(response, str) for response in responses)
assert all(len(response) > 0 for response in responses)


@pytest.mark.skipif(
not ollama_is_running(),
reason=f"Ollama server is not currently running",
)
@pytest.mark.skipif(
not ollama_is_running() or no_models(), # Avoid checking models if no server
reason=f"No Ollama models pulled",
)
# This test might fail if the GPU is busy, and the generation takes more than 30 seconds
def test_generation_on_pulled_model():
model_name = ollama.list()["models"][0]["name"]
gen = OllamaGenerator(model_name)
responses = gen.generate('Say "Hello!"')
assert len(responses) == 1
assert all(isinstance(response, str) for response in responses)
assert all(len(response) > 0 for response in responses)

@pytest.mark.respx(base_url="http://" + OllamaGenerator.DEFAULT_PARAMS["host"])
def test_ollama_generation_mocked(respx_mock):
mock_response = {
'model': 'mistral',
'response': 'Hello how are you?'
}
respx_mock.post('/api/generate').mock(
return_value=httpx.Response(200, json=mock_response)
)
gen = OllamaGenerator("mistral")
generation = gen.generate("Bla bla")
assert generation == ['Hello how are you?']


@pytest.mark.respx(base_url="http://" + OllamaGenerator.DEFAULT_PARAMS["host"])
def test_ollama_generation_chat_mocked(respx_mock):
mock_response = {
'model': 'mistral',
'message': {
'role': 'assistant',
'content': 'Hello how are you?'
}
}
respx_mock.post('/api/chat').mock(
return_value=httpx.Response(200, json=mock_response)
)
gen = OllamaGeneratorChat("mistral")
generation = gen.generate("Bla bla")
assert generation == ['Hello how are you?']


@pytest.mark.respx(base_url="http://" + OllamaGenerator.DEFAULT_PARAMS["host"])
def test_error_on_nonexistant_model_mocked(respx_mock):
mock_response = {
'error': "No such model"
}
respx_mock.post('/api/generate').mock(
return_value=httpx.Response(404, json=mock_response)
)
model_name = "non-existant-model"
gen = OllamaGenerator(model_name)
with pytest.raises(ollama.ResponseError):
gen.generate("This shouldnt work")


@pytest.mark.respx(base_url="http://" + OllamaGenerator.DEFAULT_PARAMS["host"])
def test_error_on_nonexistant_model_chat_mocked(respx_mock):
mock_response = {
'error': "No such model"
}
respx_mock.post('/api/chat').mock(
return_value=httpx.Response(404, json=mock_response)
)
model_name = "non-existant-model"
gen = OllamaGeneratorChat(model_name)
with pytest.raises(ollama.ResponseError):
gen.generate("This shouldnt work")

0 comments on commit aadc9aa

Please sign in to comment.