Skip to content

Commit

Permalink
test usage metadata accuracy
Browse files Browse the repository at this point in the history
  • Loading branch information
mattf committed Aug 29, 2024
1 parent a9783a7 commit 9c9c0fd
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 0 deletions.
51 changes: 51 additions & 0 deletions libs/ai-endpoints/tests/integration_tests/test_bind_tools.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import json
import warnings
from functools import reduce
from operator import add
from typing import Any, Callable, List, Literal, Optional, Union

import pytest
Expand Down Expand Up @@ -736,3 +738,52 @@ def test_accuracy_parallel_tool_calls_easy(
tool_call1 = response.tool_calls[1]
assert tool_call1["name"] == "get_current_weather"
assert tool_call1["args"]["location"] in valid_args


@pytest.mark.xfail(reason="Server producing invalid response")
def test_stream_usage_metadata(
tool_model: str,
mode: dict,
) -> None:
"""
This is a regression test for the server. The server was returning
usage metadata multiple times resulting in incorrect aggregate
usage data.
We use invoke to get the baseline usage metadata and then compare
the usage metadata from the stream to the baseline.
"""

@tool
def magic(
num: int = Field(..., description="Number to magic"),
) -> int:
"""Magic a number"""
return (num**num) % num

prompt = "What is magic(42)?"
llm = ChatNVIDIA(model=tool_model, **mode).bind_tools(
[magic], tool_choice="required"
)
baseline = llm.invoke(prompt)
assert isinstance(baseline, AIMessage)
assert baseline.usage_metadata is not None
baseline_in, baseline_out, baseline_total = (
baseline.usage_metadata["input_tokens"],
baseline.usage_metadata["output_tokens"],
baseline.usage_metadata["total_tokens"],
)
assert baseline_in + baseline_out == baseline_total
response = reduce(add, llm.stream(prompt))
assert isinstance(response, AIMessage)
assert response.usage_metadata is not None
tolerance = 1.25 # allow for streaming to be 25% higher than invoke
response_in, response_out, response_total = (
response.usage_metadata["input_tokens"],
response.usage_metadata["output_tokens"],
response.usage_metadata["total_tokens"],
)
assert response_in + response_out == response_total
assert response_in < baseline_in * tolerance
assert response_out < baseline_out * tolerance
assert response_total < baseline_total * tolerance
28 changes: 28 additions & 0 deletions libs/ai-endpoints/tests/unit_tests/test_bind_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,3 +261,31 @@ def test_regression_ai_null_content(
assistant.content = None # type: ignore
llm.invoke([assistant])
llm.stream([assistant])


def test_stream_usage_metadata(
requests_mock: requests_mock.Mocker,
) -> None:
requests_mock.post(
"https://integrate.api.nvidia.com/v1/chat/completions",
text="\n\n".join(
[
r'data: {"id":"ID0","object":"chat.completion.chunk","created":1234567890,"model":"BOGUS","system_fingerprint":null,"usage":null,"choices":[{"index":0,"delta":{"role":"assistant","content":null},"logprobs":null,"finish_reason":null}]}', # noqa: E501
r'data: {"id":"ID0","object":"chat.completion.chunk","created":1234567890,"model":"BOGUS","system_fingerprint":null,"usage":null,"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"id":"ID1","type":"function","function":{"name":"magic_function","arguments":""}}]},"logprobs":null,"finish_reason":null}]}', # noqa: E501
r'data: {"id":"ID0","object":"chat.completion.chunk","created":1234567890,"model":"BOGUS","system_fingerprint":null,"usage":null,"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\"in"}}]},"logprobs":null,"finish_reason":null}]}', # noqa: E501
r'data: {"id":"ID0","object":"chat.completion.chunk","created":1234567890,"model":"BOGUS","system_fingerprint":null,"usage":null,"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"put\":"}}]},"logprobs":null,"finish_reason":null}]}', # noqa: E501
r'data: {"id":"ID0","object":"chat.completion.chunk","created":1234567890,"model":"BOGUS","system_fingerprint":null,"usage":null,"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":" 3}"}}]},"logprobs":null,"finish_reason":null}]}', # noqa: E501
r'data: {"id":"ID0","object":"chat.completion.chunk","created":1234567890,"model":"BOGUS","system_fingerprint":null,"choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"tool_calls"}],"usage":null}', # noqa: E501
r'data: {"id":"ID0","object":"chat.completion.chunk","created":1234567890,"model":"BOGUS","system_fingerprint":null,"choices":[],"usage":{"prompt_tokens":76,"completion_tokens":29,"total_tokens":105}}', # noqa: E501
r"data: [DONE]",
]
),
)

llm = ChatNVIDIA(api_key="BOGUS")
response = reduce(add, llm.stream("IGNROED"))
assert isinstance(response, AIMessage)
assert response.usage_metadata is not None
assert response.usage_metadata["input_tokens"] == 76
assert response.usage_metadata["output_tokens"] == 29
assert response.usage_metadata["total_tokens"] == 105

0 comments on commit 9c9c0fd

Please sign in to comment.