-
Notifications
You must be signed in to change notification settings - Fork 22
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Model features: token usage #109
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -268,6 +268,13 @@ class ChatNVIDIA(BaseChatModel): | |
top_p: Optional[float] = Field(None, description="Top-p for distribution sampling") | ||
seed: Optional[int] = Field(None, description="The seed for deterministic results") | ||
stop: Optional[Sequence[str]] = Field(None, description="Stop words (cased)") | ||
stream_usage: bool = Field( | ||
False, | ||
description="""Whether to include usage metadata in streaming output. | ||
If True, additional message chunks will be generated during the | ||
stream including usage metadata. | ||
""", | ||
) | ||
|
||
def __init__(self, **kwargs: Any): | ||
""" | ||
|
@@ -381,18 +388,38 @@ def _generate( | |
response = self._client.get_req(payload=payload, extra_headers=extra_headers) | ||
responses, _ = self._client.postprocess(response) | ||
self._set_callback_out(responses, run_manager) | ||
parsed_response = self._custom_postprocess(responses, streaming=False) | ||
parsed_response = self._custom_postprocess( | ||
responses, streaming=False, stream_usage=False | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. how about relying on streaming=False to imply stream_usage=False? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what if it is streaming but do not want stream usage details? |
||
) | ||
# for pre 0.2 compatibility w/ ChatMessage | ||
# ChatMessage had a role property that was not present in AIMessage | ||
parsed_response.update({"role": "assistant"}) | ||
generation = ChatGeneration(message=AIMessage(**parsed_response)) | ||
return ChatResult(generations=[generation], llm_output=responses) | ||
|
||
def _should_stream_usage( | ||
self, stream_usage: Optional[bool] = None, **kwargs: Any | ||
) -> bool: | ||
"""Determine whether to include usage metadata in streaming output. | ||
For backwards compatibility, we check for `stream_options` passed | ||
explicitly to kwargs or in the model_kwargs and override self.stream_usage. | ||
""" | ||
stream_usage_sources = [ # order of preference | ||
stream_usage, | ||
kwargs.get("stream_options", {}).get("include_usage"), | ||
self.stream_usage, | ||
] | ||
for source in stream_usage_sources: | ||
if isinstance(source, bool): | ||
return source | ||
return self.stream_usage | ||
|
||
def _stream( | ||
self, | ||
messages: List[BaseMessage], | ||
stop: Optional[Sequence[str]] = None, | ||
run_manager: Optional[CallbackManagerForLLMRun] = None, | ||
stream_usage: Optional[bool] = None, | ||
**kwargs: Any, | ||
) -> Iterator[ChatGenerationChunk]: | ||
"""Allows streaming to model!""" | ||
|
@@ -401,11 +428,16 @@ def _stream( | |
for message in [convert_message_to_dict(message) for message in messages] | ||
] | ||
inputs, extra_headers = _process_for_vlm(inputs, self._client.model) | ||
|
||
# check stream_usage, stream_options params to | ||
# include token_usage for streaming. | ||
stream_usage = self._should_stream_usage(stream_usage, **kwargs) | ||
kwargs["stream_options"] = {"include_usage": stream_usage} | ||
|
||
payload = self._get_payload( | ||
inputs=inputs, | ||
stop=stop, | ||
stream=True, | ||
stream_options={"include_usage": True}, | ||
**kwargs, | ||
) | ||
# todo: get vlm endpoints fixed and remove this | ||
|
@@ -420,7 +452,9 @@ def _stream( | |
payload=payload, extra_headers=extra_headers | ||
): | ||
self._set_callback_out(response, run_manager) | ||
parsed_response = self._custom_postprocess(response, streaming=True) | ||
parsed_response = self._custom_postprocess( | ||
response, streaming=True, stream_usage=stream_usage | ||
) | ||
# for pre 0.2 compatibility w/ ChatMessageChunk | ||
# ChatMessageChunk had a role property that was not | ||
# present in AIMessageChunk | ||
|
@@ -444,7 +478,7 @@ def _set_callback_out( | |
cb.llm_output = result | ||
|
||
def _custom_postprocess( | ||
self, msg: dict, streaming: bool = False | ||
self, msg: dict, streaming: bool = False, stream_usage: bool = False | ||
) -> dict: # todo: remove | ||
kw_left = msg.copy() | ||
out_dict = { | ||
|
@@ -456,11 +490,12 @@ def _custom_postprocess( | |
"response_metadata": {}, | ||
} | ||
if token_usage := kw_left.pop("token_usage", None): | ||
out_dict["usage_metadata"] = { | ||
"input_tokens": token_usage.get("prompt_tokens", 0), | ||
"output_tokens": token_usage.get("completion_tokens", 0), | ||
"total_tokens": token_usage.get("total_tokens", 0), | ||
} | ||
if (streaming and stream_usage) or not streaming: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. how about always returning tokens usage if its in the response? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. that makes sense, because does not matter parameter at this point |
||
out_dict["usage_metadata"] = { | ||
"input_tokens": token_usage.get("prompt_tokens", 0), | ||
"output_tokens": token_usage.get("completion_tokens", 0), | ||
"total_tokens": token_usage.get("total_tokens", 0), | ||
} | ||
# "tool_calls" is set for invoke and stream responses | ||
if tool_calls := kw_left.pop("tool_calls", None): | ||
assert isinstance( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -294,14 +294,39 @@ def test_stream_usage_metadata( | |
) | ||
|
||
llm = ChatNVIDIA(api_key="BOGUS") | ||
response = reduce(add, llm.stream("IGNROED")) | ||
response = reduce(add, llm.stream("IGNROED", stream_usage=True)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the default should be True |
||
assert isinstance(response, AIMessage) | ||
assert response.usage_metadata is not None | ||
assert response.usage_metadata["input_tokens"] == 76 | ||
assert response.usage_metadata["output_tokens"] == 29 | ||
assert response.usage_metadata["total_tokens"] == 105 | ||
|
||
|
||
def test_stream_usage_metadata_false( | ||
requests_mock: requests_mock.Mocker, | ||
) -> None: | ||
requests_mock.post( | ||
"https://integrate.api.nvidia.com/v1/chat/completions", | ||
text="\n\n".join( | ||
[ | ||
r'data: {"id":"ID0","object":"chat.completion.chunk","created":1234567890,"model":"BOGUS","system_fingerprint":null,"usage":null,"choices":[{"index":0,"delta":{"role":"assistant","content":null},"logprobs":null,"finish_reason":null}]}', # noqa: E501 | ||
r'data: {"id":"ID0","object":"chat.completion.chunk","created":1234567890,"model":"BOGUS","system_fingerprint":null,"usage":null,"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"id":"ID1","type":"function","function":{"name":"magic_function","arguments":""}}]},"logprobs":null,"finish_reason":null}]}', # noqa: E501 | ||
r'data: {"id":"ID0","object":"chat.completion.chunk","created":1234567890,"model":"BOGUS","system_fingerprint":null,"usage":null,"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\"in"}}]},"logprobs":null,"finish_reason":null}]}', # noqa: E501 | ||
r'data: {"id":"ID0","object":"chat.completion.chunk","created":1234567890,"model":"BOGUS","system_fingerprint":null,"usage":null,"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"put\":"}}]},"logprobs":null,"finish_reason":null}]}', # noqa: E501 | ||
r'data: {"id":"ID0","object":"chat.completion.chunk","created":1234567890,"model":"BOGUS","system_fingerprint":null,"usage":null,"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":" 3}"}}]},"logprobs":null,"finish_reason":null}]}', # noqa: E501 | ||
r'data: {"id":"ID0","object":"chat.completion.chunk","created":1234567890,"model":"BOGUS","system_fingerprint":null,"choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"tool_calls"}],"usage":null}', # noqa: E501 | ||
r'data: {"id":"ID0","object":"chat.completion.chunk","created":1234567890,"model":"BOGUS","system_fingerprint":null,"choices":[]}', # noqa: E501 | ||
r"data: [DONE]", | ||
] | ||
), | ||
) | ||
|
||
llm = ChatNVIDIA(api_key="BOGUS") | ||
response = reduce(add, llm.stream("IGNROED")) | ||
assert isinstance(response, AIMessage) | ||
assert response.usage_metadata is None | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"strict", | ||
[False, None, "BOGUS"], | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the langchain user's expectation is that usage details are returned by default