From 0cf48732275d7e2bbb9892971d3da25a10dd89c2 Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Wed, 31 Jul 2024 06:25:39 -0400 Subject: [PATCH] update chat stream test to a prompt that should generate multiple chunks --- .../tests/integration_tests/test_chat_models.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/libs/ai-endpoints/tests/integration_tests/test_chat_models.py b/libs/ai-endpoints/tests/integration_tests/test_chat_models.py index 6ed7c4ae..f22df3f3 100644 --- a/libs/ai-endpoints/tests/integration_tests/test_chat_models.py +++ b/libs/ai-endpoints/tests/integration_tests/test_chat_models.py @@ -154,14 +154,14 @@ def test_ai_endpoints_streaming(chat_model: str, mode: dict) -> None: """Test streaming tokens from ai endpoints.""" llm = ChatNVIDIA(model=chat_model, max_tokens=36, **mode) - generator = llm.stream("I'm Pickle Rick") + generator = llm.stream("Count to 100, e.g. 1 2 3 4") response = next(generator) cnt = 0 for chunk in generator: assert isinstance(chunk.content, str) response += chunk cnt += 1 - assert cnt > 1 + assert cnt > 1, response # compatibility test for ChatMessageChunk (pre 0.2) # assert hasattr(response, "role") # assert response.role == "assistant" # does not work, role not passed through @@ -171,11 +171,14 @@ async def test_ai_endpoints_astream(chat_model: str, mode: dict) -> None: """Test streaming tokens from ai endpoints.""" llm = ChatNVIDIA(model=chat_model, max_tokens=35, **mode) + generator = llm.astream("Count to 100, e.g. 1 2 3 4") + response = await anext(generator) cnt = 0 - async for token in llm.astream("I'm Pickle Rick"): - assert isinstance(token.content, str) + async for chunk in generator: + assert isinstance(chunk.content, str) + response += chunk cnt += 1 - assert cnt > 1 + assert cnt > 1, response async def test_ai_endpoints_abatch(chat_model: str, mode: dict) -> None: