Skip to content

Commit

Permalink
add tests for ggml generator
Browse files Browse the repository at this point in the history
Signed-off-by: Jeffrey Martin <[email protected]>
  • Loading branch information
jmartin-tech committed Apr 22, 2024
1 parent 4e5cec4 commit f5fd86c
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 6 deletions.
5 changes: 3 additions & 2 deletions garak/generators/ggml.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from garak.generators.base import Generator

GGUF_MAGIC = bytes([0x47, 0x47, 0x55, 0x46])
ENV_VAR = "GGML_MAIN_PATH"


class GgmlGenerator(Generator):
Expand Down Expand Up @@ -55,9 +56,9 @@ def command_params(self):
}

def __init__(self, name, generations=10):
self.path_to_ggml_main = os.getenv("GGML_MAIN_PATH")
self.path_to_ggml_main = os.getenv(ENV_VAR)
if self.path_to_ggml_main is None:
raise RuntimeError("Executable not provided by environment GGML_MAIN_PATH")
raise RuntimeError(f"Executable not provided by environment {ENV_VAR}")
if not os.path.isfile(self.path_to_ggml_main):
raise FileNotFoundError(
f"Path provided is not a file: {self.path_to_ggml_main}"
Expand Down
9 changes: 8 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ dependencies = [
"cohere>=4.5.1,<5",
"openai==1.12.0",
"replicate>=0.8.3",
"pytest>=8.0",
"google-api-python-client>=2.0",
"backoff>=2.1.1",
"rapidfuzz>=3.0.0",
Expand All @@ -63,6 +62,14 @@ dependencies = [
"typing>=3.7,<3.8; python_version<'3.5'"
]

[project.optional-dependencies]
tests = [
"pytest>=8.0",
]
lint = [
"black>=22.3",
]

[project.urls]
"Homepage" = "https://github.com/leondz/garak"
"Bug Tracker" = "https://github.com/leondz/garak/issues"
Expand Down
5 changes: 4 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ tqdm>=4.64.0
cohere>=4.5.1,<5
openai==1.12.0
replicate>=0.8.3
pytest>=8.0
google-api-python-client>=2.0
backoff>=2.1.1
rapidfuzz>=3.0.0
Expand All @@ -29,3 +28,7 @@ deepl==1.17.0
fschat>=0.2.36
litellm>=1.33.8
typing>=3.7,<3.8; python_version<'3.5'
# tests
pytest>=8.0
# lint
black>=22.3
74 changes: 74 additions & 0 deletions tests/generators/test_ggml.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
#!/usr/bin/env python3

import os
import pytest
import tempfile
import garak.generators.ggml

DEFAULT_GENERATIONS_QTY = 10

stored_env = os.getenv(garak.generators.ggml.ENV_VAR)
model_name = None


@pytest.fixture(autouse=True)
def reset_env(request) -> None:
if stored_env is not None:
os.environ[garak.generators.ggml.ENV_VAR] = stored_env
model_name = tempfile.NamedTemporaryFile(suffix="_test_model.gguf")


@pytest.fixture(autouse=True)
def set_fake_env() -> None:
os.environ[garak.generators.ggml.ENV_VAR] = os.path.abspath(__file__)


def test_init_bad_app():
with pytest.raises(RuntimeError) as exc_info:
if os.getenv(garak.generators.ggml.ENV_VAR) is not None:
del os.environ[garak.generators.ggml.ENV_VAR]
garak.generators.ggml.GgmlGenerator(model_name)
assert "not provided by environment" in str(exc_info.value)


def test_init_missing_model():
model_name = tempfile.TemporaryFile().name
with pytest.raises(FileNotFoundError) as exc_info:
garak.generators.ggml.GgmlGenerator(model_name)
assert "File not found" in str(exc_info.value)


def test_init_bad_model():
with tempfile.NamedTemporaryFile(mode="w", suffix="_test_model.gguf") as file:
file.write(file.name)
file.seek(0)
with pytest.raises(RuntimeError) as exc_info:
garak.generators.ggml.GgmlGenerator(file.name)
assert "not in GGUF" in str(exc_info.value)


def test_init_good_model():
with tempfile.NamedTemporaryFile(suffix="_test_model.gguf") as file:
file.write(garak.generators.ggml.GGUF_MAGIC)
file.seek(0)
g = garak.generators.ggml.GgmlGenerator(file.name)
assert g is not None


# def test_subprocess_first_err():
# with open(model_name, "wb") as file:
# file.write(garak.generators.ggml.GGUF_MAGIC)
# g = garak.generators.ggml.GgmlGenerator(model_name)
# with pytest.raises(subprocess.CalledProcessError) as exc_info:
# result = g.generate("mock prompt")
# assert exc_info is not None


# def test_subprocess_continue_err():
# with open(model_name, "wb") as file:
# file.write(garak.generators.ggml.GGUF_MAGIC)
# g = garak.generators.ggml.GgmlGenerator(model_name)
# g.generate("mock prompt")
# with pytest.raises(subprocess.CalledProcessError) as exc_info:
# result = g.generate("mock prompt")
# assert exc_info is not None
10 changes: 8 additions & 2 deletions tests/test_reqs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,16 @@
def test_requirements_txt_pyproject_toml():
with open("requirements.txt", "r", encoding="utf-8") as req_file:
reqtxt_reqs = req_file.readlines()
reqtxt_reqs = list(map(str.strip, reqtxt_reqs))
reqtxt_reqs = list(
filter(lambda x: not x.startswith("#"), map(str.strip, reqtxt_reqs))
)
reqtxt_reqs.sort()
with open("pyproject.toml", "rb") as pyproject_file:
pyproject_reqs = tomllib.load(pyproject_file)["project"]["dependencies"]
pyproject_toml = tomllib.load(pyproject_file)
pyproject_reqs = pyproject_toml["project"]["dependencies"]
for test_deps in pyproject_toml["project"]["optional-dependencies"].values():
for dep in test_deps:
pyproject_reqs.append(dep)
pyproject_reqs.sort()
# assert len(reqtxt_reqs) == len(pyproject_reqs) # same number of requirements
assert (
Expand Down

0 comments on commit f5fd86c

Please sign in to comment.