From 22d508244c0699cea055ee545b3f9670fbe160f8 Mon Sep 17 00:00:00 2001 From: Umberto Griffo <1609440+umbertogriffo@users.noreply.github.com> Date: Sat, 4 May 2024 14:23:25 +0100 Subject: [PATCH] feat: added tests for the lamacpp client --- poetry.lock | 20 +++++++++++++++- pyproject.toml | 1 + tests/test_lamacpp_client.py | 44 ++++++++++++++++++++++++++++++++++++ 3 files changed, 64 insertions(+), 1 deletion(-) diff --git a/poetry.lock b/poetry.lock index c1e580d..6d3920c 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3226,6 +3226,24 @@ tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""} [package.extras] testing = ["argcomplete", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "xmlschema"] +[[package]] +name = "pytest-asyncio" +version = "0.23.6" +description = "Pytest support for asyncio" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pytest-asyncio-0.23.6.tar.gz", hash = "sha256:ffe523a89c1c222598c76856e76852b787504ddb72dd5d9b6617ffa8aa2cde5f"}, + {file = "pytest_asyncio-0.23.6-py3-none-any.whl", hash = "sha256:68516fdd1018ac57b846c9846b954f0393b26f094764a28c955eabb0536a4e8a"}, +] + +[package.dependencies] +pytest = ">=7.0.0,<9" + +[package.extras] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"] +testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] + [[package]] name = "pytest-cov" version = "4.0.0" @@ -5272,4 +5290,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.11" -content-hash = "32c93716f5dda87fd5225bde0e14a6323641f546d73c6c8f76316f0b97ea2a1f" +content-hash = "5f3dfae2c369b232d7ec9ab1d52b01983af2854d1c97db8862cef966a7dc027d" diff --git a/pyproject.toml b/pyproject.toml index e1ab2a1..edacbf0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,7 @@ nest_asyncio = "~=1.5.8" pytest = "~=7.2.1" pytest-cov = "~=4.0.0" pytest-mock = "~=3.10.0" +pytest-asyncio = "~=0.23.6" pre-commit = "~=3.6.0" ruff = "~=0.1.9" httpx = "~=0.23.3" diff --git a/tests/test_lamacpp_client.py b/tests/test_lamacpp_client.py index 70947ff..c8b73ee 100644 --- a/tests/test_lamacpp_client.py +++ b/tests/test_lamacpp_client.py @@ -1,3 +1,4 @@ +import asyncio from unittest.mock import patch import pytest @@ -41,3 +42,46 @@ def test_generate_answer(lamacpp_client): prompt = "What is the capital city of Italy?" generated_answer = lamacpp_client.generate_answer(prompt, max_new_tokens=10) assert "rome" in generated_answer.lower() + + +def test_generate_stream_answer(lamacpp_client): + prompt = "What is the capital city of Italy?" + generated_answer = lamacpp_client.stream_answer(prompt, max_new_tokens=10) + assert "rome" in generated_answer.lower() + + +def test_start_answer_iterator_streamer(lamacpp_client): + prompt = "What is the capital city of Italy?" + stream = lamacpp_client.start_answer_iterator_streamer(prompt, max_new_tokens=10) + generated_answer = "" + for output in stream: + generated_answer += output["choices"][0]["text"] + assert "rome" in generated_answer.lower() + + +@pytest.mark.asyncio +async def test_async_generate_answer(lamacpp_client): + prompt = "What is the capital city of Italy?" + task = lamacpp_client.async_generate_answer(prompt, max_new_tokens=10) + generated_answer = await asyncio.gather(task) + assert "rome" in generated_answer[0].lower() + + +@pytest.mark.asyncio +async def test_async_start_answer_iterator_streamer(lamacpp_client): + prompt = "What is the capital city of Italy?" + task = lamacpp_client.async_start_answer_iterator_streamer(prompt, max_new_tokens=10) + stream = await asyncio.gather(task) + generated_answer = "" + for output in stream[0]: + generated_answer += output["choices"][0]["text"] + assert "rome" in generated_answer.lower() + + +def test_parse_token(lamacpp_client): + prompt = "What is the capital city of Italy?" + stream = lamacpp_client.start_answer_iterator_streamer(prompt, max_new_tokens=10) + generated_answer = "" + for output in stream: + generated_answer += lamacpp_client.parse_token(output) + assert "rome" in generated_answer.lower()