From 2e1045d8eee089ce8d2c53efe90a3e415e3dfbc4 Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Wed, 6 Nov 2024 17:31:47 -0800 Subject: [PATCH] First test for AsyncModel --- llm/models.py | 2 +- setup.py | 1 + tests/conftest.py | 32 ++++++++++++++++++++++++++++++-- tests/test_async.py | 10 ++++++++++ 4 files changed, 42 insertions(+), 3 deletions(-) create mode 100644 tests/test_async.py diff --git a/llm/models.py b/llm/models.py index 25a016b7..1e6c165e 100644 --- a/llm/models.py +++ b/llm/models.py @@ -373,7 +373,7 @@ async def __aiter__(self) -> AsyncIterator[str]: yield chunk return - async for chunk in await self.model.execute( + async for chunk in self.model.execute( self.prompt, stream=self.stream, response=self, diff --git a/setup.py b/setup.py index 6f500815..24b5acd2 100644 --- a/setup.py +++ b/setup.py @@ -55,6 +55,7 @@ def get_long_description(): "pytest", "numpy", "pytest-httpx>=0.33.0", + "pytest-asyncio", "cogapp", "mypy>=1.10.0", "black>=24.1.0", diff --git a/tests/conftest.py b/tests/conftest.py index bcdb8854..7eb3dd56 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -75,6 +75,29 @@ def execute(self, prompt, stream, response, conversation): break +class AsyncMockModel(llm.AsyncModel): + model_id = "mock" + + def __init__(self): + self.history = [] + self._queue = [] + + def enqueue(self, messages): + assert isinstance(messages, list) + self._queue.append(messages) + + async def execute(self, prompt, stream, response, conversation): + self.history.append((prompt, stream, response, conversation)) + while True: + try: + messages = self._queue.pop(0) + for message in messages: + yield message + break + except IndexError: + break + + class EmbedDemo(llm.EmbeddingModel): model_id = "embed-demo" batch_size = 10 @@ -118,8 +141,13 @@ def mock_model(): return MockModel() +@pytest.fixture +def async_mock_model(): + return AsyncMockModel() + + @pytest.fixture(autouse=True) -def register_embed_demo_model(embed_demo, mock_model): +def register_embed_demo_model(embed_demo, mock_model, async_mock_model): class MockModelsPlugin: __name__ = "MockModelsPlugin" @@ -131,7 +159,7 @@ def register_embedding_models(self, register): @llm.hookimpl def register_models(self, register): - register(mock_model) + register(mock_model, async_model=async_mock_model) pm.register(MockModelsPlugin(), name="undo-mock-models-plugin") try: diff --git a/tests/test_async.py b/tests/test_async.py new file mode 100644 index 00000000..c7d3f9d9 --- /dev/null +++ b/tests/test_async.py @@ -0,0 +1,10 @@ +import pytest + + +@pytest.mark.asyncio +async def test_async_model(async_mock_model): + gathered = [] + async_mock_model.enqueue(["hello world"]) + async for chunk in async_mock_model.prompt("hello"): + gathered.append(chunk) + assert gathered == ["hello world"]