From a42e389d6c679faff352b0241d1232e8bca7e4a0 Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Wed, 28 Aug 2024 06:53:51 -0400 Subject: [PATCH] add ainvoke / astream basic tests --- .../test_completions_models.py | 23 +++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/libs/ai-endpoints/tests/integration_tests/test_completions_models.py b/libs/ai-endpoints/tests/integration_tests/test_completions_models.py index f7d1308a..3fb4e724 100644 --- a/libs/ai-endpoints/tests/integration_tests/test_completions_models.py +++ b/libs/ai-endpoints/tests/integration_tests/test_completions_models.py @@ -74,6 +74,19 @@ def stream(llm: NVIDIA, prompt: str, **kwargs: Any) -> Tuple[str, int]: return response, count +async def ainvoke(llm: NVIDIA, prompt: str, **kwargs: Any) -> Tuple[str, int]: + return await llm.ainvoke(prompt, **kwargs), 1 + + +async def astream(llm: NVIDIA, prompt: str, **kwargs: Any) -> Tuple[str, int]: + response = "" + count = 0 + async for chunk in llm.astream(prompt, **kwargs): + response += chunk + count += 1 + return response, count + + @pytest.mark.parametrize( "func, count", [(invoke, 0), (stream, 1)], ids=["invoke", "stream"] ) @@ -84,6 +97,16 @@ def test_basic(completions_model: str, mode: dict, func: Callable, count: int) - assert cnt > count, "Should have received more chunks" +@pytest.mark.parametrize( + "func, count", [(ainvoke, 0), (astream, 1)], ids=["ainvoke", "astream"] +) +async def test_abasic(completions_model: str, mode: dict, func: Callable, count: int) -> None: + llm = NVIDIA(model=completions_model, **mode) + response, cnt = await func(llm, "Hello, my name is") + assert isinstance(response, str) + assert cnt > count, "Should have received more chunks" + + @pytest.mark.parametrize( "param, value", [