Skip to content

Commit

Permalink
First test for AsyncModel
Browse files Browse the repository at this point in the history
  • Loading branch information
simonw committed Nov 7, 2024
1 parent 91732d0 commit 2e1045d
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 3 deletions.
2 changes: 1 addition & 1 deletion llm/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ async def __aiter__(self) -> AsyncIterator[str]:
yield chunk
return

async for chunk in await self.model.execute(
async for chunk in self.model.execute(
self.prompt,
stream=self.stream,
response=self,
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def get_long_description():
"pytest",
"numpy",
"pytest-httpx>=0.33.0",
"pytest-asyncio",
"cogapp",
"mypy>=1.10.0",
"black>=24.1.0",
Expand Down
32 changes: 30 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,29 @@ def execute(self, prompt, stream, response, conversation):
break


class AsyncMockModel(llm.AsyncModel):
model_id = "mock"

def __init__(self):
self.history = []
self._queue = []

def enqueue(self, messages):
assert isinstance(messages, list)
self._queue.append(messages)

async def execute(self, prompt, stream, response, conversation):
self.history.append((prompt, stream, response, conversation))
while True:
try:
messages = self._queue.pop(0)
for message in messages:
yield message
break
except IndexError:
break


class EmbedDemo(llm.EmbeddingModel):
model_id = "embed-demo"
batch_size = 10
Expand Down Expand Up @@ -118,8 +141,13 @@ def mock_model():
return MockModel()


@pytest.fixture
def async_mock_model():
return AsyncMockModel()


@pytest.fixture(autouse=True)
def register_embed_demo_model(embed_demo, mock_model):
def register_embed_demo_model(embed_demo, mock_model, async_mock_model):
class MockModelsPlugin:
__name__ = "MockModelsPlugin"

Expand All @@ -131,7 +159,7 @@ def register_embedding_models(self, register):

@llm.hookimpl
def register_models(self, register):
register(mock_model)
register(mock_model, async_model=async_mock_model)

pm.register(MockModelsPlugin(), name="undo-mock-models-plugin")
try:
Expand Down
10 changes: 10 additions & 0 deletions tests/test_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import pytest


@pytest.mark.asyncio
async def test_async_model(async_mock_model):
gathered = []
async_mock_model.enqueue(["hello world"])
async for chunk in async_mock_model.prompt("hello"):
gathered.append(chunk)
assert gathered == ["hello world"]

0 comments on commit 2e1045d

Please sign in to comment.