-
Notifications
You must be signed in to change notification settings - Fork 22
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Model features: token usage #109
base: main
Are you sure you want to change the base?
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -268,6 +268,13 @@ class ChatNVIDIA(BaseChatModel): | |
top_p: Optional[float] = Field(None, description="Top-p for distribution sampling") | ||
seed: Optional[int] = Field(None, description="The seed for deterministic results") | ||
stop: Optional[Sequence[str]] = Field(None, description="Stop words (cased)") | ||
stream_usage: bool = Field( | ||
False, | ||
description="""Whether to include usage metadata in streaming output. | ||
If True, additional message chunks will be generated during the | ||
stream including usage metadata. | ||
""", | ||
) | ||
|
||
def __init__(self, **kwargs: Any): | ||
""" | ||
|
@@ -381,18 +388,38 @@ def _generate( | |
response = self._client.get_req(payload=payload, extra_headers=extra_headers) | ||
responses, _ = self._client.postprocess(response) | ||
self._set_callback_out(responses, run_manager) | ||
parsed_response = self._custom_postprocess(responses, streaming=False) | ||
parsed_response = self._custom_postprocess( | ||
responses, streaming=False, stream_usage=False | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. how about relying on streaming=False to imply stream_usage=False? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what if it is streaming but do not want stream usage details? |
||
) | ||
# for pre 0.2 compatibility w/ ChatMessage | ||
# ChatMessage had a role property that was not present in AIMessage | ||
parsed_response.update({"role": "assistant"}) | ||
generation = ChatGeneration(message=AIMessage(**parsed_response)) | ||
return ChatResult(generations=[generation], llm_output=responses) | ||
|
||
def _should_stream_usage( | ||
self, stream_usage: Optional[bool] = None, **kwargs: Any | ||
) -> bool: | ||
"""Determine whether to include usage metadata in streaming output. | ||
For backwards compatibility, we check for `stream_options` passed | ||
explicitly to kwargs or in the model_kwargs and override self.stream_usage. | ||
""" | ||
stream_usage_sources = [ # order of preference | ||
stream_usage, | ||
kwargs.get("stream_options", {}).get("include_usage"), | ||
self.stream_usage, | ||
] | ||
for source in stream_usage_sources: | ||
if isinstance(source, bool): | ||
return source | ||
return self.stream_usage | ||
|
||
def _stream( | ||
self, | ||
messages: List[BaseMessage], | ||
stop: Optional[Sequence[str]] = None, | ||
run_manager: Optional[CallbackManagerForLLMRun] = None, | ||
stream_usage: Optional[bool] = None, | ||
**kwargs: Any, | ||
) -> Iterator[ChatGenerationChunk]: | ||
"""Allows streaming to model!""" | ||
|
@@ -401,11 +428,16 @@ def _stream( | |
for message in [convert_message_to_dict(message) for message in messages] | ||
] | ||
inputs, extra_headers = _process_for_vlm(inputs, self._client.model) | ||
|
||
# check stream_usage, stream_options params to | ||
# include token_usage for streaming. | ||
stream_usage = self._should_stream_usage(stream_usage, **kwargs) | ||
kwargs["stream_options"] = {"include_usage": stream_usage} | ||
|
||
payload = self._get_payload( | ||
inputs=inputs, | ||
stop=stop, | ||
stream=True, | ||
stream_options={"include_usage": True}, | ||
**kwargs, | ||
) | ||
# todo: get vlm endpoints fixed and remove this | ||
|
@@ -420,7 +452,9 @@ def _stream( | |
payload=payload, extra_headers=extra_headers | ||
): | ||
self._set_callback_out(response, run_manager) | ||
parsed_response = self._custom_postprocess(response, streaming=True) | ||
parsed_response = self._custom_postprocess( | ||
response, streaming=True, stream_usage=stream_usage | ||
) | ||
# for pre 0.2 compatibility w/ ChatMessageChunk | ||
# ChatMessageChunk had a role property that was not | ||
# present in AIMessageChunk | ||
|
@@ -444,7 +478,7 @@ def _set_callback_out( | |
cb.llm_output = result | ||
|
||
def _custom_postprocess( | ||
self, msg: dict, streaming: bool = False | ||
self, msg: dict, streaming: bool = False, stream_usage: bool = False | ||
) -> dict: # todo: remove | ||
kw_left = msg.copy() | ||
out_dict = { | ||
|
@@ -456,11 +490,12 @@ def _custom_postprocess( | |
"response_metadata": {}, | ||
} | ||
if token_usage := kw_left.pop("token_usage", None): | ||
out_dict["usage_metadata"] = { | ||
"input_tokens": token_usage.get("prompt_tokens", 0), | ||
"output_tokens": token_usage.get("completion_tokens", 0), | ||
"total_tokens": token_usage.get("total_tokens", 0), | ||
} | ||
if (streaming and stream_usage) or not streaming: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. how about always returning tokens usage if its in the response? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. that makes sense, because does not matter parameter at this point |
||
out_dict["usage_metadata"] = { | ||
"input_tokens": token_usage.get("prompt_tokens", 0), | ||
"output_tokens": token_usage.get("completion_tokens", 0), | ||
"total_tokens": token_usage.get("total_tokens", 0), | ||
} | ||
# "tool_calls" is set for invoke and stream responses | ||
if tool_calls := kw_left.pop("tool_calls", None): | ||
assert isinstance( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,13 +1,15 @@ | ||
"""Test ChatNVIDIA chat model.""" | ||
|
||
from typing import List | ||
from typing import AsyncIterator, Iterator, List, Optional | ||
|
||
import pytest | ||
from langchain_core.load.dump import dumps | ||
from langchain_core.load.load import loads | ||
from langchain_core.messages import ( | ||
AIMessage, | ||
AIMessageChunk, | ||
BaseMessage, | ||
BaseMessageChunk, | ||
HumanMessage, | ||
SystemMessage, | ||
) | ||
|
@@ -441,3 +443,121 @@ def test_stop( | |
assert isinstance(token.content, str) | ||
result += f"{token.content}|" | ||
assert all(target not in result for target in targets) | ||
|
||
|
||
def test_ai_endpoints_stream_token_usage() -> None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. all integration tests should accept and use a model and mode |
||
"""Test streaming tokens from NVIDIA Endpoints.""" | ||
|
||
def _test_stream(stream: Iterator, expect_usage: bool) -> None: | ||
full: Optional[BaseMessageChunk] = None | ||
chunks_with_token_counts = 0 | ||
chunks_with_response_metadata = 0 | ||
for chunk in stream: | ||
assert isinstance(chunk.content, str) | ||
full = chunk if full is None else full + chunk | ||
assert isinstance(chunk, AIMessageChunk) | ||
if chunk.usage_metadata is not None: | ||
chunks_with_token_counts += 1 | ||
if chunk.response_metadata: | ||
chunks_with_response_metadata += 1 | ||
assert isinstance(full, AIMessageChunk) | ||
if chunks_with_response_metadata != 1: | ||
raise AssertionError( | ||
"Expected exactly one chunk with metadata. " | ||
"AIMessageChunk aggregation can add these metadata. Check that " | ||
"this is behaving properly." | ||
) | ||
assert full.response_metadata.get("finish_reason") is not None | ||
assert full.response_metadata.get("model_name") is not None | ||
if expect_usage: | ||
if chunks_with_token_counts != 1: | ||
raise AssertionError( | ||
"Expected exactly one chunk with token counts. " | ||
"AIMessageChunk aggregation adds counts. Check that " | ||
"this is behaving properly." | ||
) | ||
assert full.usage_metadata is not None | ||
assert full.usage_metadata["input_tokens"] > 0 | ||
assert full.usage_metadata["output_tokens"] > 0 | ||
assert full.usage_metadata["total_tokens"] > 0 | ||
else: | ||
assert chunks_with_token_counts == 0 | ||
assert full.usage_metadata is None | ||
|
||
llm = ChatNVIDIA(temperature=0, max_tokens=5) | ||
_test_stream(llm.stream("Hello"), expect_usage=False) | ||
_test_stream( | ||
llm.stream("Hello", stream_options={"include_usage": True}), expect_usage=True | ||
) | ||
_test_stream(llm.stream("Hello", stream_usage=True), expect_usage=True) | ||
llm = ChatNVIDIA( | ||
temperature=0, | ||
max_tokens=5, | ||
model_kwargs={"stream_options": {"include_usage": True}}, | ||
) | ||
_test_stream( | ||
llm.stream("Hello", stream_options={"include_usage": False}), | ||
expect_usage=False, | ||
) | ||
llm = ChatNVIDIA(temperature=0, max_tokens=5, stream_usage=True) | ||
_test_stream(llm.stream("Hello"), expect_usage=True) | ||
_test_stream(llm.stream("Hello", stream_usage=False), expect_usage=False) | ||
|
||
|
||
async def test_ai_endpoints_astream_token_usage() -> None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. all integration tests should accept and use a model and mode |
||
"""Test async streaming tokens from NVIDIA Endpoints.""" | ||
|
||
async def _test_stream(stream: AsyncIterator, expect_usage: bool) -> None: | ||
full: Optional[BaseMessageChunk] = None | ||
chunks_with_token_counts = 0 | ||
chunks_with_response_metadata = 0 | ||
async for chunk in stream: | ||
assert isinstance(chunk.content, str) | ||
full = chunk if full is None else full + chunk | ||
assert isinstance(chunk, AIMessageChunk) | ||
if chunk.usage_metadata is not None: | ||
chunks_with_token_counts += 1 | ||
if chunk.response_metadata: | ||
chunks_with_response_metadata += 1 | ||
assert isinstance(full, AIMessageChunk) | ||
if chunks_with_response_metadata != 1: | ||
raise AssertionError( | ||
"Expected exactly one chunk with metadata. " | ||
"AIMessageChunk aggregation can add these metadata. Check that " | ||
"this is behaving properly." | ||
) | ||
assert full.response_metadata.get("finish_reason") is not None | ||
assert full.response_metadata.get("model_name") is not None | ||
if expect_usage: | ||
if chunks_with_token_counts != 1: | ||
raise AssertionError( | ||
"Expected exactly one chunk with token counts. " | ||
"AIMessageChunk aggregation adds counts. Check that " | ||
"this is behaving properly." | ||
) | ||
assert full.usage_metadata is not None | ||
assert full.usage_metadata["input_tokens"] > 0 | ||
assert full.usage_metadata["output_tokens"] > 0 | ||
assert full.usage_metadata["total_tokens"] > 0 | ||
else: | ||
assert chunks_with_token_counts == 0 | ||
assert full.usage_metadata is None | ||
|
||
llm = ChatNVIDIA(temperature=0, max_tokens=5) | ||
await _test_stream(llm.astream("Hello"), expect_usage=False) | ||
await _test_stream( | ||
llm.astream("Hello", stream_options={"include_usage": True}), expect_usage=True | ||
) | ||
await _test_stream(llm.astream("Hello", stream_usage=True), expect_usage=True) | ||
llm = ChatNVIDIA( | ||
temperature=0, | ||
max_tokens=5, | ||
model_kwargs={"stream_options": {"include_usage": True}}, | ||
) | ||
await _test_stream( | ||
llm.astream("Hello", stream_options={"include_usage": False}), | ||
expect_usage=False, | ||
) | ||
llm = ChatNVIDIA(temperature=0, max_tokens=5, stream_usage=True) | ||
await _test_stream(llm.astream("Hello"), expect_usage=True) | ||
await _test_stream(llm.astream("Hello", stream_usage=False), expect_usage=False) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -294,14 +294,39 @@ def test_stream_usage_metadata( | |
) | ||
|
||
llm = ChatNVIDIA(api_key="BOGUS") | ||
response = reduce(add, llm.stream("IGNROED")) | ||
response = reduce(add, llm.stream("IGNROED", stream_usage=True)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the default should be True |
||
assert isinstance(response, AIMessage) | ||
assert response.usage_metadata is not None | ||
assert response.usage_metadata["input_tokens"] == 76 | ||
assert response.usage_metadata["output_tokens"] == 29 | ||
assert response.usage_metadata["total_tokens"] == 105 | ||
|
||
|
||
def test_stream_usage_metadata_false( | ||
requests_mock: requests_mock.Mocker, | ||
) -> None: | ||
requests_mock.post( | ||
"https://integrate.api.nvidia.com/v1/chat/completions", | ||
text="\n\n".join( | ||
[ | ||
r'data: {"id":"ID0","object":"chat.completion.chunk","created":1234567890,"model":"BOGUS","system_fingerprint":null,"usage":null,"choices":[{"index":0,"delta":{"role":"assistant","content":null},"logprobs":null,"finish_reason":null}]}', # noqa: E501 | ||
r'data: {"id":"ID0","object":"chat.completion.chunk","created":1234567890,"model":"BOGUS","system_fingerprint":null,"usage":null,"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"id":"ID1","type":"function","function":{"name":"magic_function","arguments":""}}]},"logprobs":null,"finish_reason":null}]}', # noqa: E501 | ||
r'data: {"id":"ID0","object":"chat.completion.chunk","created":1234567890,"model":"BOGUS","system_fingerprint":null,"usage":null,"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\"in"}}]},"logprobs":null,"finish_reason":null}]}', # noqa: E501 | ||
r'data: {"id":"ID0","object":"chat.completion.chunk","created":1234567890,"model":"BOGUS","system_fingerprint":null,"usage":null,"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"put\":"}}]},"logprobs":null,"finish_reason":null}]}', # noqa: E501 | ||
r'data: {"id":"ID0","object":"chat.completion.chunk","created":1234567890,"model":"BOGUS","system_fingerprint":null,"usage":null,"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":" 3}"}}]},"logprobs":null,"finish_reason":null}]}', # noqa: E501 | ||
r'data: {"id":"ID0","object":"chat.completion.chunk","created":1234567890,"model":"BOGUS","system_fingerprint":null,"choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"tool_calls"}],"usage":null}', # noqa: E501 | ||
r'data: {"id":"ID0","object":"chat.completion.chunk","created":1234567890,"model":"BOGUS","system_fingerprint":null,"choices":[]}', # noqa: E501 | ||
r"data: [DONE]", | ||
] | ||
), | ||
) | ||
|
||
llm = ChatNVIDIA(api_key="BOGUS") | ||
response = reduce(add, llm.stream("IGNROED")) | ||
assert isinstance(response, AIMessage) | ||
assert response.usage_metadata is None | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"strict", | ||
[False, None, "BOGUS"], | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the langchain user's expectation is that usage details are returned by default