-
Notifications
You must be signed in to change notification settings - Fork 284
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add two Ollama generators, using either the chat or generate functions from the Ollama package.
- Loading branch information
Showing
6 changed files
with
240 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
garak.generators.ollama | ||
======================== | ||
|
||
.. automodule:: garak.generators.ollama | ||
:members: | ||
:undoc-members: | ||
:show-inheritance: | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
"""Ollama interface""" | ||
|
||
from typing import List, Union | ||
|
||
import backoff | ||
import ollama | ||
|
||
from garak import _config | ||
from garak.generators.base import Generator | ||
from httpx import TimeoutException | ||
|
||
|
||
def _give_up(error): | ||
return isinstance(error, ollama.ResponseError) and error.status_code == 404 | ||
|
||
|
||
class OllamaGenerator(Generator): | ||
"""Interface for Ollama endpoints | ||
Model names can be passed in short form like "llama2" or specific versions or sizes like "gemma:7b" or "llama2:latest" | ||
""" | ||
|
||
DEFAULT_PARAMS = Generator.DEFAULT_PARAMS | { | ||
"timeout": 30, # Add a timeout of 30 seconds. Ollama can tend to hang forever on failures, if this is not present | ||
"host": "127.0.0.1:11434", # The default host of an Ollama server. This can be overwritten with a passed config or generator config file. | ||
} | ||
|
||
active = True | ||
generator_family_name = "Ollama" | ||
parallel_capable = False | ||
|
||
def __init__(self, name="", config_root=_config): | ||
super().__init__(name, config_root) # Sets the name and generations | ||
|
||
self.client = ollama.Client( | ||
self.host, timeout=self.timeout | ||
) # Instantiates the client with the timeout | ||
|
||
@backoff.on_exception( | ||
backoff.fibo, | ||
(TimeoutException, ollama.ResponseError), | ||
max_value=70, | ||
giveup=_give_up, | ||
) | ||
@backoff.on_predicate( | ||
backoff.fibo, lambda ans: ans == [None] or len(ans) == 0, max_tries=3 | ||
) # Ollama sometimes returns empty responses. Only 3 retries to not delay generations expecting empty responses too much | ||
def _call_model( | ||
self, prompt: str, generations_this_call: int = 1 | ||
) -> List[Union[str, None]]: | ||
response = self.client.generate(self.name, prompt) | ||
return [response.get("response", None)] | ||
|
||
|
||
class OllamaGeneratorChat(OllamaGenerator): | ||
"""Interface for Ollama endpoints, using the chat functionality | ||
Model names can be passed in short form like "llama2" or specific versions or sizes like "gemma:7b" or "llama2:latest" | ||
""" | ||
|
||
@backoff.on_exception( | ||
backoff.fibo, | ||
(TimeoutException, ollama.ResponseError), | ||
max_value=70, | ||
giveup=_give_up, | ||
) | ||
@backoff.on_predicate( | ||
backoff.fibo, lambda ans: ans == [None] or len(ans) == 0, max_tries=3 | ||
) # Ollama sometimes returns empty responses. Only 3 retries to not delay generations expecting empty responses too much | ||
def _call_model( | ||
self, prompt: str, generations_this_call: int = 1 | ||
) -> List[Union[str, None]]: | ||
response = self.client.chat( | ||
model=self.name, | ||
messages=[ | ||
{ | ||
"role": "user", | ||
"content": prompt, | ||
}, | ||
], | ||
) | ||
return [response.get("message", {}).get("content", None)] # Return the response or None | ||
|
||
|
||
DEFAULT_CLASS = "OllamaGeneratorChat" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,144 @@ | ||
import pytest | ||
import ollama | ||
import respx | ||
import httpx | ||
from httpx import ConnectError | ||
from garak.generators.ollama import OllamaGeneratorChat, OllamaGenerator | ||
|
||
PINGED_OLLAMA_SERVER = False # Avoid calling the server multiple times if it is not running | ||
OLLAMA_SERVER_UP = False | ||
|
||
|
||
def ollama_is_running(): | ||
global PINGED_OLLAMA_SERVER | ||
global OLLAMA_SERVER_UP | ||
|
||
if not PINGED_OLLAMA_SERVER: | ||
try: | ||
ollama.list() # Gets a list of all pulled models. Used as a ping | ||
OLLAMA_SERVER_UP = True | ||
except ConnectError: | ||
OLLAMA_SERVER_UP = False | ||
finally: | ||
PINGED_OLLAMA_SERVER = True | ||
return OLLAMA_SERVER_UP | ||
|
||
|
||
def no_models(): | ||
return len(ollama.list()) == 0 or len(ollama.list()["models"]) == 0 | ||
|
||
|
||
@pytest.mark.skipif( | ||
not ollama_is_running(), | ||
reason=f"Ollama server is not currently running", | ||
) | ||
def test_error_on_nonexistant_model_chat(): | ||
model_name = "non-existant-model" | ||
gen = OllamaGeneratorChat(model_name) | ||
with pytest.raises(ollama.ResponseError): | ||
gen.generate("This shouldnt work") | ||
|
||
|
||
@pytest.mark.skipif( | ||
not ollama_is_running(), | ||
reason=f"Ollama server is not currently running", | ||
) | ||
def test_error_on_nonexistant_model(): | ||
model_name = "non-existant-model" | ||
gen = OllamaGenerator(model_name) | ||
with pytest.raises(ollama.ResponseError): | ||
gen.generate("This shouldnt work") | ||
|
||
|
||
@pytest.mark.skipif( | ||
not ollama_is_running(), | ||
reason=f"Ollama server is not currently running", | ||
) | ||
@pytest.mark.skipif( | ||
not ollama_is_running() or no_models(), # Avoid checking models if no server | ||
reason=f"No Ollama models pulled", | ||
) | ||
# This test might fail if the GPU is busy, and the generation takes more than 30 seconds | ||
def test_generation_on_pulled_model_chat(): | ||
model_name = ollama.list()["models"][0]["name"] | ||
gen = OllamaGeneratorChat(model_name) | ||
responses = gen.generate('Say "Hello!"') | ||
assert len(responses) == 1 | ||
assert all(isinstance(response, str) for response in responses) | ||
assert all(len(response) > 0 for response in responses) | ||
|
||
|
||
@pytest.mark.skipif( | ||
not ollama_is_running(), | ||
reason=f"Ollama server is not currently running", | ||
) | ||
@pytest.mark.skipif( | ||
not ollama_is_running() or no_models(), # Avoid checking models if no server | ||
reason=f"No Ollama models pulled", | ||
) | ||
# This test might fail if the GPU is busy, and the generation takes more than 30 seconds | ||
def test_generation_on_pulled_model(): | ||
model_name = ollama.list()["models"][0]["name"] | ||
gen = OllamaGenerator(model_name) | ||
responses = gen.generate('Say "Hello!"') | ||
assert len(responses) == 1 | ||
assert all(isinstance(response, str) for response in responses) | ||
assert all(len(response) > 0 for response in responses) | ||
|
||
@pytest.mark.respx(base_url="http://" + OllamaGenerator.DEFAULT_PARAMS["host"]) | ||
def test_ollama_generation_mocked(respx_mock): | ||
mock_response = { | ||
'model': 'mistral', | ||
'response': 'Hello how are you?' | ||
} | ||
respx_mock.post('/api/generate').mock( | ||
return_value=httpx.Response(200, json=mock_response) | ||
) | ||
gen = OllamaGenerator("mistral") | ||
generation = gen.generate("Bla bla") | ||
assert generation == ['Hello how are you?'] | ||
|
||
|
||
@pytest.mark.respx(base_url="http://" + OllamaGenerator.DEFAULT_PARAMS["host"]) | ||
def test_ollama_generation_chat_mocked(respx_mock): | ||
mock_response = { | ||
'model': 'mistral', | ||
'message': { | ||
'role': 'assistant', | ||
'content': 'Hello how are you?' | ||
} | ||
} | ||
respx_mock.post('/api/chat').mock( | ||
return_value=httpx.Response(200, json=mock_response) | ||
) | ||
gen = OllamaGeneratorChat("mistral") | ||
generation = gen.generate("Bla bla") | ||
assert generation == ['Hello how are you?'] | ||
|
||
|
||
@pytest.mark.respx(base_url="http://" + OllamaGenerator.DEFAULT_PARAMS["host"]) | ||
def test_error_on_nonexistant_model_mocked(respx_mock): | ||
mock_response = { | ||
'error': "No such model" | ||
} | ||
respx_mock.post('/api/generate').mock( | ||
return_value=httpx.Response(404, json=mock_response) | ||
) | ||
model_name = "non-existant-model" | ||
gen = OllamaGenerator(model_name) | ||
with pytest.raises(ollama.ResponseError): | ||
gen.generate("This shouldnt work") | ||
|
||
|
||
@pytest.mark.respx(base_url="http://" + OllamaGenerator.DEFAULT_PARAMS["host"]) | ||
def test_error_on_nonexistant_model_chat_mocked(respx_mock): | ||
mock_response = { | ||
'error': "No such model" | ||
} | ||
respx_mock.post('/api/chat').mock( | ||
return_value=httpx.Response(404, json=mock_response) | ||
) | ||
model_name = "non-existant-model" | ||
gen = OllamaGeneratorChat(model_name) | ||
with pytest.raises(ollama.ResponseError): | ||
gen.generate("This shouldnt work") |