Skip to content

Commit

Permalink
add ainvoke / astream basic tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mattf committed Aug 28, 2024
1 parent a785670 commit a42e389
Showing 1 changed file with 23 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,19 @@ def stream(llm: NVIDIA, prompt: str, **kwargs: Any) -> Tuple[str, int]:
return response, count


async def ainvoke(llm: NVIDIA, prompt: str, **kwargs: Any) -> Tuple[str, int]:
return await llm.ainvoke(prompt, **kwargs), 1


async def astream(llm: NVIDIA, prompt: str, **kwargs: Any) -> Tuple[str, int]:
response = ""
count = 0
async for chunk in llm.astream(prompt, **kwargs):
response += chunk
count += 1
return response, count


@pytest.mark.parametrize(
"func, count", [(invoke, 0), (stream, 1)], ids=["invoke", "stream"]
)
Expand All @@ -84,6 +97,16 @@ def test_basic(completions_model: str, mode: dict, func: Callable, count: int) -
assert cnt > count, "Should have received more chunks"


@pytest.mark.parametrize(
"func, count", [(ainvoke, 0), (astream, 1)], ids=["ainvoke", "astream"]
)
async def test_abasic(completions_model: str, mode: dict, func: Callable, count: int) -> None:

Check failure on line 103 in libs/ai-endpoints/tests/integration_tests/test_completions_models.py

View workflow job for this annotation

GitHub Actions / cd libs/ai-endpoints / make lint #3.8

Ruff (E501)

tests/integration_tests/test_completions_models.py:103:89: E501 Line too long (94 > 88)

Check failure on line 103 in libs/ai-endpoints/tests/integration_tests/test_completions_models.py

View workflow job for this annotation

GitHub Actions / cd libs/ai-endpoints / make lint #3.11

Ruff (E501)

tests/integration_tests/test_completions_models.py:103:89: E501 Line too long (94 > 88)
llm = NVIDIA(model=completions_model, **mode)
response, cnt = await func(llm, "Hello, my name is")
assert isinstance(response, str)
assert cnt > count, "Should have received more chunks"


@pytest.mark.parametrize(
"param, value",
[
Expand Down

0 comments on commit a42e389

Please sign in to comment.