Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model features: token usage #109

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 44 additions & 9 deletions libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,13 @@ class ChatNVIDIA(BaseChatModel):
top_p: Optional[float] = Field(None, description="Top-p for distribution sampling")
seed: Optional[int] = Field(None, description="The seed for deterministic results")
stop: Optional[Sequence[str]] = Field(None, description="Stop words (cased)")
stream_usage: bool = Field(
False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the langchain user's expectation is that usage details are returned by default

description="""Whether to include usage metadata in streaming output.
If True, additional message chunks will be generated during the
stream including usage metadata.
""",
)

def __init__(self, **kwargs: Any):
"""
Expand Down Expand Up @@ -381,18 +388,38 @@ def _generate(
response = self._client.get_req(payload=payload, extra_headers=extra_headers)
responses, _ = self._client.postprocess(response)
self._set_callback_out(responses, run_manager)
parsed_response = self._custom_postprocess(responses, streaming=False)
parsed_response = self._custom_postprocess(
responses, streaming=False, stream_usage=False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about relying on streaming=False to imply stream_usage=False?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if it is streaming but do not want stream usage details?

)
# for pre 0.2 compatibility w/ ChatMessage
# ChatMessage had a role property that was not present in AIMessage
parsed_response.update({"role": "assistant"})
generation = ChatGeneration(message=AIMessage(**parsed_response))
return ChatResult(generations=[generation], llm_output=responses)

def _should_stream_usage(
self, stream_usage: Optional[bool] = None, **kwargs: Any
) -> bool:
"""Determine whether to include usage metadata in streaming output.
For backwards compatibility, we check for `stream_options` passed
explicitly to kwargs or in the model_kwargs and override self.stream_usage.
"""
stream_usage_sources = [ # order of preference
stream_usage,
kwargs.get("stream_options", {}).get("include_usage"),
self.stream_usage,
]
for source in stream_usage_sources:
if isinstance(source, bool):
return source
return self.stream_usage

def _stream(
self,
messages: List[BaseMessage],
stop: Optional[Sequence[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
stream_usage: Optional[bool] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
"""Allows streaming to model!"""
Expand All @@ -401,11 +428,16 @@ def _stream(
for message in [convert_message_to_dict(message) for message in messages]
]
inputs, extra_headers = _process_for_vlm(inputs, self._client.model)

# check stream_usage, stream_options params to
# include token_usage for streaming.
stream_usage = self._should_stream_usage(stream_usage, **kwargs)
kwargs["stream_options"] = {"include_usage": stream_usage}

payload = self._get_payload(
inputs=inputs,
stop=stop,
stream=True,
stream_options={"include_usage": True},
**kwargs,
)
# todo: get vlm endpoints fixed and remove this
Expand All @@ -420,7 +452,9 @@ def _stream(
payload=payload, extra_headers=extra_headers
):
self._set_callback_out(response, run_manager)
parsed_response = self._custom_postprocess(response, streaming=True)
parsed_response = self._custom_postprocess(
response, streaming=True, stream_usage=stream_usage
)
# for pre 0.2 compatibility w/ ChatMessageChunk
# ChatMessageChunk had a role property that was not
# present in AIMessageChunk
Expand All @@ -444,7 +478,7 @@ def _set_callback_out(
cb.llm_output = result

def _custom_postprocess(
self, msg: dict, streaming: bool = False
self, msg: dict, streaming: bool = False, stream_usage: bool = False
) -> dict: # todo: remove
kw_left = msg.copy()
out_dict = {
Expand All @@ -456,11 +490,12 @@ def _custom_postprocess(
"response_metadata": {},
}
if token_usage := kw_left.pop("token_usage", None):
out_dict["usage_metadata"] = {
"input_tokens": token_usage.get("prompt_tokens", 0),
"output_tokens": token_usage.get("completion_tokens", 0),
"total_tokens": token_usage.get("total_tokens", 0),
}
if (streaming and stream_usage) or not streaming:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about always returning tokens usage if its in the response?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that makes sense, because does not matter parameter at this point

out_dict["usage_metadata"] = {
"input_tokens": token_usage.get("prompt_tokens", 0),
"output_tokens": token_usage.get("completion_tokens", 0),
"total_tokens": token_usage.get("total_tokens", 0),
}
# "tool_calls" is set for invoke and stream responses
if tool_calls := kw_left.pop("tool_calls", None):
assert isinstance(
Expand Down
130 changes: 129 additions & 1 deletion libs/ai-endpoints/tests/integration_tests/test_chat_models.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
"""Test ChatNVIDIA chat model."""

from typing import List
from typing import AsyncIterator, Iterator, List, Optional

import pytest
from langchain_core.load.dump import dumps
from langchain_core.load.load import loads
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
BaseMessageChunk,
HumanMessage,
SystemMessage,
)
Expand Down Expand Up @@ -441,3 +443,129 @@ def test_stop(
assert isinstance(token.content, str)
result += f"{token.content}|"
assert all(target not in result for target in targets)


def test_ai_endpoints_stream_token_usage(chat_model: str, mode: dict) -> None:
"""Test streaming tokens from NVIDIA Endpoints."""

def _test_stream(stream: Iterator, expect_usage: bool) -> None:
full: Optional[BaseMessageChunk] = None
chunks_with_token_counts = 0
chunks_with_response_metadata = 0
for chunk in stream:
assert isinstance(chunk.content, str)
full = chunk if full is None else full + chunk
assert isinstance(chunk, AIMessageChunk)
if chunk.usage_metadata is not None:
chunks_with_token_counts += 1
if chunk.response_metadata:
chunks_with_response_metadata += 1
assert isinstance(full, AIMessageChunk)
if chunks_with_response_metadata != 1:
raise AssertionError(
"Expected exactly one chunk with metadata. "
"AIMessageChunk aggregation can add these metadata. Check that "
"this is behaving properly."
)
assert full.response_metadata.get("finish_reason") is not None
assert full.response_metadata.get("model_name") is not None
if expect_usage:
if chunks_with_token_counts != 1:
raise AssertionError(
"Expected exactly one chunk with token counts. "
"AIMessageChunk aggregation adds counts. Check that "
"this is behaving properly."
)
assert full.usage_metadata is not None
assert full.usage_metadata["input_tokens"] > 0
assert full.usage_metadata["output_tokens"] > 0
assert full.usage_metadata["total_tokens"] > 0
else:
assert chunks_with_token_counts == 0
assert full.usage_metadata is None

llm = ChatNVIDIA(model=chat_model, temperature=0, max_tokens=5, **mode)
_test_stream(llm.stream("Hello"), expect_usage=False)
_test_stream(
llm.stream("Hello", stream_options={"include_usage": True}), expect_usage=True
)
_test_stream(llm.stream("Hello", stream_usage=True), expect_usage=True)
llm = ChatNVIDIA(
model=chat_model,
temperature=0,
max_tokens=5,
model_kwargs={"stream_options": {"include_usage": True}},
**mode,
)
_test_stream(
llm.stream("Hello", stream_options={"include_usage": False}),
expect_usage=False,
)
llm = ChatNVIDIA(
model=chat_model, temperature=0, max_tokens=5, stream_usage=True, **mode
)
_test_stream(llm.stream("Hello"), expect_usage=True)
_test_stream(llm.stream("Hello", stream_usage=False), expect_usage=False)


async def test_ai_endpoints_astream_token_usage(chat_model: str, mode: dict) -> None:
"""Test async streaming tokens from NVIDIA Endpoints."""

async def _test_stream(stream: AsyncIterator, expect_usage: bool) -> None:
full: Optional[BaseMessageChunk] = None
chunks_with_token_counts = 0
chunks_with_response_metadata = 0
async for chunk in stream:
assert isinstance(chunk.content, str)
full = chunk if full is None else full + chunk
assert isinstance(chunk, AIMessageChunk)
if chunk.usage_metadata is not None:
chunks_with_token_counts += 1
if chunk.response_metadata:
chunks_with_response_metadata += 1
assert isinstance(full, AIMessageChunk)
if chunks_with_response_metadata != 1:
raise AssertionError(
"Expected exactly one chunk with metadata. "
"AIMessageChunk aggregation can add these metadata. Check that "
"this is behaving properly."
)
assert full.response_metadata.get("finish_reason") is not None
assert full.response_metadata.get("model_name") is not None
if expect_usage:
if chunks_with_token_counts != 1:
raise AssertionError(
"Expected exactly one chunk with token counts. "
"AIMessageChunk aggregation adds counts. Check that "
"this is behaving properly."
)
assert full.usage_metadata is not None
assert full.usage_metadata["input_tokens"] > 0
assert full.usage_metadata["output_tokens"] > 0
assert full.usage_metadata["total_tokens"] > 0
else:
assert chunks_with_token_counts == 0
assert full.usage_metadata is None

llm = ChatNVIDIA(model=chat_model, temperature=0, max_tokens=5, **mode)
await _test_stream(llm.astream("Hello"), expect_usage=False)
await _test_stream(
llm.astream("Hello", stream_options={"include_usage": True}), expect_usage=True
)
await _test_stream(llm.astream("Hello", stream_usage=True), expect_usage=True)
llm = ChatNVIDIA(
model=chat_model,
temperature=0,
max_tokens=5,
model_kwargs={"stream_options": {"include_usage": True}},
**mode,
)
await _test_stream(
llm.astream("Hello", stream_options={"include_usage": False}),
expect_usage=False,
)
llm = ChatNVIDIA(
model=chat_model, temperature=0, max_tokens=5, stream_usage=True, **mode
)
await _test_stream(llm.astream("Hello"), expect_usage=True)
await _test_stream(llm.astream("Hello", stream_usage=False), expect_usage=False)
27 changes: 26 additions & 1 deletion libs/ai-endpoints/tests/unit_tests/test_bind_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,14 +294,39 @@ def test_stream_usage_metadata(
)

llm = ChatNVIDIA(api_key="BOGUS")
response = reduce(add, llm.stream("IGNROED"))
response = reduce(add, llm.stream("IGNROED", stream_usage=True))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the default should be True

assert isinstance(response, AIMessage)
assert response.usage_metadata is not None
assert response.usage_metadata["input_tokens"] == 76
assert response.usage_metadata["output_tokens"] == 29
assert response.usage_metadata["total_tokens"] == 105


def test_stream_usage_metadata_false(
requests_mock: requests_mock.Mocker,
) -> None:
requests_mock.post(
"https://integrate.api.nvidia.com/v1/chat/completions",
text="\n\n".join(
[
r'data: {"id":"ID0","object":"chat.completion.chunk","created":1234567890,"model":"BOGUS","system_fingerprint":null,"usage":null,"choices":[{"index":0,"delta":{"role":"assistant","content":null},"logprobs":null,"finish_reason":null}]}', # noqa: E501
r'data: {"id":"ID0","object":"chat.completion.chunk","created":1234567890,"model":"BOGUS","system_fingerprint":null,"usage":null,"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"id":"ID1","type":"function","function":{"name":"magic_function","arguments":""}}]},"logprobs":null,"finish_reason":null}]}', # noqa: E501
r'data: {"id":"ID0","object":"chat.completion.chunk","created":1234567890,"model":"BOGUS","system_fingerprint":null,"usage":null,"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\"in"}}]},"logprobs":null,"finish_reason":null}]}', # noqa: E501
r'data: {"id":"ID0","object":"chat.completion.chunk","created":1234567890,"model":"BOGUS","system_fingerprint":null,"usage":null,"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"put\":"}}]},"logprobs":null,"finish_reason":null}]}', # noqa: E501
r'data: {"id":"ID0","object":"chat.completion.chunk","created":1234567890,"model":"BOGUS","system_fingerprint":null,"usage":null,"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":" 3}"}}]},"logprobs":null,"finish_reason":null}]}', # noqa: E501
r'data: {"id":"ID0","object":"chat.completion.chunk","created":1234567890,"model":"BOGUS","system_fingerprint":null,"choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"tool_calls"}],"usage":null}', # noqa: E501
r'data: {"id":"ID0","object":"chat.completion.chunk","created":1234567890,"model":"BOGUS","system_fingerprint":null,"choices":[]}', # noqa: E501
r"data: [DONE]",
]
),
)

llm = ChatNVIDIA(api_key="BOGUS")
response = reduce(add, llm.stream("IGNROED"))
assert isinstance(response, AIMessage)
assert response.usage_metadata is None


@pytest.mark.parametrize(
"strict",
[False, None, "BOGUS"],
Expand Down
Loading